Layers
Reusable building blocks shared across model families. These exist to match specific timm constructs in semantics, layout, and default parameters. They are not intended as a general-purpose layer library.
Building blocks
Luximm.Layers.std_conv — Function
std_conv(kW, kH, in, out; stride=1, pad=0, eps=1f-8) -> @compact blockCross-correlation convolution whose kernel is standardized at forward time. Use this in place of Conv when porting a timm model that wraps its convolutions in StdConv2d (BiT-ResNet, NFNet, etc.).
Luximm.Layers.layernorm2d — Function
layernorm2d(C; eps=1f-6) -> Lux.LayerNormLayerNorm with per-channel affine parameters for a 4D tensor in Lux's (W, H, C, N) layout. Normalizes over the channel axis at each spatial location and batch element, then applies a learnable scale and bias. Numerically equivalent to timm's LayerNorm2d.
Implemented as Lux.LayerNorm((1, 1, C); dims = 3, epsilon = eps), so the affine parameter leaves :scale and :bias have shape (1, 1, C, 1). PyTorch state-dict entries <prefix>.weight and <prefix>.bias are stored as (C,) and must be reshaped to (1, 1, C, 1) when loading (see Luximm.Interop.as_channel4d).
Luximm.Layers.grn_layer — Function
grn_layer(C; eps=1f-6) -> @compact blockGlobal Response Normalization for (W, H, C, N) tensors. Computes the L2 norm of each channel's spatial map, normalizes that by the channel-mean of those norms, and applies a per-channel affine x + bias + scale * (x * n) with both parameters initialized to zero.
Named grn_layer to keep the @compact field name :grn available for use in containing blocks, so PyTorch keys like mlp.grn.weight map directly to (..., :grn, :scale).
PyTorch state-dict keys <prefix>.weight and <prefix>.bias map to the :scale and :bias leaves of this layer with the identity transform.
Initializers
Luximm.Layers.kaiming_normal_fan_out — Function
kaiming_normal_fan_out(rng, [T,] dims...) -> Array{T}PyTorch's nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='relu').
std = sqrt(2 / fan_out) where fan_out follows PyTorch's _calculate_fan_in_and_fan_out: out_channels * prod(receptive_field). Lux Conv weight shape is (kW, kH, in, out), so fan_out = dims[end] * prod(dims[1:end-2]). For Dense weight shape (out, in), fan_out = dims[1].
Luximm.Layers.normal_init — Function
normal_init(; std) -> (rng, dims...) -> Array{Float32}Closure that produces Normal(0, std) samples in Float32, mirroring PyTorch's nn.init.normal_(weight, mean=0., std=std). Used for timm's BiT classifier head (std = 0.01).