Automating GPU-Aware Finetuning for Embedding Models

At the startup where I work, the core product is a platform for retrieval-augmented classification. Customers provide labeled data, which is embedded and stored to guide a classifier. The biggest lever for adapting those classifiers to a customer’s domain is usually finetuning the embedding model.

I was recently working on a customer project doing intent classification over long support conversations. It ended up driving a complete rewrite of the finetuning system, and I wanted to share a few interesting lessons I learned along the way.

Fixing SentenceTransformer's Group by Label Batch Sampler

The first issue I ran into was that BatchTripletLoss from SentenceTransformers wasn’t training well. This loss finds the farthest same-class sample (the positive) and closest different-class sample (the negative) for each sample (the anchor) in the batch, and then pushes the anchor toward the positive and away from the negative. Contrastive losses like this have the benefit of shaping the embedding space directly, preserving useful neighborhood structure better than simply adding a linear prediction head and training the embedding model to minimize the cross-entropy directly (our baseline). But it didn’t seem to be working.

The first issue was embarrassingly simple: the sampler was broken. The special thing about BatchTripletLoss is that it does not require precomputing (anchor, positive, negative) triplets, but instead mines them from inside the batch during training, based on the current state of the finetuned embedding model. It performs in-batch online mining. This requires that each batch contains at least two samples of each class and that at least two classes are present in the batch. Otherwise the loss cannot find a positive and negative sample to compare against for each anchor from the batch. That is what the recommended GroupByLabelBatchSampler is supposed to guarantee.

The process is called PK sampling: P classes per batch, K samples per class. As long as P > 1 and K > 1, the loss has something to mine. But it turned out that a big regression had slipped into this batch sampler in the version 3 release of SentenceTransformers, leading it to produce ~99% single-class batches (and thus zero gradients). Anyone using the recommended sampler was training on mostly nothing, and we had missed it when upgrading because our integration tests were only running on tiny datasets that hid the issue.

After narrowing the issue down to this bug in the sampler, I filed a PR that fixed it upstream. Together with the great maintainers at Hugging Face, I came up with a proper PK sampling strategy and a dynamic P-shrinking step that preserves the structure when the underlying class distribution is imbalanced. Sample utilization on imbalanced data went from ~9% to >99%.

A Contrastive Loss That Scales

The sampler fix made triplet loss train, but it didn’t make it good. Since the loss can only mine positives and negatives from inside the batch, the size of the batch sets the size of the comparison pool, and on this project the batch was tiny. The p99 token length on our dataset was around 1,800 tokens, and at that sequence length the device batch we could fit on a 40 GB GPU was ~22. With at most a couple of same-class candidates per anchor and a few dozen across-class candidates to pick negatives from, the loss never got enough signal.

The instinct here might be to reach for LoRA or a smaller model, but neither helps much. For long-context embedding training, activation memory dominates: the attention matrices and intermediate states scale with sequence length, not parameter count. LoRA mostly reduces trainable parameter and optimizer memory, not the activations that dominate long-sequence training. A smaller base model can help, but when p99 inputs are around 1,800 tokens, the sequence length and logical batch size are usually the binding constraints. Long sequences and big batches both have to fit on the same card, and there was no way to make triplet’s batch bigger from this end. Test F1 plateaued 6% below the prediction head baseline.

We looked to the literature for inspiration. The best general-purpose embedding models these days are trained with InfoNCE-style contrastive losses on batches with thousands of samples, and they make those batches fit by chunking the encoder forward pass with GradCache (Gao et al., 2021): encode a few samples at a time without gradients, compute the contrastive loss over the cached embeddings, then re-encode with gradients and replay the cached gradients. The objective is intended to be identical to the unchunked version, but peak activation memory is bounded by the chunk size rather than the batch size. The comparison pool can grow far past what fits in one forward pass. SentenceTransformers has a flagship implementation of this in CachedMultipleNegativesRankingLoss.

That looked like exactly what we needed, but not quite. MultipleNegativesRankingLoss (MNRL) expects (query, positive) pairs as input, and treats every other sample in the batch as a negative for each anchor. But our data had labels, with many samples per class and many classes per dataset. Force-fitting our data into MNRL would have made same-label samples in the batch “negatives” for each other, which was obviously not what we wanted.

