Is this a game, or is it real?

JAX Performance: From 63 Minutes to 2 Minutes

Written on

30x Speedup with vmap: Why JAX is Built for RL Research

I continue to be impressed with the JAX library as I learn how to use it. In my last post I mentioned switching from PyTorch to JAX for the alberta-framework. I had seen about a 2.78x speedup just from using jax.lax.scan properly instead of Python loops. That was a real "ah ha" moment and I had another one this morning when my honeypot sample size grew from ~3k log entries to ~300k. My first run took about an hour so I asked claude code for improvements that would work on my current hardware and it recommended vmap.

The results were dramatic enough that I wanted to document them.

The chronos-sec Project

chronos-sec is a side project where I'm applying reinforcement learning to cybersecurity. I have a Cowrie SSH honeypot running on a VPS in Germany collecting real attack traffic. Each SSH session gets translated into a 17-dimensional feature vector:

  • Session features: connection age, auth attempts, unique usernames tried, commands executed, files downloaded, lateral movement attempts
  • Behavioral features: command diversity (Shannon entropy), timing randomization (human vs. bot detection)
  • Context features: IP reputation, country code, bytes sent/received, hour and day of week (cyclical encoding)

The goal is to train an RL agent that learns which attacker behaviors are worth engaging with longer to collect more intelligence. Right now I'm validating the foundational learning algorithms (IDBD, Autostep) on this real-world data before building the full agent.

The Problem: Multi-Seed Experiments are Expensive

I'm running RL experiments with multiple random seeds to get reliable results. The standard practice is to run 30 or more seeds and report means with confidence intervals.

For chronos-sec, I'm testing 4 experimental conditions (IDBD vs Autostep, each with and without normalization) across ~300k observations of attack traffic. Running that sequentially means:

4 conditions × 30 seeds = 120 sequential runs

Each run has to process all 276k observations through the learning loop. Even with JAX's compiled scan loops, that was taking over an hour of wall-clock time.

The Solution: Batched Execution with vmap

JAX's vmap (vectorized map) transforms a function that processes one input into a function that processes a batch of inputs in parallel. For learning loops, this means instead of running:

Seed 1 → Seed 2 → ... → Seed 30 (serial)

You run:

[Seed 1, Seed 2, ..., Seed 30] (parallel)

I added batched versions of the learning loops to alberta-framework:

def run_learning_loop_batched(learner, stream, num_steps, keys):
    def single_seed_run(key):
        return run_learning_loop(learner, stream, num_steps, key)

    # vmap over the keys dimension - all 30 seeds run simultaneously
    return jax.vmap(single_seed_run)(keys)

That's essentially it. The inner run_learning_loop already uses jax.lax.scan for the sequential timesteps, so we're composing two JAX primitives: vmap for parallelism across seeds, and scan for the temporal sequence within each seed.

The Results

Running the full chronos-sec Step 1 validation experiment (4 conditions × 30 seeds × 276k observations):

Condition Sequential Batched Speedup
IDBD+Norm 18.5 min 37.6s ~29.5x
IDBD 8.6 min 17.7s ~29x
Autostep+Norm 22 min 43.6s ~30x
Autostep 14 min 27.7s ~30x
Total ~63 min ~2.1 min ~30x

The speedup factor is almost exactly 30x because we're running 30 seeds in parallel.

Why This Works

The RTX 3070 has 5,888 CUDA cores and 8GB VRAM. My linear learner for these experiments has only 17 features (the observation dimensions from the honeypot state vectors). The memory footprint per seed is tiny:

  • Weights: 17 float32 values (68 bytes)
  • Optimizer state (IDBD): ~51 float32 values per weight (3.5 KB)
  • Normalizer state: ~34 float32 values (136 bytes)

Multiply by 30 seeds and you're still under 200KB of state. The observations array is broadcast (shared memory, not copied 30x), so memory isn't the bottleneck. The GPU can easily run 30 copies of the learner in parallel.

