URL: https://arxiv.org/abs/1811.10959
Notes:
Unlike these approaches, we are interested in understanding the intrinsic properties of the training data rather than a specific trained model.
Sections of the paper:
- Train a network with fixed initialization using one grad descent step.
- Use randomly initialized weights (more challenging)
- Linear network -- understand limitations
- Use more gradient steps
- How to obtain distilled images using different initializations.
random note to self: This is a paper where the goal fits into the paradigm of Input design.
Algorithm
The algorithm (with some paraphrasing for myself)
Firstly, the required inputs that are of interest:
- A prediction model $f(x, \theta)$ for the learning task at hand. $x$ is the actual data, and $\theta$ are its parameters.
- A distribution for initial weights, $p(\theta_0)$.
- The desired number of distilled data points, $M$.
Secondly, some other parameters that may be of interest:
- Step size, $\alpha$
- Batch size, $n$, such that $n \lt\lt M$.
- Number of optimization iterations, $T$
- $\tilde{\eta}{0}$, i.e. the initial value for $\tilde{\eta}$, which is a target learning rate for learning on the _distilled, synthetic data.
The algorithm in words, translated:
- Initialize $\tilde{x}$, the distilled input data, using a random matrix.
- Set $\eta$ to $\tilde{\eta}_0$.
- For each training step $t=1$ to $T$:
- Get a batch of $n$ data points from $x$. We call this $\mathbf{x_t}$, and in the paper, it is given the notation $\mathbf{x_t} = {x_{t,j} }_{j=1}^{n}$
- Draw a batch of initial weights from $p(\theta_0)$, call it $\theta_{0}^{(j)}$. Given the notation in the paper, this likely is "sample $j$ weight draws".
- For each sampled weight $\theta_{0}^{(j)}$:
- Update $\theta_{1}^{(j)}$ such by doing one step in the negative direction of the gradient of the loss function using the distilled data... $\theta_1^{(j)} = \theta_0^{(j)} - \tilde{\eta} \nabla_{\theta_0^{(j)}} l(\tilde{x}, \theta_0^{(j)})$.
- Now evaluate the new loss on actual data $L^{(j)} = l( \mathbf{x_t}, \theta_{1}^{(j)})$
- Then, update the distilled data values $\tilde{x}$ and $\tilde{\eta}$:
- $\tilde{x} \leftarrow \tilde{x} - \alpha \nabla_{\tilde{x}} \sum_{j} L^{(j)}$: "update distilled data in the negative direction of the gradient of sum of losses over all sampled weights in $\theta_t^{(j)}$"
- $\tilde{\eta} \leftarrow \tilde{\eta} - \alpha \nabla_{\tilde{\eta}} \sum_{j} L^{(j)}$: same semantic meaning, except we do it on $\tilde{\eta}$, the learning rate to use when using distilled data to train the model.
It seems like this is going to be a paper I need to take a second look over.
In any case, some ideas to think about:
- We can retrain UniRep by distilling protein space, maybe? Doing so might make retraining really easy, and we yield a set of weights that are free from the license restrictions?
- What about distilling graph data? How easy would this be?
Why I'm so interested in dataset distillation
It's a neat idea for a few reasons.
The biggest reason: compression + fast training. I am tired of waiting for large models to finish training. If we can distill a dataset to its essentials, then we should be able to train large models faster.
Another ancillary reason: intellectual interest! There's an interesting question for me: what is the smallest dataset that is necessary for a model to explain the data?
What would dataset distillation in a quadratic model look like?
Some things I'd expect:
- The learned $x$ should collapse to a few points? Like it should be able to get a good model representative of the training data, using just two points.
- We might not get such an extreme distillation, maybe?
- On the other hand, because of the use of distribution draws for parameters, I can see how we would end up marginalizing over all possible pairs of "inducing" points
- Oh wait, we actually get to define the number of inducing points!
The setup: We set up a bunch of training data points along the x-axis line, and create noisy $y$ outputs for the function of interest, in this case, a quadratic model. We should be able to distill points to include:
- The hump/bowl
- Points on both sides.
Then, we try to distill $x$ to a small set of $M = 10$ points.