Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Augmenting Statistical Models with Natural Language Parameters, published by jsteinhardt on September 22, 2024 on LessWrong.
This is a guest post by my student Ruiqi Zhong, who has some very exciting work defining new families of statistical models that can take natural language explanations as parameters. The motivation is that existing statistical models are bad at explaining structured data. To address this problem, we agument these models with natural language parameters, which can represent interpretable abstract features and be learned automatically.
Imagine the following scenario: It is the year 3024. We are historians trying to understand what happened between 2016 and 2024, by looking at how Twitter topics changed across that time period. We are given a dataset of user-posted images sorted by time, $x_1$, $x_2$ ... $x_T$, and our goal is to find trends in this dataset to help interpret what happened.
If we successfully achieve our goal, we would discover, for instance, (1) a recurring spike of images depicting athletes every four years for the Olympics, and (2) a large increase in images containing medical concepts during and after the COVID-19 pandemic.
How do we usually discover temporal trends from a dataset? One common approach is to fit a time series model to predict how the features evolve and then interpret the learned model. However, it is unclear what features to use: pixels and neural image embeddings are high-dimensional and uninterpretable, undermining the goal of extracting explainable trends.
We address this problem by augmenting statistical models with interpretable natural language parameters. The figure below depicts a graphical model representation for the case of time series data. We explain the trends in the observed data [ $x_1$ ... $x_T$] by learning two sets of latent parameters: natural language parameters $\phi$ (the learned features) and real-valued parameters $w$ (the time-varying trends).
$\phi$: the natural language descriptions of $K$ different topics, e.g. "depicts athletes competing". $\phi$ is an element of $\Sigma$, the universe of all natural language predicates.
$w_t$: the frequency of each of the K topics at the time $t$.
If our model successfully recovers the underlying trends, then we can visualize $w$ and $\phi$ below and see that: 1) more pictures contain medical concepts (red) starting from 2020, and 2) there are recurring (blue) spikes of athletes competing.
In the rest of this post, we will explain in detail how to specify and learn models with natural language parameters and showcase the model on several real-world applications. We will cover:
A warm-up example of a statistical model with natural language explanations
A modeling language for specifying natural language parameters
Applications of our framework, which can be used to specify models for time series, clustering, and applications. We will go over:
A machine learning application that uses our time series model to monitor trends in LLM usage
A business application that uses our clustering model to taxonomize product reviews
A cognitive science application that uses our classification model to explain what images are more memorable for humans
Thanks to Louise Verkin for helping to typeset the post in Ghost format.
Warm-up Example: Logistic Regression with Natural Language Parameters
Instead of understanding topic shifts across the entire time window of 2016-2024, let's first study a much simpler question: what images are more likely to appear after 2020? The usual way to approach this problem is to,
1. brainstorm some features,
2. extract the real-valued features from each image, and
3. run a logistic regression model on these features to predict the target $Y$ =1 if the image appears after 2020, $Y$ =0 otherwise.
More concretely:
Step 1: Propose different...