Interop

PyTorch and HuggingFace plumbing: applying a PyTorch state_dict to a Lux (ps, st) pair, resolving and caching weights through the HuggingFace Hub, and loading .safetensors blobs.

State-dict application

Luximm.Interop.apply_state_dictFunction
apply_state_dict(ps, state_dict, mapping) -> ps

Rebuild a Lux parameter NamedTuple by replacing leaves according to mapping, an iterable of (pytorch_key, lux_path, transform) triples where:

  • pytorch_key: dotted name as written by Python's model.state_dict().keys().
  • lux_path: tuple of Symbols naming the leaf in ps, e.g. (:stage1, :layer_1, :norm1, :gn, :scale).
  • transform: function Array{Float32} -> Array{Float32} applied to the raw HDF5-read array. Common transforms are identity (HDF5-natural matches Lux) and axis_reverse (full axis reversal, used when the Julia tensor layout matches PyTorch's logical axis order rather than its reversed storage order).

The original ps is not mutated; the caller must bind the return value.

source
Luximm.Interop.read_parityFunction
read_parity(path) -> NamedTuple

Read a parity HDF5 fixture and return (input, state_dict, output). Tensors are returned as Float32 arrays in their HDF5-natural Julia layout (PyTorch axes reversed). Per-tensor permutations are the caller's responsibility.

output is a single Array{Float32} if the fixture wrote /output as a dataset, or a Dict{String, Array{Float32}} if it wrote /output/<name>.

source

Weight-layout transforms

Per-tensor transforms passed in apply_state_dict mappings to bridge PyTorch's stored layout and the Lux-natural layout (see Porting Backbones).

Luximm.Interop.axis_reverseFunction
axis_reverse(a) -> Array

Permutes all axes in reverse order: (d1, d2, ..., dN) -> (dN, ..., d2, d1). Use this for tensors whose Julia layout was designed in PyTorch axis order. The HDF5 read gives back the reversed layout; this permute restores the original axis order.

source
Luximm.Interop.pypermFunction
pyperm(perm) -> Function

Build a transform that applies a specific permutation to the HDF5-read array. Useful when the Julia axis order is neither the HDF5-natural reverse nor a full reverse.

source
Luximm.Interop.as_channel4dFunction
as_channel4d(a) -> Array

Reshape a (C,) PyTorch norm parameter into (1, 1, C, 1), the shape used by Lux.LayerNorm((1, 1, C); dims = 3) (and so by Luximm.Layers.layernorm2d) for its :scale / :bias leaves on WHCN 4D inputs.

source
Luximm.Interop.as_token_normFunction
as_token_norm(a) -> Array

Reshape a (C,) PyTorch norm parameter into (C, 1, 1), the shape used by a channel-axis LayerNorm over a (C, T, N) token tensor (see Luximm.Layers.vit_layernorm) for its :scale / :bias leaves.

source
Luximm.Interop.adapt_input_convFunction
adapt_input_conv(in_chans) -> transform

Build a state-dict transform that adapts a stem conv weight to the requested input channel count, mirroring timm.models._helpers.adapt_input_conv.

The transform takes an HDF5-natural Julia conv weight in Lux's (kW, kH, I, O) layout (returned by read_parity or load_safetensors_state_dict) and returns a (kW, kH, in_chans, O) weight following the timm recipe:

  • in_chans == I: no-op (identity copy as Float32).
  • in_chans == 1, I == 3: sum across the input-channel axis (the canonical RGB-to-grayscale collapse).
  • in_chans == 1, I > 3, I % 3 == 0: reshape to (kW, kH, 3, I÷3, O) and sum the size-3 axis, leaving (kW, kH, I÷3, O). This branch matches timm's special case for space-to-depth stems.
  • in_chans != I, I == 3: tile the weight across the input-channel axis to cover in_chans, truncate, and rescale by 3 / in_chans so the per- output-element response on a uniform input is preserved.
  • Any other shape combination raises; timm itself does not support it.

Plug into apply_state_dict mappings in place of identity on the stem weight entry when in_chans != 3.

source

HuggingFace Hub

Luximm.Interop.hf_hub_downloadFunction
hf_hub_download(repo_id, filename; revision="main",
                 cache_dir=hf_hub_cache_dir(),
                 repo_type="model") -> String

Resolve <repo_id>/<filename> against revision (a branch, tag, or commit) and return the local snapshot path, downloading only what is not already cached. The on-disk layout matches huggingface_hub:

<cache_dir>/models--<org>--<name>/blobs/<etag>
<cache_dir>/models--<org>--<name>/snapshots/<commit>/<filename>
    -> ../../blobs/<etag>
<cache_dir>/models--<org>--<name>/refs/<revision>     # text: <commit>

This means a timm.create_model(..., pretrained=True) call and a hf_hub_download call against the same repo see each other's cached blob.

The function always performs a no-redirect HEAD against https://huggingface.co/<repo_id>/resolve/<revision>/<filename> to look up the current commit (X-Repo-Commit) and the blob's etag (X-Linked-ETag for LFS-backed files, ETag otherwise). If the HEAD fails (e.g. offline), the function falls back to the most recently recorded commit in refs/<revision> and returns the existing snapshot path if present; otherwise the original error is rethrown.

Set repo_type="dataset" for dataset repos; default "model" matches timm's usage.

source
Luximm.Interop.hf_downloadFunction
hf_download(url, dest) -> String

Download url to dest unless dest already exists. Returns dest.

If HUGGING_FACE_HUB_TOKEN is set, it is sent as a Bearer token in the Authorization header. The download streams into a sibling temp file and is renamed into place only on success, so an interrupted call (network drop, ^C) never leaves a partial file at dest that a later call would mistake for a cache hit.

source
Luximm.Interop.hf_hub_cache_dirFunction
hf_hub_cache_dir() -> String

Cache root that matches huggingface_hub (Python). Honors the same env-var precedence:

  1. HF_HUB_CACHE if set,
  2. otherwise $HF_HOME/hub if HF_HOME is set,
  3. otherwise ~/.cache/huggingface/hub.

Files downloaded into this directory by Luximm are visible to timm / huggingface_hub, and vice versa.

source

SafeTensors

Luximm.Interop.load_safetensors_state_dictFunction
load_safetensors_state_dict(path; reverse_axes=true) -> Dict{String, Array{Float32}}

Read a .safetensors file from disk into a dict of Float32 arrays.

When reverse_axes = true (default), every tensor's axes are reversed so the resulting layout matches read_parity's HDF5-natural Julia layout (PyTorch axes reversed). Set reverse_axes = false to keep PyTorch's logical axis order if a caller wants that layout explicitly.

source