The right answer for labeled data is supervised contrastive learning (Khosla et al., 2020). SupCon generalizes InfoNCE to the labeled case: every same-label sample in the batch is a positive for the anchor, every different-label sample is a negative, and the loss takes a mean over positive pairs in the log-probability space. It’s the natural label-aware version of MNRL, and it’s stronger than triplet in the way that matters once you’re trying to scale:

  • Triplet uses one positive and one negative per anchor. SupCon uses all same-label positives and all different-label negatives in the batch. Same batch, much richer signal per step.
  • Triplet’s hardest-pair mining is a non-smooth max over the batch, so gradient flows through one pair. SupCon is smooth: every positive contributes to the log-sum-exp numerator, every negative contributes to the denominator. Training is less spiky.
  • The batch size leverage compounds. With triplet, a 4× larger batch means 4× more anchors, but each anchor still mines exactly one positive and one negative, so per-anchor signal doesn’t change. With SupCon, every anchor’s loss term itself grows with the batch: 4× more positives in the numerator, 4× more negatives in the denominator, all contributing gradient. Triplet’s gradient signal scales linearly with batch size. SupCon’s pairwise signal grows roughly quadratically.

SentenceTransformers doesn’t have SupCon at all, neither the regular version nor the cached version. So I built it. SupervisedContrastiveLoss for the base case, and CachedSupervisedContrastiveLoss for when the encoder forward doesn’t fit, with the same three-phase GradCache structure as cached MNRL:

  1. Encode without gradients and cache the embeddings plus the per-mini-batch random states.
  2. Compute the SupCon loss over all cached embeddings and cache the gradients with respect to them.
  3. Then re-encode with gradients and restore the random state so dropout masks line up exactly with phase one, before injecting the cached gradients via a surrogate dot-product.

This kind of three-phase setup is easy to get silently wrong: if the dropout masks don’t line up exactly between phase one and phase three, the gradient is mathematically wrong but training still looks healthy. To ensure correctness, I added a test that asserts the cached and unchunked versions produce the same loss and the same encoder gradients to within a small epsilon, including the orphan-anchor edge case where a class has only one sample in the batch. I also added the gather-with-grad path for distributed training, so training on N GPUs grows the comparison pool to the global batch instead of computing N independent local losses.

SupCon worked. It beat the prediction baseline by a comfortable margin when run on medium batch sizes, not just on this project but also across several other projects. Contrastive loss just needed the right implementation and batch size to work well. It also preserved useful neighborhood structure in the embedding space better, improving the efficacy of our other analysis tools.

But getting there took a lot of trial and error across losses, batch sizes, learning rates, and sequence lengths — most of it hyperparameter search by hand. The next thing to build was making that search a first-class part of the platform.

Compiling Semantic Training Intents into Hardware-Grounded Execution Plans

Two things needed to be automated: which hardware shape an individual run uses, and the search across runs to find good hyperparameters. The platform runs on a lot of different GPU configurations — hosted cloud, self-hosted, Docker, or binaries for local experiments. So the per-device batch size that actually fits depends on the model, the sequence length, the dtype, and whether gradient checkpointing is on. We needed to abstract all of that away from the domain model we expose to customers, so they could just express their intent like this:

model.finetune(
    dataset,
    EmbeddingFinetuneConfig(
        loss="contrastive",
        batch_size=[64, 128, 256],
        learning_rate=(1e-5, 1e-4),
        max_seq_length="p99",
        trial_count=12,
    ),
)

The platform compiles these semantic knobs — loss type, logical batch size, truncation strategy — into a hardware-specific training plan: per-device batch size, gradient accumulation steps, mini-batch chunking, gradient checkpointing, and multi-GPU gathering. Any field can be a scalar for a single run, a list of choices, or a (min, max) tuple range; the platform searches across the combinations within the trial_count budget. The important constraint is that the compile step is not allowed to change the loss semantics. It can change how the batch is encoded, split, gathered, or accumulated, but it cannot quietly turn one contrastive objective into many smaller ones.

To avoid mid-epoch OOMs during training that waste hours of compute and cause poisoned CUDA processes that need a restart, the platform needs to determine the maximum batch size that can fit on a GPU. I started with a heuristic based on the model’s parameter count, embedding dimension, dtype, and the data’s maximum sequence length. That got close, but it had to be conservative — in one case it was off by roughly 40%, which was enough to turn a reasonable sweep into a much slower one. So I ended up only using the heuristic as a safe starting point and empirically measuring the peak memory of a forward and backward pass at the maximum sequence length to extrapolate a more accurate value for the maximum possible batch size. From there the rest of the execution plan is relatively straightforward to determine. The code looks something like this:

