Porting Backbones
This page is a contributor guide for adding a new timm backbone to Luximm.jl. The bar is numeric parity with timm on the released weights; the rest is mechanics.
If you are using Claude Code to drive the port, the same workflow is encoded as an agent-facing skill at .claude/skills/timm-to-lux/SKILL.md and loads automatically inside this repo. The skill assumes the Kaimon REPL workflow described in .claude/skills/kaimon-julia/SKILL.md. This page covers the same ground for human contributors without those tools.
Acceptance criteria
A new backbone is mergeable when all four hold:
- Pretrained parity. Forward output of the Lux model with weights loaded via the closure returned by
create_pretrainedmatchestimm's forward output on the same input. The bar is two-tier: logits are checked at an absolute max-abs-diff underLOGITS_ATOL = 1f-3, and features (num_classes = 0, and thein_chans = 1companion) are checked at a relative barmax-abs-diff / max-abs(timm ref)underFEATURES_RTOL = 1f-4. Existing models land well inside this: BiT ResNetV2-50 features around1.5e-4absolute (well under the relative bar at typical feature magnitudes), its logits around2e-5; ConvNeXtV2 atto comparable. - Random-init parity. With the Lux model initialized using the same
_init_weightsrecipetimmuses for the family, andtimminitialized with the same RNG seed, forward outputs match the same logits and features bars. See_CN2_INITatsrc/Models/ConvNeXtV2/Model.jl(truncated_normal(mean = 0f0, std = 0.02f0)mirroring timm'strunc_normal_(std = 0.02)) for a worked example. - State-dict round-trip. The mapping function consumes every PyTorch
state_dictkey. Mappings raise on a missing key so silent random-init leaks cannot pass parity by coincidence; see the assertion block at the end ofconvnextv2_mappinginsrc/Models/ConvNeXtV2/Model.jl. - Variant table entry. New variants are registered in
<FAMILY>_VARIANTS(src/Models/<Family>/Config.jl) with their HuggingFace repo id and default class count, and listed in the README backbone table.
Why the bars are ~1e-3 (logits) and ~1e-4 (features), not 1e-5
Float32 round-off accumulates through the depth of a network. Each conv, norm, and weight-standardization stage reorders sums in ways that differ from PyTorch's BLAS or cuDNN kernels even when the math is semantically identical, so a 50-layer ResNet against the timm reference reliably lands near 1e-4 absolute diff on order-one activations. The shallower head (one pool plus one Dense) sits closer to 1e-5 because almost no new accumulation happens after the backbone.
The test gates are LOGITS_ATOL = 1f-3 (absolute, logits) and FEATURES_RTOL = 1f-4 (relative, features). Both are tight enough to catch the silent-divergence failure modes that actually matter (cross-correlation vs convolution, population vs sample variance, GELU approximation mismatch, axis permutations, missing norm epsilon), and loose enough not to fail on legitimate float32 reordering. The features bar is relative because raw pre-norm features on deep / wide backbones (large and xlarge ConvNeXt, huge ConvNeXtV2) drift by ~1e-3 to ~2e-3 absolute even when their downstream logits stay near 1e-5 after the LayerNorm + classifier squashes them; a relative bar keeps the check scale-free across tiny through huge variants. If your port reports a logits max-abs-diff in the 1e-2 range, or a features relative diff in the 1e-2 range or higher, that is a real bug, not round-off; bisect it with the per-stage parity fixtures described below.
Reference example
src/Models/ResNetV2/ is the canonical, fully-worked port. Read it first. It exercises most of the shared utilities in one place:
- Pre-activation residual blocks.
- Weight-standardized convolutions (
std_convfromLuximm.Layers). - GroupNorm with explicit epsilon.
adapt_input_convstem adaptation for non-RGB inputs.- The full mapping/loader pattern that every other family follows.
If your port can be expressed as parameter-table changes on top of an existing family, do that. Adding a new variant to BIT_VARIANTS or CONVNEXTV2_VARIANTS is a one-row PR. Adding a new family is the multi-file workflow below.
Workflow
Phase 1: Capture a parity fixture
Each port starts with an HDF5 fixture produced by a small Python sidecar in test/parity/. The sidecar uses the shared _dump_common.py helpers and follows the dump_<family>_io.py convention. Look at test/parity/dump_resnetv2_bit_io.py for the template.
The fixture stores /input (deterministic random input, PyTorch NCHW layout), /state_dict/<key> for every PyTorch parameter, and /output/features plus optional /output/logits. The Julia side reads it via Luximm.Interop.read_parity, which reverses the axes so tensors arrive in Lux's WHCN layout.
Run the dump once per variant:
uv run python test/parity/dump_<family>_io.py --variant <timm_name> --out data/parity/<key>_io.h5Optionally dump a single-channel companion fixture (--in-chans 1, output data/parity/<key>_in1c_io.h5) so the in_chans = 1 parity test can run as well.
For a brand-new family, capture more than one fixture: the end-to-end one gives a pass/fail signal with no localization power. Add at least three random seeds and per-stage intermediates by registering forward hooks on model.stages[i]. The per-stage fixtures are the bisection tool when the end-to-end test fails.
Phase 2: Reuse the shared utilities
Do not re-implement anything that already lives under src/Layers/ or src/Interop/. The load-bearing helpers and what they do:
Luximm.Interop.read_parityreturns(input, state_dict, output)asFloat32arrays in WHCN layout (the reverse of PyTorch's logical NCHW). Conv weight(out, in, kH, kW)becomes(kW, kH, in, out), which is exactly Lux's Conv layout. For most parameters the layout is what you want and the per-key transform isidentity.apply_state_dictrebuilds the parameterNamedTupleby setting leaves from the dict. Mapping entries are triples(pytorch_key, lux_path_tuple, transform). Non-mutating; bind the result.Luximm.Interop.axis_reverseandLuximm.Interop.pypermare ready-made transforms for the cases where the HDF5-natural layout is not what you want, typically Dense weights and LayerNorm scale/bias.std_conv,layernorm2d,grn_layerare the building blocks that matchtimm'sStdConv2dSame, channel-axisLayerNorm2d, and Global Response Norm.
Keep the per-family weight mapping in src/Models/<Family>/Model.jl as <family>_mapping(state_dict, variant; prefix, num_classes, in_chans), returning a Vector{Tuple{String, Tuple{Vararg{Symbol}}, Function}}. The prefix argument lets a backbone be nested under a wrapper model.
Phase 3: Implement the model with @compact
Lux's @compact is the right primitive for composing layers. The pattern is fixed:
@compact(
conv1 = Conv((3, 3), in_ch => out_ch; pad = 1, cross_correlation = true),
norm1 = GroupNorm(out_ch, 32; affine = true, epsilon = 1f-5),
) do x
@return NNlib.relu.(norm1(conv1(x)))
endNumeric conventions that bite if you forget them, in rough order of how often they cost real debugging time:
- Cross-correlation, always. PyTorch's
Conv2dis cross-correlation; Lux'sConvdefaults to true convolution (kernel-flipped). Passcross_correlation = trueto everyConv. Without it, weights load with the right shape but produce mirrored outputs. When you must drop intoNNlib.convdirectly (weight standardization is the canonical case), passflipkernel = trueonNNlib.DenseConvDims. Same semantic, two flags. - Explicit padding when zero-padding matters.
Conv((k, k), ...; pad = p)works for symmetric same-value padding. For pooling that must pad with zeros instead of-Inf(timm's BiT stem is the canonical case), callNNlib.pad_zeros(x, (l, r, t, b, 0, 0, 0, 0))first and usepad = 0on the op. - Norm defaults are not portable. Always pass
epsilonandaffineexplicitly onGroupNorm,LayerNorm,BatchNorm. PyTorch'snn.GroupNormuseseps = 1e-5; Lux's default differs. Mismatched epsilons silently shift activations and look like a flaky parity failure. - Variance corrections. Sample variance (Bessel-corrected, the Julia default) and population variance (
corrected = false, what PyTorch uses for BN-style stats and for weight standardization) differ by a factor ofN / (N - 1). Passcorrected = falsetovarwhenever you are matching a BN- or WS-style operation. Seestd_convinsrc/Layers/StdConv.jl. Lux.testmode(st)for parity tests. Otherwise BatchNorm running stats update and any dropout activates, neither of which is whatmodel.eval()does on the PyTorch side.
The forward should look like math: a sequence of broadcasts and tensor ops with no scalar control flow. No x[i] = ..., no Array(x) inside the forward, no if/else on tensor values, no scalar indexing.
Phase 4: Wire the variant table and mapping
src/Models/<Family>/Config.jl holds the variant catalog:
- The
<Family>Variantstruct captures the architectural knobs (depths, dims, stem channels) plushf_repo,default_num_classes, anddefault_input_size. <FAMILY>_VARIANTS :: Dict{Symbol, <Family>Variant}lists every registered variant.
Variant keys are the timm model name with the dot rewritten as an underscore (so the key is a single Julia identifier). The full dot-separated name lives at <FAMILY>_VARIANTS[key].hf_repo.
Pick one variant and finish it before adding the second. When a model family has many variants, port the smallest first: it surfaces every shared numeric trap with the fastest test loop, and the mapping function written for one variant typically generalizes by changing only depths or widths. Adding a second variant before the first passes parity creates code paths nothing has exercised and that the first failed test cannot localize.
Phase 5: Add the loader
Each family exposes a load_<family>_pretrained(ps, st, variant; num_classes, in_chans, revision, cache_dir, prefix) -> (ps, st) function. All four share this signature: stateless families (BiT/ConvNeXt/ConvNeXtV2) take st and return it unchanged; ResNet mutates it to merge BatchNorm running stats. The flow is identical across families:
- Look up the variant in
<FAMILY>_VARIANTS. - Validate
num_classesagainstdefault_num_classes(or allow 0 for backbone-only). - Call
hf_hub_downloadto resolve the snapshot path under the HuggingFace Hub cache layout. - Load via
load_safetensors_state_dict. Passreverse_axes = trueso the safetensors arrays end up in the same WHCN layout the HDF5 fixtures produce; this lets one<family>_mappingfunction serve both fixture-driven tests and production loading. - Apply via
apply_state_dict. - For families with BatchNorm running stats (ResNet only), additionally apply the state dict to the model state via
<family>_state_mappingandapply_state_dict.
Once the variant is registered in <FAMILY>_VARIANTS, it is reachable through the family-agnostic create_pretrained and create_model dispatchers automatically — no extra wiring.
Keep the constructor and the weight loading separate. The constructor returns a @compact block; the loader takes the result of Lux.setup and returns a new (ps, st). Mixing the two inside @compact makes the model unusable in tests that want a random init.
Phase 6: Verify parity end-to-end
The verification loop is layered. Run the cheap gates between meaningful edits.
End-to-end first, against the fixture's state_dict (not the HF download) so the test isolates the forward pass:
data = Luximm.Interop.read_parity(_FIXTURE_PATH)
model = <family>(variant; in_chans = 3, num_classes = 0)
ps, st = Lux.setup(Xoshiro(0), model)
st = Lux.testmode(st)
ps = Luximm.Interop.apply_state_dict(ps, data.state_dict,
<family>_mapping(data.state_dict, variant))
y, _ = model(data.input, ps, st)
expected = data.output["features"]
diff = maximum(abs.(y .- expected))
rel = diff / max(maximum(abs.(expected)), eps(Float32))
@test rel < 1f-3The convenience wrapper scripts/test_variant.sh runs exactly this for one variant, dumping the fixture first if absent. See Testing for the variant-test workflow.
When parity fails, the per-stage and per-block fixtures pay off. Walk the forward by hand: run the partial forward up to each stage and compare against data.output["stage_i"]. The first stage where parity breaks localizes the bug. Inside the failing stage, splice the matching state-dict entries into a single block in isolation and compare against the per-block fixture.
Phase 7: Update the bookkeeping files
After parity passes, two repo files need to stay in sync. Forgetting either is a silent regression in CI's path filtering or the user-facing variant table.
README.mdand the API reference. Update the backbone table at the top ofREADME.md, and confirm the API reference page (docs/src/api/models.md) lists the new constructor and loader.ci/JimmCI/src/PathFilter.jl. Update_FAMILY_PREFIXES,_FAMILY_EXACT,ALL_FAMILIES,REPRESENTATIVE_VARIANT, and the_SHARED_*lists as appropriate. Without this, CI will silently skip tests for the new code on PR-scope runs; only a fulljimm-ci --mastersweep would catch it.
Pitfalls reference
A running checklist. Every item here has cost real debugging time on a previous port.
- Cross-correlation. Every
Convneedscross_correlation = true. EveryNNlib.convneedsflipkernel = trueon itsDenseConvDims. - Zero-padded pooling.
NNlib.maxpool(pad = 1)pads with-Inf.nn.MaxPool2d(padding = 1)pads with zero. When the input has negative values (post-norm activations, leaky-relu outputs), this diverges. Usepad_zerosfirst, then pool withpad = 0. - GN/BN/LN defaults. Always pass
epsilonandaffineexplicitly. Never trust the Lux default to match the PyTorch default. - Variance correction. Pass
corrected = falsewhen matching anything BN-style or WS-style. - Axis order on
apply_state_dict. HDF5 fixtures arrive reversed (Lux-natural). SafeTensors arrive PyTorch-natural by default; passreverse_axes = truetoload_safetensors_state_dictso both sources share one mapping function. forward_featuresvsforward. The fixture and the Julia forward must agree on which one was dumped. Mismatch shows up as a shape error if you are lucky, a silent wrong number if you are not.- Pre-activation block order. norm then activation then conv vs conv then norm then activation. Both compile. Only one matches the upstream architecture.
timm'sadapt_input_convforin_chans != 3. The stem weight in the safetensors file is the 3-channel version. The loader appliesadapt_input_convat load time to collapse it. Do not re-collapse on the Julia side.Lux.testmode(st). Required for parity. Forgetting it makes the test occasionally pass and occasionally fail depending on RNG state inside dropout.- Generalize only after one variant passes. Adding a second variant before the first lands creates code paths that the first failed test cannot localize.