One Weird Trick to Speed Up Your TensorFlow Model 100X...

written by Eric J. Ma on 2020-02-13 | tags: deep learning optimization software development data science

You’ve got a TensorFlow 1.13 model on hand, written by an academic group, and you know it’s going to be useful, because you’ve already done some preliminary experiments with it. The model takes in protein sequences, and returns vectors, one per protein sequence. The vectors turn out to be excellent descriptors for downstream supervised learning, to predict protein activity from sequence data.

The only problem? It’s slow. Excruciatingly slow. Not just on CPU. It’s slow on your local GPU as well. You can’t run this locally without waiting for hours to process the thousands of protein sequences that you have on hand.

Slow? What now?

What do we do, then? Install a newer GPU with faster processors and more RAM? “Lift and shift” onto cloud GPUs?

I’d argue that all of those are evil choices, because premature optimization is the root of all evil (Donald Knuth). Moving the model onto a GPU without knowing the root cause of the model’s slowness could be construed as a form of premature optimization, because we are prematurely moving to more powerful hardware without first finding out whether it can be improved on existing ones. Moving the model onto a cloud GPU, renting another person’s computer while not maximizing pushing the model to its extremes locally, is probably even worse, because we end up with more variables to debug. By the aforementioned premises and Knuth’s logical conjugation, those options are, therefore, evil choices.

Rather provocatively, I’d like to suggest that one right option when faced with a very slow deep learning model is to make it fast on a single CPU. If we make the model fast locally on a single CPU, then we might have a good chance of dissecting the source of the model’s slowness, and thus make it really fast on different hardware.

This is exactly what we did with the UniRep model. Our protein engineering intern in Basel, Arkadij, had already done the groundwork testing the model on our high performance computing cluster. It was slow in our hands, and so we decided it was time to really dig into the model.


Because the model is an RNN model that processes sequential data, we need to know how the model processes one window of sequence first. In the deep learning literature, this is known as an “RNN cell”. We isolated that part of the RNN, and carefully translated each line of code into NumPy. Because it processes one “window” or ”step” at a time, we call it mlstm1900_step.

Because the RNN cell is “scanned” over a single sequence, we then defined the function that processes a single sequence. To do this, we took advantage of the jax.lax submodule’s scan function, which literally “scans” the RNN cell over the entire sequence the way TensorFlow’s RNN cell would be scanned. This basically is a vectorized for-loop. Thanks to Stephan Hoyer pointing out my conceptual error here; lax.scan is a compiled for-loop with known number of iterations. Together, that defined the function call that processes a single sequence. This, incidentally, gives us the “batch-“ or “sample-wise” RNN function, mlstm1900_batch.

Finally, we need a function that processes multiple sequences that are the same length. Here, we need the function to be vectorized, so that we do not incur a Python loop penalty for processing sequences one at a time. This forms the mlstm1900 layer, which, semantically speaking, takes in a batch of sequences and returns a batch of outputs.


