Input design

Design by adaptive sampling

PDF: https://arxiv.org/pdf/1810.03714.pdf

Even after a few months, the paper still feels dense to digest. However, I think I have finally grokked it.

Setup:

  • We have an oracle: $p(y|x)$. (a.k.a. property prediction model)
  • Oracle can help us compute probability of a set $S$ of data.
    • In the case of maximization, $S$ is the set of values $y$ s.t. $y \geq y_{max}$, where $y_{max}$ is given by the $x$ that maximizes the expectation of $y$ under $p(y|x)$.
    • For practical settings, we instead want S such that $y \geq \gamma$, where $\gamma \leq y_{max}$.
    • In the case of specification, $S$ is the set of values $y$ s.t. $y = y_{target}$ (more strictly speaking, taking on infinitesimal values around it).
    • The conditional probability of the set $S$, i.e. "the probability that our property desideratum is satisfied for a given input", is $P(S|x) \equiv P(Y \in S|x) = \int p(y|x) I_{S}(y) dy$.
      • $I_{S}(y)$ is an indicator function for whether $y \in S$, takes on $1$ if yes, or $0$ otherwise.
    • When we do thresholding, i.e. greater than a threshold, $P(S|x) = p(y \geq \gamma|x) = 1 - CDF(x, \gamma)$. (I think this in practice is just the empirical cumulative density function? Since it is the probability of $y$ greater than the threshold $\gamma$, conditioned on observed input data $x$).

Approach: Design input to satisfy desired property.

  • $S$ is the set of property values that satisfy what we want.
  • We want to maximize the expected probability that our desideratum is satisfied, where expectation is performed over a generative model distribution $p(x|\theta)$, like a VAE.
  • This gives rise to us optimizing $\theta$ w.r.t. the desirderatum.
  • We want $\hat{\theta} = \underset{\theta}{argmax} \log P(S|\theta)$, i.e. the $\theta$ that maximizes the log probability of S.
    • This is the theta that maximizes the log probability of S.
  • In the paper, a few problems are mentioned that I don't fully understand. However, we do arrive at this point where the optimization problem that we want to solve is:
    • $\theta^{(t+1)} = \underset{\theta}{argmax} \sum_{i=1}^M P(S^{(t)}|x_i^{(t)})\log p(x_i^{(t)}|\theta)$
    • In English: At each iteration $t$, we want to the $\theta$ that improves the sum of (the likelihood of observing set $S$ given the sample input set $x$ drawn times the likelihood of observing the input set generated from the generative model parameters $\theta$).
    • Key idea here: we must get the set $S$ out such that the desired property is non-vanishing. How so? Just sample from empirical distribution then.

How can we implement this using dummy data?

It's probably most instructive if we start with an $x^2$ model, from which we know ground truth and want to find the maxima adaptively.

  • Generative model for $x$: $x \sim N(\mu, \sigma)$. Here, $(\mu, \sigma) = \theta$, the set of parameters we want to optimize.
  • Designing the inputs means changing $\mu$ and $\sigma$, such that we continue to generate higher values of $x$ that maximize the $x^3$ function.

Dataset Distillation

URL: https://arxiv.org/abs/1811.10959

Notes:

Unlike these approaches, we are interested in understanding the intrinsic properties of the training data rather than a specific trained model.

Sections of the paper:

  1. Train a network with fixed initialization using one grad descent step.
  2. Use randomly initialized weights (more challenging)
  3. Linear network -- understand limitations
  4. Use more gradient steps
  5. How to obtain distilled images using different initializations.

random note to self: This is a paper where the goal fits into the paradigm of Input design.

Algorithm

The algorithm (with some paraphrasing for myself)

Firstly, the required inputs that are of interest:

  1. A prediction model $f(x, \theta)$ for the learning task at hand. $x$ is the actual data, and $\theta$ are its parameters.
  2. A distribution for initial weights, $p(\theta_0)$.
  3. The desired number of distilled data points, $M$.

Secondly, some other parameters that may be of interest:

  1. Step size, $\alpha$
  2. Batch size, $n$, such that $n \lt\lt M$.
  3. Number of optimization iterations, $T$
  4. $\tilde{\eta}{0}$, i.e. the initial value for $\tilde{\eta}$, which is a target learning rate for learning on the _distilled, synthetic data.

The algorithm in words, translated:

  1. Initialize $\tilde{x}$, the distilled input data, using a random matrix.
  2. Set $\eta$ to $\tilde{\eta}_0$.
  3. For each training step $t=1$ to $T$:
    1. Get a batch of $n$ data points from $x$. We call this $\mathbf{x_t}$, and in the paper, it is given the notation $\mathbf{x_t} = {x_{t,j} }_{j=1}^{n}$
    2. Draw a batch of initial weights from $p(\theta_0)$, call it $\theta_{0}^{(j)}$. Given the notation in the paper, this likely is "sample $j$ weight draws".
    3. For each sampled weight $\theta_{0}^{(j)}$:
      1. Update $\theta_{1}^{(j)}$ such by doing one step in the negative direction of the gradient of the loss function using the distilled data... $\theta_1^{(j)} = \theta_0^{(j)} - \tilde{\eta} \nabla_{\theta_0^{(j)}} l(\tilde{x}, \theta_0^{(j)})$.
      2. Now evaluate the new loss on actual data $L^{(j)} = l( \mathbf{x_t}, \theta_{1}^{(j)})$
    4. Then, update the distilled data values $\tilde{x}$ and $\tilde{\eta}$:
      1. $\tilde{x} \leftarrow \tilde{x} - \alpha \nabla_{\tilde{x}} \sum_{j} L^{(j)}$: "update distilled data in the negative direction of the gradient of sum of losses over all sampled weights in $\theta_t^{(j)}$"
      2. $\tilde{\eta} \leftarrow \tilde{\eta} - \alpha \nabla_{\tilde{\eta}} \sum_{j} L^{(j)}$: same semantic meaning, except we do it on $\tilde{\eta}$, the learning rate to use when using distilled data to train the model.

It seems like this is going to be a paper I need to take a second look over.

In any case, some ideas to think about:

  1. We can retrain UniRep by distilling protein space, maybe? Doing so might make retraining really easy, and we yield a set of weights that are free from the license restrictions?
  2. What about distilling graph data? How easy would this be?

Why I'm so interested in dataset distillation

It's a neat idea for a few reasons.

The biggest reason: compression + fast training. I am tired of waiting for large models to finish training. If we can distill a dataset to its essentials, then we should be able to train large models faster.

Another ancillary reason: intellectual interest! There's an interesting question for me: what is the smallest dataset that is necessary for a model to explain the data?

What would dataset distillation in a quadratic model look like?

Some things I'd expect:

  • The learned $x$ should collapse to a few points? Like it should be able to get a good model representative of the training data, using just two points.
    • We might not get such an extreme distillation, maybe?
  • On the other hand, because of the use of distribution draws for parameters, I can see how we would end up marginalizing over all possible pairs of "inducing" points
  • Oh wait, we actually get to define the number of inducing points!

The setup: We set up a bunch of training data points along the x-axis line, and create noisy $y$ outputs for the function of interest, in this case, a quadratic model. We should be able to distill points to include:

  • The hump/bowl
  • Points on both sides.

Then, we try to distill $x$ to a small set of $M = 10$ points.