Layers
Reusable building blocks shared across model families. These exist to match specific timm constructs in semantics, layout, and default parameters. They are not intended as a general-purpose layer library.
Building blocks
Luximm.Layers.std_conv — Function
std_conv(kW, kH, in, out; stride=1, pad=0, eps=1f-8) -> @compact blockCross-correlation convolution whose kernel is standardized at forward time. Use this in place of Conv when porting a timm model that wraps its convolutions in StdConv2d (BiT-ResNet, NFNet, etc.).
Luximm.Layers.layernorm2d — Function
layernorm2d(C; eps=1f-6) -> Lux.LayerNormLayerNorm with per-channel affine parameters for a 4D tensor in Lux's (W, H, C, N) layout. Normalizes over the channel axis at each spatial location and batch element, then applies a learnable scale and bias. Numerically equivalent to timm's LayerNorm2d.
Implemented as Lux.LayerNorm((1, 1, C); dims = 3, epsilon = eps), so the affine parameter leaves :scale and :bias have shape (1, 1, C, 1). PyTorch state-dict entries <prefix>.weight and <prefix>.bias are stored as (C,) and must be reshaped to (1, 1, C, 1) when loading (see Luximm.Interop.as_channel4d).
Luximm.Layers.grn_layer — Function
grn_layer(C; eps=1f-6) -> @compact blockGlobal Response Normalization for (W, H, C, N) tensors. Computes the L2 norm of each channel's spatial map, normalizes that by the channel-mean of those norms, and applies a per-channel affine x + bias + scale * (x * n) with both parameters initialized to zero.
Named grn_layer to keep the @compact field name :grn available for use in containing blocks, so PyTorch keys like mlp.grn.weight map directly to (..., :grn, :scale).
PyTorch state-dict keys <prefix>.weight and <prefix>.bias map to the :scale and :bias leaves of this layer with the identity transform.
Luximm.Layers.se_block — Function
se_block(C; rd_ratio=1/16, rd_divisor=8, rd_channels=nothing, act=NNlib.relu) -> @compact blockSqueeze-and-excitation channel-attention block for (W, H, C, N) tensors. Global-average-pools each channel to a scalar, runs a two-layer 1x1-conv bottleneck (fc1 C→rd → act → fc2 rd→C), and rescales the input by the per-channel sigmoid gate. The bottleneck width rd is rd_channels when given (e.g. CoAtNet's int(attn_ratio * mid_chs)), otherwise se_make_divisible(C * rd_ratio, rd_divisor), matching timm's SEModule. act is the bottleneck activation (ReLU by default; CoAtNet's later recipes use SiLU).
PyTorch keys <prefix>.fc1.weight/bias and <prefix>.fc2.weight/bias map to the :fc1 / :fc2 Conv leaves with the identity transform.
Luximm.Layers.patch_embed — Function
patch_embed(in_chans, embed_dim; patch=16) -> @compact blockSplit a (W, H, in_chans, N) image into patch x patch patches via a stride-patch conv and return a token tensor (embed_dim, num_tokens, N) where num_tokens = (W/patch) * (H/patch) and tokens run width-fastest, matching timm's PatchEmbed.
PyTorch keys <prefix>.proj.weight / <prefix>.proj.bias map to the :proj Conv leaves (identity, or adapt_input_conv for the weight when in_chans != 3).
Luximm.Layers.mhsa — Function
mhsa(dim; num_heads, qkv_bias=true) -> @compact blockMulti-head self-attention over a (dim, T, N) token tensor. Splits a fused qkv projection into num_heads heads of width dim ÷ num_heads, computes scaled dot-product attention (scale 1/sqrt(head_dim), softmax over the key axis), merges heads, and applies the output proj. Numerically matches timm's Attention with qk_norm=False.
PyTorch keys map as: attn.qkv.weight → (:qkv, :weight) (axis_reverse), attn.qkv.bias → (:qkv, :bias) (identity), attn.proj.weight → (:proj, :weight) (axis_reverse), attn.proj.bias → (:proj, :bias) (identity).
Luximm.Layers.vit_block — Function
vit_block(dim; num_heads, mlp_ratio=4, eps=1f-6) -> @compact blockA single pre-norm ViT encoder block over a (dim, T, N) token tensor: LayerNorm → mhsa → residual, then LayerNorm → MLP (fc1 → exact GELU → fc2) → residual.
PyTorch keys map as: norm1/norm2.{weight,bias} → (:norm1/:norm2, :scale/:bias) (as_token_norm), attn.* via mhsa, mlp.fc1/fc2.* → (:fc1/:fc2, :weight/:bias) (axis_reverse/identity).
Luximm.Layers.vit_layernorm — Function
vit_layernorm(dim; eps=1f-6) -> Lux.LayerNormLayerNorm over the channel axis (axis 1) of a (dim, T, N) token tensor, with per-channel affine. Matches timm's ViT nn.LayerNorm(dim, eps=1e-6). The affine :scale / :bias leaves have shape (dim, 1, 1), so the PyTorch (dim,) parameter is reshaped with Luximm.Interop.as_token_norm when loading.
Luximm.Layers.rel_pos_attention — Function
rel_pos_attention(dim, dim_out; dim_head=32, window) -> @compact block2D relative-position-bias multi-head self-attention over a (dim, ...) WHCN feature map. qkv is a fused 1x1 conv dim => 3*dim; proj is a 1x1 conv dim => dim_out. Heads number dim ÷ dim_head. A learned bias from rel_pos_bias_table (gathered by a precomputed relative-position index for the window = (Wh, Ww) grid) is added to the attention scores before the key-axis softmax. Numerically matches timm's Attention2d(head_first=True, expand_first=False) with RelPosBias.
PyTorch keys: qkv.weight/bias → (:qkv, :weight/:bias) (identity/identity, 1x1 conv), proj.weight/bias → (:proj, :weight/:bias), rel_pos.relative_position_bias_table → (:rel_pos_bias_table,) (axis_reverse, to (num_rel, num_heads)).
Helpers
Luximm.Layers.se_make_divisible — Function
se_make_divisible(v, divisor=8; min_value=divisor) -> Inttimm's make_divisible: round v to the nearest multiple of divisor, never dropping below min_value and never losing more than 10% of v. Used to size the SE bottleneck channel count.
Luximm.Layers.rel_pos_index — Function
rel_pos_index(Wh, Ww) -> Matrix{Int}Relative-position index grid for an Wh x Ww window, matching timm's gen_relative_position_index. Returns a (L, L) matrix (L = Wh*Ww, tokens in width-fastest order) of 1-based row indices into a (2*Wh-1)*(2*Ww-1)-row bias table. Entry [i, j] is the bias-table row for the relative offset between query token i and key token j.
Initializers
Luximm.Layers.kaiming_normal_fan_out — Function
kaiming_normal_fan_out(rng, [T,] dims...) -> Array{T}PyTorch's nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='relu').
std = sqrt(2 / fan_out) where fan_out follows PyTorch's _calculate_fan_in_and_fan_out: out_channels * prod(receptive_field). Lux Conv weight shape is (kW, kH, in, out), so fan_out = dims[end] * prod(dims[1:end-2]). For Dense weight shape (out, in), fan_out = dims[1].
Luximm.Layers.normal_init — Function
normal_init(; std) -> (rng, dims...) -> Array{Float32}Closure that produces Normal(0, std) samples in Float32, mirroring PyTorch's nn.init.normal_(weight, mean=0., std=std). Used for timm's BiT classifier head (std = 0.01).