Getting Started

This page walks through Luximm.jl end-to-end: installing the package, loading a pretrained ResNet50 and predicting an ImageNet class, switching to feature-extractor mode, working with non-RGB inputs, and the HuggingFace cache layout.

Installation

Luximm targets Julia 1.12 or newer and is registered in the Julia General registry, so it installs with Pkg.add directly:

using Pkg
Pkg.add("Luximm")

For local hacking, clone the repo and Pkg.develop the path:

using Pkg
Pkg.develop(path = "/path/to/Luximm.jl")

The Python sidecar (pyproject.toml with timm, torch, h5py, safetensors, huggingface-hub) is only needed to regenerate parity fixtures for new backbones. Inference and pretrained-weight loading work from Julia alone. See Testing for the contributor setup.

Loading ResNet50 and predicting an ImageNet class

using Luximm, Lux, Random

# `create_pretrained` is family-agnostic; the symbol selects the
# family. It returns the model and a closure that loads the released
# weights into a `(ps, st)` pair you produce with `Lux.setup`.
model, load = create_pretrained(:resnet50_a1_in1k)
ps, st = Lux.setup(Xoshiro(0), model)
ps, st = load(ps, st)
st = Lux.testmode(st)

x = randn(Float32, 224, 224, 3, 1)
logits, _ = model(x, ps, st)              # (1000, 1)

top5 = partialsortperm(vec(logits), 1:5; rev = true)

A few things worth noting in the snippet:

  • Variant keys are Julia symbols. :resnet50_a1_in1k is the timm model name resnet50.a1_in1k with the dot rewritten as an underscore. The full name with the dot is at RESNET_VARIANTS[:resnet50_a1_in1k].hf_repo.
  • create_pretrained defaults num_classes to the variant's default_num_classes (1000 for the resnet50.a1_in1k ImageNet checkpoint). Pass an explicit num_classes = 0 for a features-only model, or any other Int for a custom head — see the next two sections.
  • Lux.setup produces a ps (parameters) and st (state) NamedTuple. The closure returns a new (ps, st) with the HuggingFace weights merged in. Stateless families (BiT, ConvNeXt, ConvNeXtV2) return st unchanged; ResNet merges BatchNorm running statistics into st. Call Lux.testmode(st) before inference so BatchNorm uses those running statistics instead of the current batch's statistics; for the stateless families it is a no-op but still a safe default.
  • The input is shaped (W, H, C, N), Lux's convention. PyTorch's (N, C, H, W) is read-reversed at load time so most weights land in the layout Lux expects directly.
  • x here is random noise. For a real prediction, replace it with a preprocessed image: resize to 224x224, scale to Float32 in [0, 1], then normalize with the ImageNet mean and std.
  • create_model(variant; ...) (without weight loading) is also exported for from-scratch training and for embedding into an outer @compact — see Composing into a larger model below.

To map the top-5 indices to class names, pair the variant with any ImageNet class label list (Luximm ships none of its own; the timm imagenet_classes.txt works directly because the class index ordering is unchanged).

Feature extractor mode

Pass num_classes = 0 to drop the classifier head and get the post-stage feature map back instead of logits. This matches timm.create_model(..., num_classes=0).forward_features(x):

model, load = create_pretrained(:resnet50_a1_in1k; num_classes = 0)
ps, st = Lux.setup(Xoshiro(0), model)
ps, st = load(ps, st)
st = Lux.testmode(st)

x = randn(Float32, 224, 224, 3, 1)
features, _ = model(x, ps, st)            # (7, 7, 2048, 1)

The output is shaped (W/32, H/32, num_features, N) for every registered backbone. For ResNet50 that is (7, 7, 2048, 1) on a 224x224 input; for convnextv2_atto_fcmae it is (7, 7, 320, 1).

Use this mode when you want to attach Luximm's pretrained encoder to your own downstream head (regression, segmentation, neural ODE, etc.) without carrying around the 1000-class classifier the released checkpoint would otherwise initialize.

Single-channel and other non-RGB inputs

Pass in_chans to create_pretrained. The closure adapts the released 3-channel weight at load time, matching timm's adapt_input_conv semantics: sum across input channels for in_chans = 1, tile and rescale by 3 / in_chans for other counts.

model, load = create_pretrained(:convnextv2_atto_fcmae; in_chans = 1)
ps, st = Lux.setup(Xoshiro(0), model)
ps, st = load(ps, st)
st = Lux.testmode(st)                     # BatchNorm/Dropout in eval mode

x = randn(Float32, 224, 224, 1, 1)
features, _ = model(x, ps, st)            # (7, 7, 320, 1)

This is the right path for grayscale medical or scientific imagery, where collapsing the RGB stem is preferable to repeating the single-channel input three times.

Composing into a larger model

create_pretrained returns the backbone and a closure that already knows the variant; pass prefix = (:backbone,) so the closure writes into the right subtree of the outer Lux.setup result. The backbone itself goes into the outer @compact block as a value:

backbone, load_backbone = create_pretrained(:resnet50_a1_in1k;
    num_classes = 0, prefix = (:backbone,))

outer = @compact(
    backbone = backbone,
    head     = Dense(2048 => num_outputs),
) do x
    head(backbone(x))
end

ps, st = Lux.setup(Xoshiro(0), outer)
ps, st = load_backbone(ps, st)

The mapping function prefixes every leaf path with prefix..., so apply_state_dict overwrites just the ps.backbone.* and st.backbone.* subtrees. Anything you added (:head, sibling slots, deeper nestings) keeps its Lux.setup initialization. Deeper nestings chain symbols: prefix = (:encoder, :backbone).

Switching families

create_pretrained (released weights) and create_model (random init) both dispatch on the variant symbol, so picking a different family is just picking a different symbol — no per-family entry points to learn. Each family registers its variants under one <FAMILY>_VARIANTS dict, which is enumerable for variant discovery:

See the API Reference for the full signatures.

Pretrained weights and the HuggingFace cache

The create_pretrained closure resolves model.safetensors against the standard HuggingFace Hub cache layout (HF_HUB_CACHE$HF_HOME/hub~/.cache/huggingface/hub), so the same blob is shared with timm and huggingface_hub: whichever tool downloads first, the other sees a cache hit. Subsequent calls short-circuit on the cached snapshot symlink.

For gated repos, set HUGGING_FACE_HUB_TOKEN in the environment before calling the loader. The token is forwarded as a bearer header on the resolve and download requests.

For lower-level access, hf_hub_download returns the resolved snapshot path directly and mirrors the cache semantics of huggingface_hub.hf_hub_download:

path = hf_hub_download("timm/resnet50.a1_in1k",
                       "model.safetensors";
                       revision = "main")
state_dict = load_safetensors_state_dict(path)

This is the escape hatch for cases where you want to inspect or transform the raw PyTorch state dict before applying it.