def resolve_training_plan(config, model, dataset):
    max_seq_length = resolve_seq_len(
        config.max_seq_length, # "p95", "p99", "p90", "max", or int
        tokenizer=model.tokenizer,
        dataset=dataset,
    )

    dtype = "bf16" if is_bf16_supported() else "fp32"
    device_count = get_device_count()
    device_batch_limit = probe_batch_limit(model, max_seq_length, dtype)

    gather_across_devices = config.loss == "contrastive" and device_count > 1

    # enable gradient checkpointing if needed,
    # this trades ~30% throughput for larger batch sizes
    gradient_checkpointing = (
        config.loss in {"contrastive", "triplet"}
        and config.batch_size > device_batch_limit * device_count
    )
    if gradient_checkpointing:
        device_batch_limit = probe_batch_limit(
            model, max_seq_length, dtype,
            gradient_checkpointing=True
        )

    return TrainingExecutionPlan(
        loss=config.loss,
        batch_size=config.batch_size,
        dtype=dtype,
        max_seq_length=max_seq_length,
        device_count=device_count,
        device_batch_limit=device_batch_limit,
        gather_across_devices=gather_across_devices,
        gradient_checkpointing=gradient_checkpointing,
    )

Sweeps share this training plan across all trials — the device probe only depends on the model, sequence length, and dtype, so it doesn’t repeat. What changes per trial is the logical batch size, which means the batch split has to recompute. The split depends on the loss. Additive losses like the prediction head or class-proxy loss, where examples compare against learned class representatives instead of other examples in the batch, can accumulate gradients across chunks, which is mathematically identical to running the full batch at once. For in-batch contrastive losses, gradient accumulation is wrong, since accumulating chunks gives many small contrastive losses with small comparison pools rather than one loss over the full pool, so the platform automatically switches to cached SupCon and chunks only the encoder. Triplet mining needs every sample on the same device, so the platform errors out when a triplet batch doesn’t fit on one GPU instead of silently splitting it across devices or chunking it into a different loss.

def resolve_batch_split(plan):
    per_step_capacity = plan.device_batch_limit * plan.device_count

    match plan.loss:
        case "contrastive":
            return BatchSplit(
                per_device=plan.batch_size // plan.device_count,
                encoder_chunk=min(plan.batch_size, per_step_capacity),
            )

        case "triplet":
            if plan.batch_size > plan.device_batch_limit:
                raise ValueError("Triplet batch does not fit")
            return BatchSplit(per_device=plan.batch_size)

        case "prediction" | "class-proxy":
            return BatchSplit(
                per_device=min(
                    plan.batch_size // plan.device_count,
                    plan.device_batch_limit,
                ),
                grad_accum=ceil(
                    plan.batch_size / per_step_capacity,
                ),
            )

        case _ as unreachable:
            assert_never(unreachable)

Customers don’t even have to pick ranges. Setting trial_count=9 alone tells the platform to inject the highest-impact ranges for the chosen loss within the budget, so the customer doesn’t need to know which knobs matter for which loss. Underperformers get pruned mid-training and the winner’s checkpoint is promoted as the final model, so a sweep delivers the same artifact as a single run.

Modeling Becomes One Knob

By the end of the modeling work, test F1 had climbed from 0.78 with triplet to 0.86 with SupCon, an eight-point gain that closed roughly a third of the remaining gap to perfect classification. Then it stopped. I spent several more rounds trying to push it past 0.87 with different models, input formats, training schedules, weight decay, bigger batches, and various forms of data pruning, but none of it moved the number. The obvious next step was online hard-negative mining from a vector database, which would have extended the comparison pool beyond the batch, and I had a sketch of it ready to go. But before shipping it, I took a closer look at the test data and found about 5.7% of examples were either mislabeled or genuinely ambiguous. That made the plateau make sense: the model wasn’t obviously the bottleneck anymore.

The platform’s real value showed up there. When training comes down to filling out a config, trying another loss or batch size or sequence length costs almost nothing. If none of it moves the number, you can’t tell yourself you just haven’t tried hard enough. The leverage moves to the data: better labels, targeted synthetic generation, and error-driven oversampling. That’s the data-centric ML argument, and it’s the one Orca is built around. Most teams I’ve worked with under-invest in that side because modeling feels like progress in a way that data work doesn’t.