Dataset Distillation

URL: https://arxiv.org/abs/1811.10959

Notes:

Unlike these approaches, we are interested in understanding the intrinsic properties of the training data rather than a specific trained model.

Sections of the paper:

  1. Train a network with fixed initialization using one grad descent step.
  2. Use randomly initialized weights (more challenging)
  3. Linear network -- understand limitations
  4. Use more gradient steps
  5. How to obtain distilled images using different initializations.

random note to self: This is a paper where the goal fits into the paradigm of Input design.

Algorithm

The algorithm (with some paraphrasing for myself)

Firstly, the required inputs that are of interest:

  1. A prediction model $f(x, \theta)$ for the learning task at hand. $x$ is the actual data, and $\theta$ are its parameters.
  2. A distribution for initial weights, $p(\theta_0)$.
  3. The desired number of distilled data points, $M$.

Secondly, some other parameters that may be of interest:

  1. Step size, $\alpha$
  2. Batch size, $n$, such that $n \lt\lt M$.
  3. Number of optimization iterations, $T$
  4. $\tilde{\eta}{0}$, i.e. the initial value for $\tilde{\eta}$, which is a target learning rate for learning on the _distilled, synthetic data.

The algorithm in words, translated:

  1. Initialize $\tilde{x}$, the distilled input data, using a random matrix.
  2. Set $\eta$ to $\tilde{\eta}_0$.
  3. For each training step $t=1$ to $T$:
    1. Get a batch of $n$ data points from $x$. We call this $\mathbf{x_t}$, and in the paper, it is given the notation $\mathbf{x_t} = {x_{t,j} }_{j=1}^{n}$
    2. Draw a batch of initial weights from $p(\theta_0)$, call it $\theta_{0}^{(j)}$. Given the notation in the paper, this likely is "sample $j$ weight draws".
    3. For each sampled weight $\theta_{0}^{(j)}$:
      1. Update $\theta_{1}^{(j)}$ such by doing one step in the negative direction of the gradient of the loss function using the distilled data... $\theta_1^{(j)} = \theta_0^{(j)} - \tilde{\eta} \nabla_{\theta_0^{(j)}} l(\tilde{x}, \theta_0^{(j)})$.
      2. Now evaluate the new loss on actual data $L^{(j)} = l( \mathbf{x_t}, \theta_{1}^{(j)})$
    4. Then, update the distilled data values $\tilde{x}$ and $\tilde{\eta}$:
      1. $\tilde{x} \leftarrow \tilde{x} - \alpha \nabla_{\tilde{x}} \sum_{j} L^{(j)}$: "update distilled data in the negative direction of the gradient of sum of losses over all sampled weights in $\theta_t^{(j)}$"
      2. $\tilde{\eta} \leftarrow \tilde{\eta} - \alpha \nabla_{\tilde{\eta}} \sum_{j} L^{(j)}$: same semantic meaning, except we do it on $\tilde{\eta}$, the learning rate to use when using distilled data to train the model.

It seems like this is going to be a paper I need to take a second look over.

In any case, some ideas to think about:

  1. We can retrain UniRep by distilling protein space, maybe? Doing so might make retraining really easy, and we yield a set of weights that are free from the license restrictions?
  2. What about distilling graph data? How easy would this be?

Why I'm so interested in dataset distillation

It's a neat idea for a few reasons.

The biggest reason: compression + fast training. I am tired of waiting for large models to finish training. If we can distill a dataset to its essentials, then we should be able to train large models faster.

Another ancillary reason: intellectual interest! There's an interesting question for me: what is the smallest dataset that is necessary for a model to explain the data?

What would dataset distillation in a quadratic model look like?

Some things I'd expect:

  • The learned $x$ should collapse to a few points? Like it should be able to get a good model representative of the training data, using just two points.
    • We might not get such an extreme distillation, maybe?
  • On the other hand, because of the use of distribution draws for parameters, I can see how we would end up marginalizing over all possible pairs of "inducing" points
  • Oh wait, we actually get to define the number of inducing points!

The setup: We set up a bunch of training data points along the x-axis line, and create noisy $y$ outputs for the function of interest, in this case, a quadratic model. We should be able to distill points to include:

  • The hump/bowl
  • Points on both sides.

Then, we try to distill $x$ to a small set of $M = 10$ points.

Input design