The throughput numbers confirm this. For the batched IDBD run:

469,695 obs×seed/s = 276,000 obs × 30 seeds / 17.7s

That's 8.3 million observation-seed pairs processed in under 18 seconds. We hit nearly the theoretical maximum speedup (30x for 30 seeds) because the learner states are small enough that all 30 copies fit comfortably in GPU memory, letting the CUDA cores stay fully utilized.

The JAX Pattern: scan + vmap

This is a textbook example of how JAX primitives compose:

  1. jax.lax.scan handles the sequential part (stepping through 276k observations one at a time)
  2. jax.vmap wraps around the outside to run 30 copies in parallel (one per seed)
# From alberta_framework/core/learners.py

# Inner: scan loops through timesteps sequentially
def step_fn(carry, idx):
    l_state, s_state = carry
    timestep, new_s_state = stream.step(s_state, idx)
    result = learner.update(l_state, timestep.observation, timestep.target)
    return (result.state, new_s_state), result.metrics

def run_learning_loop(learner, stream, num_steps, key):
    learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)
    (final_learner, _), metrics = jax.lax.scan(
        step_fn, (learner_state, stream_state), jnp.arange(num_steps)
    )
    return final_learner, metrics

# Outer: vmap runs 30 seeds in parallel
def run_learning_loop_batched(learner, stream, num_steps, keys):
    def single_seed_run(key):
        return run_learning_loop(learner, stream, num_steps, key)

    batched_states, batched_metrics = jax.vmap(single_seed_run)(keys)
    return batched_states, batched_metrics

The key insight is that JAX compiles all of this down to native GPU machine code. Think of it like a C compiler, but instead of targeting your CPU, it targets the GPU. The first time you call the function it's slow (compiling), but after that it runs as fast as the hardware allows with no Python interpreter in the loop.

Comparison to PyTorch

I'm thankful I stumbled across JAX for the alberta-framework before I got too far into implementing in PyTorch. I don't know PyTorch well enough to intelligently talk about how to do this in PyTorch but Claude code tells me that in PyTorch, you'd need to either:

  1. Manually batch by rewriting your learning loop to handle a batch dimension everywhere (tedious, error-prone)
  2. Use DataParallel/DistributedDataParallel which are designed for multi-GPU training, not multi-seed single-GPU experiments
  3. Just run sequentially and wait an hour

JAX's functional design means vmap works automatically on any pure function. You write the single-seed version, and vmap gives you the batched version for free.

The Bigger Picture

This speedup matters for more than just convenience. When experiments take an hour, you run fewer of them. When they take 2 minutes, I can iterate quickly:

  • Test hyperparameter variations
  • Add more seeds for tighter confidence intervals
  • Run ablation studies
  • Debug with real-time feedback

Not to mention data is going to continue to grow as the honeypot continues to collect more data.

For my research, this means I can move faster on the experimental side. The 30x speedup on multi-seed experiments is on top of the 2.78x speedup I already got from scan over eager PyTorch loops. Compounding those gains is what makes JAX worth the learning curve.

What's Next

I'm continuing work on the chronos-sec experiments with this new batched infrastructure. The reactive lag analysis measuring how normalizers adapt during attack waves is the next piece. Having fast iteration cycles lets me explore more variations of the normalization strategies.

If you're doing RL research and haven't tried JAX yet, I'd encourage you to take a look and learn the primitives (scan, vmap, jit). The upfront investment pays dividends when you need to scale experiments.


I'm assuming I'm going to hit a tipping point where I'll need cloud compute to run longer/larger experiments but performance increases like these are allowing me to continue to run on my "little" gaming rig.

Lab Setup:

  • OS: Debian Linux
  • GPU: NVIDIA RTX 3070 (8GB VRAM, 5,888 CUDA cores)
  • CPU: Intel i5-12400
  • Memory: 24GB RAM

Code: