LLM series - Snorkel AI: Distillation refines LLMs into smaller (richer) droplets of data science

This is a guest post for the Computer Weekly Developer Network written by Braden Hancock in his capacity as co-founder of Snorkel AI.

As a company, Snorkel AI says it is the first data-centric Artificial Intelligence (AI) platform powered by state-of-the-art techniques in programmatic data labeling and weak supervision.

Hancock writes as follows…

Off-the-shelf Large Language Models (LLMs) have become a favourite tool for proof-of-concept demonstrations due to their surprising accuracy on classification tasks they were not trained for.

While LLMs can quickly demonstrate the feasibility of solving a given task with AI, they rarely make sense for production use cases; even when LLMs achieve acceptable levels of accuracy (which they usually won’t without customisation) – their size, cost, and latency make them impractical for high-throughput tasks.

As a result, we believe that data science teams will turn their focus back to smaller, faster models in 2024 – and we’re not alone. At our recent LLM Summit, 74% of poll respondents said they expect their teams to use LLMs to build smaller models in the new year… and Hugging Face Co-Founder and CEO Clem Delangue echoed a similar prediction.

Distillation to smaller forms

That doesn’t mean those teams won’t benefit from the LLM advances regularly released by ML practitioners and researchers. Instead of using an LLM directly as the final model, data scientists can use them to derive smaller models with similar task-specific performance in a much smaller- and therefore faster and cheaper – form factor.

This process, called “distillation,” can kickstart data-labeling projects and it makes foundation models a vital part of the enterprise AI ecosystem.

So then, why not use LLMs for production-level predictive tasks? 

Enterprise data science teams will avoid using LLMs for production predictive applications for the following reasons:

  • Accuracy. Even with advanced prompting techniques, off-the-shelf models rarely achieve production-grade accuracy on enterprise tasks and hallucinated answers remain a challenge.
  • Cost. Due to their billions of parameters, each inference call to an LLM costs many times that of a specialised neural network, which will typically contain orders of magnitude fewer parameters.
  • Latency. As a side effect of their size, large LLMs return predictions slowly. This poses a serious blocker for applications expected to make many predictions per second.

In the best-case scenario, using an LLM for categorical predictions is like using a supercomputer to play Pong. It might work, but it wastes money.

How distillation works

Distillation bootstraps data labeling efforts for predictive tasks.

In its simplest form, a data scientist takes a large amount of unlabeled data and asks a “teacher” LLM to label it. Then, they use this labeled data to “teach” a smaller “student” model that mirrors the accuracy of the LLM on that particular task in a memory footprint orders of magnitude smaller.

If the teacher model and prompting technique yield reliable results, the team can deploy the student model and move on. Most likely, this will not be the case, but the training data produced can serve as a solid starting point for further data development. Data scientists can algorithmically identify the labels most likely to be wrong and work with subject matter experts to correct them. Through iteration, the project should produce a model that exceeds its target performance metrics in a fraction of the time manual labeling alone would require.

Let’s now look at how to improve on distillation with advanced techniques.

While distillation creates an effective starting point, it rarely produces a final model. The “student” can only mirror the effectiveness of the “teacher” and the teacher has likely not been trained on the project’s specific task. However, data scientists can elevate distilled model performance by approaching the data labeling task from multiple angles.

For example, imagine a project to classify emails as spam or not spam.

The team builds a highly-engineered prompt and submits it to GPT-3.5, PaLM 2 and Llama 2. The best of the bunch achieves 70% accuracy, but when the models “vote” accuracy increases to 75%.

The team then further targets specific behaviours known to be “spammy.” Does the email ask for money or personal information? Probably spam. These signals can also originate from non-LLM sources. A regex search for certain phrases such as “you’re a winner” might achieve high precision, as might using a third-party list of known spam senders.

As the number of signals increases, “voting” can get complicated and nuanced. Algorithmic approaches well-established in the field of weak supervision can help sort through the noise intelligently and reduce the time needed to label production-grade data from weeks or months to days or hours. The result: high quality data in a fraction of the time.

Advanced distillation

How advanced distillation can build powerful predictive models faster than ever

A lack of labeled data has long been the biggest blocker of enterprise AI projects. In the past, data labeling could only reasonably be achieved through large-scale manual efforts. Modern methods like distillation – and especially advanced distillation techniques – make LLMs a vital tool for enterprise data science teams to build valuable, useful models faster than ever.