So, does the reimplementation work in the same way as the original? Absolutely. We passed an attrappe (Google it, it’s German) sequence through the model, and checked that the reps were identical to the original - which they were. But don’t take my word for it, check out the draft preprint we wrote to document this! Meanwhile, in case you don't want to click over and read our hard work, here's the figure linked from the source. (It's only off by a bit because we added a delta to make the two representations visible.)

Congruence between the reps.

Besides that, we also added tests and elaborate docstrings for the functions in there, so anybody reading the model source code can know what the semantic meaning of each of the tensors’ axes are. Having wrestled through undocumented academic code, knowing what input tensor dimensions were supposed to be were one of the hardest things to grok from looking at a relatively poorly undocumented model.

By carefully reimplementing the deep learning model in pure JAX/NumPy, we were able to achieve approximately 100X speedup over the original implementation on a single CPU. Given JAX’s automatic compilation to GPU and TPU, the speed improvements we could obtain might be even better, though we have yet to try it out. Again, don't take my word at it - check out the draft preprint if you're curious! The exact figure, though is below.

Speed comparison between the original and `jax`-based reimplementation.

I suspect that by reimplementing the model in NumPy, we basically eliminated the TensorFlow graph compilation overhead, which was the original source of slowness. By using some of JAX’s tricks, including lax.scan and jax.vmap, we could take advantage of vectorized looping that the JAX developers built into the package, thus eliminating Python looping overhead.


So, now, how do we get the model distributed to end-users? This is where thinking about the API matters. I have this long-held opinion that data scientists need some basic proficiency in software skills, as we invariably end up doing tool-building whether we want to or not. Part of software skills involves knowing how to package the model for use in a variety of settings, and thinking about how the end-user would use it. Then, only, have we made a data product that is of use to others.

Two settings we thought of are the Python user setting, and a web user setting.

Python-user setting

Serving Python users means building pip-installable packages.

In the Python world, it’s idiomatic to consider a Python package as the unit of software that is shipped to others. Some packages are large (think NumPy), others are diverse (think SciPy), some are both (e.g. scikit-learn), and yet others do one and only one thing well (think humanize). UniRep probably falls within to the last category.

To fit this idiom, we decided to package up the model as a Python-style package that is pip-installable into the environment of an end-user. While we develop the package, it is pip-installable from our GitHub repository (where we also give full credit and reference to the original authors for the original). We might release it to PyPI, with licenses and the likes as well.

Moreover, we needed a nice API.

With UniRep, there’s basically two things one would want to do:

  1. Compute “reps” for a sequence conditioned on the pre-trained weights, and
  2. Evolutionarily-tune ("evo-tune") the model weights to compute more locally-relevant reps.

As such, we designed a get_reps(seqs) API that would correctly process single sequences, multiple sequences of same lengths, and multiple sequences of varying lengths, while not needing the user to think about that.

For evolutionary tuning, we desired that the API worked regardless of whether the user tuned on sequences of equal or varying lengths. In other words, we wanted to wrap the complexity of handling sequences of multiple lengths away from the end-user. This is something we are working on right now, but at least the API sketch as it stands now gives the end-user a single evotune(sequences) function to work with. Designing it this way allows us to work backwards from the outside-in, starting with the desired API and flexibly designing the inner parts. A friend of mine, Ex-Googler Jesse Johnson at Cellarity, wrote an excellent blog post on this.


With a pip-installable package on-hand, we could then depend on it to build other single-use tools. For example, we built a web API for the model by using FastAPI, and designed it to return the calculated representations for a single sequence based on the pre-trained weights. Because evo-tuning is computationally expensive, we don’t plan to offer that on the web API. The web API is also packaged up as a pip-installable package, and can be installed inside a Docker container, thus forming the basis of a deep learning-powered micro-service that pre-calculates “standard”, non-tuned representations of every new protein we might encounter.

Lessons Learned

On the fragility of deep learning models

Our modern way of communicating the architecture of deep learning models is still a work-in-progress, with overloaded terms and ambiguity present. As such, there were minor details that we reimplemented incorrectly that incidentally also had a major impact on the correctness of the reimplementation. For example, the sigmoid() function in TensorFlow was different from the other implementations that we had seen in multiple codebases (we detail this in our preprint), yet both are colloquially described as “sigmoids”.

We also know from adversarial examples that deep learning models are notoriously fragile against inputs that have noise added to them. This has implications in the design of the API to the model. If we design the user-facing API to accept tensors rather than strings, then we leave the burden of correctly translating string inputs into tensors (with semantically correct dimensions) on the end-user. If the end-user processed the strings in a way that the model is not robust against, it will handle the processed tensors in a way that may be incorrect for their purposes; yet, there is no robust way to write a unit test against this. The key problem here is that the semantic meaning of the tensors cannot be easily and automatically tested.

As such, in order for the user-facing API of the model to be “friendly” (so to speak), it must accept data in a format that the intended end-user will be handling. This is nothing new in the world of software development, to be user-centric. At the same time, because of its small size and the way we structure our codebase, all of the internals are easily accessible, and how they fit together is also not difficult to understand.

Why tools, and not full apps?

The environment that we're in is a research environment. That, by definition, means we're always doing non-standard things. Moreover, a model is usually but one small component in a large interconnected web of tools that come together to make a useful, consumable app. Rather than try to go end-to-end, it makes more sense to produce modular parts that fit well with one another. The mantra I like to keep in mind here is:

Modularity and composability give us the flexibility for creativity.

And if creativity is not the heart of research, what then is?

If you want to accelerate a deep learning model...

how about making sure it first runs well on a single CPU?

But beyond that, also consider that whole frameworks and their associated overheads might not necessarily be what you need. In our case, we didn't need the entire TF framework to get the model reimplemented. We only really needed XLA (pre-compiled in jaxlib for our systems) and the NumPy API to get it right. Also consider whether a framework is getting in the way of your modelling work, or if it's helping. (If it's the latter, keep using what you're using!)

Encouraging a co-creation mindset

Co-creation is the ethos of the open source world, where unrelated people come together and build software that is useful to others beyond themselves. I've experienced it many times on the pyjanitor and nxviz projects. Technically, nobody has to ask for permission to fork a project, make changes, and even submit a pull request back. Of course, we do so anyways, just for the polity of being courteous to the original creator - it's always good to let someone know something's coming their way.

Pair-coding, as it turns out, happens to be a wonderful way to encourage co-creation. Much of the reimplementation work was done pair coding with my co-author, Arkadij Kummer, who lives and works in our Basel site. Yes, we're separated by six hours of time, but that doesn't stop us from using all of the good internet-enabled tools available to us to collaborate together. We used MS Teams at work for video chatting, and sometimes just did VSCode with VSLiveShare (with audio) to develop code together. All work was version-controlled and shared via Git, which means we could work asynchronously when needed as well. The time spent pair coding is time where knowledge is shared, trust is built, and mentoring relationships are forged. In an upcoming blog post, I will write about pair coding as a practice in data science, particularly for levelling-up junior members of the data science team and sharing knowledge between peers.


I hope you enjoyed reading about our journey reimplementing UniRep in JAX. Keep in mind that JAX isn't necessarily always going to speed up your TF model 100X; I'm clearly being facetious in writing it that way. JAX and TF share XLA underneath the hood, Matt Johnson pointed out that JAX uses XLA by default, while TF does not, yet XLA is really the secret sauce that makes the models execute fast. The only difference is TF1.x needed this compilation time, while JAX does it just-in-time (when ordered to do so). Because JAX gets us autodiff on the NumPy API, it's a super productive research and teaching tool with a familiar API to many Pythonistas. My hope is this post encourages you to try it out; happy experimenting!