Interop
PyTorch and HuggingFace plumbing: applying a PyTorch state_dict to a Lux (ps, st) pair, resolving and caching weights through the HuggingFace Hub, and loading .safetensors blobs.
State-dict application
Luximm.Interop.apply_state_dict — Function
apply_state_dict(ps, state_dict, mapping) -> psRebuild a Lux parameter NamedTuple by replacing leaves according to mapping, an iterable of (pytorch_key, lux_path, transform) triples where:
pytorch_key: dotted name as written by Python'smodel.state_dict().keys().lux_path: tuple of Symbols naming the leaf inps, e.g.(:stage1, :layer_1, :norm1, :gn, :scale).transform: functionArray{Float32} -> Array{Float32}applied to the raw HDF5-read array. Common transforms areidentity(HDF5-natural matches Lux) andaxis_reverse(full axis reversal, used when the Julia tensor layout matches PyTorch's logical axis order rather than its reversed storage order).
The original ps is not mutated; the caller must bind the return value.
Luximm.Interop.read_parity — Function
read_parity(path) -> NamedTupleRead a parity HDF5 fixture and return (input, state_dict, output). Tensors are returned as Float32 arrays in their HDF5-natural Julia layout (PyTorch axes reversed). Per-tensor permutations are the caller's responsibility.
output is a single Array{Float32} if the fixture wrote /output as a dataset, or a Dict{String, Array{Float32}} if it wrote /output/<name>.
Weight-layout transforms
Per-tensor transforms passed in apply_state_dict mappings to bridge PyTorch's stored layout and the Lux-natural layout (see Porting Backbones).
Luximm.Interop.axis_reverse — Function
axis_reverse(a) -> ArrayPermutes all axes in reverse order: (d1, d2, ..., dN) -> (dN, ..., d2, d1). Use this for tensors whose Julia layout was designed in PyTorch axis order. The HDF5 read gives back the reversed layout; this permute restores the original axis order.
Luximm.Interop.pyperm — Function
pyperm(perm) -> FunctionBuild a transform that applies a specific permutation to the HDF5-read array. Useful when the Julia axis order is neither the HDF5-natural reverse nor a full reverse.
Luximm.Interop.as_channel4d — Function
as_channel4d(a) -> ArrayReshape a (C,) PyTorch norm parameter into (1, 1, C, 1), the shape used by Lux.LayerNorm((1, 1, C); dims = 3) (and so by Luximm.Layers.layernorm2d) for its :scale / :bias leaves on WHCN 4D inputs.
Luximm.Interop.as_token_norm — Function
as_token_norm(a) -> ArrayReshape a (C,) PyTorch norm parameter into (C, 1, 1), the shape used by a channel-axis LayerNorm over a (C, T, N) token tensor (see Luximm.Layers.vit_layernorm) for its :scale / :bias leaves.
Luximm.Interop.adapt_input_conv — Function
adapt_input_conv(in_chans) -> transformBuild a state-dict transform that adapts a stem conv weight to the requested input channel count, mirroring timm.models._helpers.adapt_input_conv.
The transform takes an HDF5-natural Julia conv weight in Lux's (kW, kH, I, O) layout (returned by read_parity or load_safetensors_state_dict) and returns a (kW, kH, in_chans, O) weight following the timm recipe:
in_chans == I: no-op (identity copy asFloat32).in_chans == 1,I == 3: sum across the input-channel axis (the canonical RGB-to-grayscale collapse).in_chans == 1,I > 3,I % 3 == 0: reshape to(kW, kH, 3, I÷3, O)and sum the size-3 axis, leaving(kW, kH, I÷3, O). This branch matches timm's special case for space-to-depth stems.in_chans != I,I == 3: tile the weight across the input-channel axis to coverin_chans, truncate, and rescale by3 / in_chansso the per- output-element response on a uniform input is preserved.- Any other shape combination raises; timm itself does not support it.
Plug into apply_state_dict mappings in place of identity on the stem weight entry when in_chans != 3.
HuggingFace Hub
Luximm.Interop.hf_hub_download — Function
hf_hub_download(repo_id, filename; revision="main",
cache_dir=hf_hub_cache_dir(),
repo_type="model") -> StringResolve <repo_id>/<filename> against revision (a branch, tag, or commit) and return the local snapshot path, downloading only what is not already cached. The on-disk layout matches huggingface_hub:
<cache_dir>/models--<org>--<name>/blobs/<etag>
<cache_dir>/models--<org>--<name>/snapshots/<commit>/<filename>
-> ../../blobs/<etag>
<cache_dir>/models--<org>--<name>/refs/<revision> # text: <commit>This means a timm.create_model(..., pretrained=True) call and a hf_hub_download call against the same repo see each other's cached blob.
The function always performs a no-redirect HEAD against https://huggingface.co/<repo_id>/resolve/<revision>/<filename> to look up the current commit (X-Repo-Commit) and the blob's etag (X-Linked-ETag for LFS-backed files, ETag otherwise). If the HEAD fails (e.g. offline), the function falls back to the most recently recorded commit in refs/<revision> and returns the existing snapshot path if present; otherwise the original error is rethrown.
Set repo_type="dataset" for dataset repos; default "model" matches timm's usage.
Luximm.Interop.hf_download — Function
hf_download(url, dest) -> StringDownload url to dest unless dest already exists. Returns dest.
If HUGGING_FACE_HUB_TOKEN is set, it is sent as a Bearer token in the Authorization header. The download streams into a sibling temp file and is renamed into place only on success, so an interrupted call (network drop, ^C) never leaves a partial file at dest that a later call would mistake for a cache hit.
Luximm.Interop.hf_hub_cache_dir — Function
hf_hub_cache_dir() -> StringCache root that matches huggingface_hub (Python). Honors the same env-var precedence:
HF_HUB_CACHEif set,- otherwise
$HF_HOME/hubifHF_HOMEis set, - otherwise
~/.cache/huggingface/hub.
Files downloaded into this directory by Luximm are visible to timm / huggingface_hub, and vice versa.
SafeTensors
Luximm.Interop.load_safetensors_state_dict — Function
load_safetensors_state_dict(path; reverse_axes=true) -> Dict{String, Array{Float32}}Read a .safetensors file from disk into a dict of Float32 arrays.
When reverse_axes = true (default), every tensor's axes are reversed so the resulting layout matches read_parity's HDF5-natural Julia layout (PyTorch axes reversed). Set reverse_axes = false to keep PyTorch's logical axis order if a caller wants that layout explicitly.