Defining Rules
Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations, which we discuss before discussing the more involved process of actually writing rules.
Simplifying Code via Overlays
Mooncake.@mooncake_overlay — Macro
@mooncake_overlay method_exprDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.
For example, suppose that you have a function
julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:
julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.
A Worked Example
To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!
First, consider a simple example:
julia> scale(x) = 2x
scale (generic function with 1 method)
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))We can use @mooncake_overlay to change the definition which Mooncake.jl sees:
julia> Mooncake.@mooncake_overlay scale(x) = 3x
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.
Additionally, it is possible to use the usual multi-line syntax to declare an overlay:
julia> Mooncake.@mooncake_overlay function scale(x)
return 4x
end
julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))Functions with Zero Adjoint
If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:
Mooncake.@zero_adjoint — Macro
@zero_adjoint ctx sigEquivalent to @zero_derivative ctx sig ReverseMode. Consult the docstring for @zero_derivative for more information.
Mooncake.zero_adjoint — Function
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}Utility functionality for constructing rrule!!s for functions whose adjoints always return zero.
NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.
You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:
julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)
julia> world = Base.get_world_counter();
julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}, world) = true;
julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```
Using ChainRules.jl
ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.
There is enough similarity between these two systems that most of the boilerplate code can be avoided.
Mooncake.@from_rrule — Macro
@from_rrule ctx sig [has_kwargs=false]Equivalent to @from_chainrules ctx sig has_kwargs ReverseMode. See @from_chainrules for more information.
Adding Methods To rrule!! And build_primitive_rrule
If the above strategies do not work for you, you should first implement a method of Mooncake.is_primitive for the signature of interest:
Mooncake.is_primitive — Function
is_primitive(ctx::Type, mode::Type{<:Mode}, sig::Type{<:Tuple}, world::UInt)Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of context ctx in mode mode at world age world.
julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64}, Base.get_world_counter())
trueworld is needed as rules which Mooncake derives are associated to a particular Julia world age. As a result, anything declared a primitive after the construction of a rule ought not to be considered a primitive by that rule. One can explicitly derive a new rule (eg. via build_frule, build_rrule, or a function from the higher-level interface such as prepare_derivative_cache, prepare_pullback_cache or prepare_gradient_cache) after new @is_primitive declarations, should it be needed in cases where a rule has been previously derived. To see how this works, consider the following:
julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode, @is_primitive
julia> foo(x::Float64) = 5x
foo (generic function with 1 method)
julia> old_world_age = Base.get_world_counter();
julia> @is_primitive DefaultCtx ReverseMode Tuple{typeof(foo),Float64}
julia> new_world_age = Base.get_world_counter();
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, old_world_age)
false
julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, new_world_age)
trueObserve that is_primitive returns false for the world age prior to declaring foo a primitive, but true afterwards. For more information on Julia's world age mechanism, see https://docs.julialang.org/en/v1/manual/methods/#Redefining-Methods .
Then implement a method of one of the following:
Mooncake.rrule!! — Function
rrule!!(f::CoDual, x::CoDual...)Performs the forwards-pass of AD. The tangent field of f and each x should contain the forwards tangent data (fdata) associated to each corresponding primal field.
Returns a 2-tuple. The first element, y, is a CoDual whose primal field is the value associated to running f.primal(map(x -> x.primal, x)...), and whose tangent field is its associated fdata. The second element contains the pullback, which runs the reverse-pass. It maps from the rdata associated to y to the rdata associated to f and each x.
using Mooncake: zero_fcodual, CoDual, NoFData, rrule!!
y, pb!! = rrule!!(zero_fcodual(sin), CoDual(5.0, NoFData()))
pb!!(1.0)
# output
(NoRData(), 0.28366218546322625)Mooncake.build_primitive_rrule — Function
build_primitive_rrule(sig::Type{<:Tuple})Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that a method of _is_primitive(context_type, ReverseMode, sig) exists, preferably by using the @is_primitive macro. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.
Extended Help
The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.
Canonicalising Tangent Types
For some differentiation rules, Mooncake performs an explicit canonicalisation step inside frule!!/rrule!! that collapses heterogeneous array and tangent types into a small set of canonical representations. By canonicalising at the rule boundary, a single implementation can support many combinations of argument and tangent types without duplicating logic or relying on complex dispatch. This allows the remainder of the rule to assume a single, well-defined tangent representation.
Recall that rrule!! methods in Mooncake receive CoDual-wrapped arguments, including the function itself. Each CoDual carries both a primal value and an associated tangent (or FData). Consider a kron rule:
function Mooncake.rrule!!(
::CoDual{typeof(kron)},
x1::CoDual{<:AbstractVecOrMat{<:T}},
x2::CoDual{<:AbstractVecOrMat{<:T}},
) where {T<:Base.IEEEFloat}
# Canonicalise inputs: although this method constrains `x1`/`x2` to `AbstractVecOrMat`,
# they may still be realised by many concrete array types (e.g. vectors, matrices, views,
# `Diagonal`, `Symmetric`, `PDMat`, and other wrappers). Canonicalising at the rule boundary
# avoids a proliferation of specialised methods and lets the pullback operate on a single,
# predictable dense matrix tangent representation.
# `matrixify` returns a tuple (primal, tangent_matrix).
px1, dx1 = matrixify(x1)
px2, dx2 = matrixify(x2)
# Run the primal computation
y = kron(px1, px2)
dy = zero(y)
# Work with canonicalised tangent arrays.
function kron_pb!!(::NoRData)
# Run the pullback computation
# Code omitted here for brevity
return NoRData(), NoRData(), NoRData()
end
return CoDual(y, dy), kron_pb!!
endThe key insight is that matrixify is one of several canonicalisation utilities (alongside arrayify) used to reconcile heterogeneous tangent representations into simple, uniform forms. In this case, tangents associated with vectors, matrices, views, Diagonal, Symmetric, PDMat, and other array wrappers are converted into a standard dense matrix representation that the rule can consume directly. Without this step, the rule would require multiple specialised methods or intricate dispatch logic to account for every admissible tangent representation.
Although this pattern is especially visible in BLAS- and LAPACK-backed rules—where performance-critical kernels must accommodate many array wrappers—it is not specific to linear algebra. Canonicalisation is a general rule-design technique: it isolates type heterogeneity at the boundary of the rule, simplifies the core logic, and improves maintainability across any domain where primitives admit many equivalent tangent representations (e.g. broadcasting, structured arrays, or custom numeric types).
Customising Friendly Gradients
When friendly_tangents=true is passed to value_and_gradient!! or prepare_gradient_cache, Mooncake converts its internal tangent representation into user-facing values. The conversion by type is:
- Immutable structs, mutable structs (with standard
MutableTangent), and closures with differentiable fields:NamedTupleof per-field gradients, keyed by field name. Tuple:Tupleof per-element gradients.AbstractArraywith non-IEEEFloat(or complex) eltype: array of per-element gradients.AbstractArraywithIEEEFloat(or complex) eltype: plain array tangent, unchanged.- Callables with no captured differentiable state:
NoTangent(), unchanged. AbstractDict: a dict of the same type as the primal, with the same keys and gradient values.- Everything else (primitive types, zero-field types, mutable structs with custom tangent types): raw Mooncake tangent, unchanged unless customised as described below.
For example, with friendly_tangents=false (default), an immutable struct Foo with fields a::Float64 and b::Vector{Float64} returns a Mooncake.Tangent wrapping (a = da, b = db), and a mutable struct Bar with the same fields returns a Mooncake.MutableTangent wrapping (a = da, b = db). With friendly_tangents=true both unwrap to the plain NamedTuple (a = da, b = db) where da::Float64 and db::Vector{Float64}.
An override is needed when the default output is unreadable or unintuitive — for example, types that store only a compressed representation, such as LinearAlgebra.Symmetric, which stores only one triangle of the full matrix but logically represents both.
Two hooks control this conversion:
Mooncake.FriendlyTangentCache — Type
FriendlyTangentCache{M, B}Pre-allocated output buffer for the user-facing gradient of a non-composite primal type, carrying a mode flag M that drives dispatch in tangent_to_friendly!!.
A type is non-composite if friendly_tangent_cache returns a single FriendlyTangentCache{M} for it. A type is composite if friendly_tangent_cache instead recurses into its sub-components and returns a NamedTuple, Tuple, or Array of per-element caches. The modes below do not apply to composite types.
M is one of the following mode types:
User-overridable modes (for use in custom friendly_tangent_cache overloads):
AsRaw— default for all non-composite types without an explicit override (Julia primitive types, float arrays, types with custom tangent types, zero-field types).bufferisnothing— no allocation at prepare time. The raw Mooncake tangent is returned directly, aliasing internal cache storage; copy it before the next AD call with the same cache if you need to retain it.AsPrimal— opt-in;bufferis a copy of the primal (via_copy_output). At runtime, non-differentiable fields are refreshed from the current primal and the tangent is written in viatangent_to_primal_internal!!. Used for mutable collections (e.g.AbstractDict) where the user-facing gradient should have the same container type as the primal.AsCustomised— opt-in;bufferis a user-supplied friendly output buffer (e.g.Matrix{T}forSymmetric{T}). At runtime,tangent_to_friendly_internal!!is called to fill it.
Internal mode (generated automatically; do not use in overloads):
AsMutableFields— used internally for mutable structs with fields and the standardMutableTangenttangent type.bufferis aNamedTupleof per-field caches built at prepare time. At runtime, each field is recursively converted and the results are assembled into aNamedTuple. Mutable structs with a custom tangent type fall through toAsRaw.
Override friendly_tangent_cache to return a FriendlyTangentCache of the desired mode for custom types.
Mooncake.friendly_tangent_cache — Function
friendly_tangent_cache(x)Return a pre-allocated cache for the user-facing gradient of the primal x.
A primal type is non-composite if this function returns a single FriendlyTangentCache{M} for it. It is composite if this function recurses into sub-components and returns a nested NamedTuple, Tuple, or Array of per-element caches instead.
Behaviour by type category:
| Category | Cache returned |
|---|---|
Immutable struct with fields and standard Tangent | NamedTuple of per-field caches (composite) |
Tuple | Tuple of per-element caches (composite) |
AbstractArray with non-float eltype | Array of per-element caches via map (composite) |
Mutable struct with fields and standard MutableTangent | FriendlyTangentCache{AsMutableFields} — per-field NamedTuple at runtime (non-composite, internal mode) |
AbstractDict | FriendlyTangentCache{AsPrimal} (non-composite) |
LinearAlgebra.Symmetric / Hermitian / SymTridiagonal | FriendlyTangentCache{AsCustomised} (non-composite) |
| Everything else (Julia primitive types, float arrays, custom-tangent types, zero-field types) | FriendlyTangentCache{AsRaw} (non-composite) |
Override to opt a type into a non-composite mode with a custom buffer:
Mooncake.friendly_tangent_cache(x::MyMatrix{T}) where {T} =
Mooncake.FriendlyTangentCache{Mooncake.AsCustomised}(Matrix{T}(undef, size(x)...))Overloads for LinearAlgebra.Symmetric, LinearAlgebra.Hermitian, and LinearAlgebra.SymTridiagonal live in src/rules/linear_algebra.jl.
Mutable structs whose fields form a self-referential cycle (e.g. a linked-list node whose next field points to another instance of the same type) will cause a StackOverflowError when this function descends into their fields at prepare time. Override friendly_tangent_cache for such types to avoid recursion:
Mooncake.friendly_tangent_cache(::MyRecursiveType) =
Mooncake.FriendlyTangentCache{Mooncake.AsRaw}(nothing)Mooncake.tangent_to_friendly!! — Function
tangent_to_friendly!!(dest, primal, tangent, c::MaybeCache)
tangent_to_friendly!!(primal, tangent)Translate a Mooncake tangent to a user-facing gradient.
The 4-argument form dispatches on the FriendlyTangentCache mode stored in dest (or recurses into a NamedTuple / AbstractArray dest tree). c is an IdDict or NoCache used to handle aliased mutable buffers across a single call.
The 2-argument form is a convenience wrapper: it calls friendly_tangent_cache to build dest and creates a fresh cache c, then delegates to the 4-argument form.
Returns the unwrapped user-facing value (not the FriendlyTangentCache wrapper).
Mooncake.tangent_to_friendly_internal!! — Function
tangent_to_friendly_internal!!(dest, primal, tangent)Implementation hook for the AsCustomised mode of tangent_to_friendly!!.
Override together with friendly_tangent_cache (returning a FriendlyTangentCache{Mooncake.AsCustomised}) to provide a direct tangent → friendly conversion for custom types. dest is the pre-allocated output buffer from the cache (used for dispatch on its type and for in-place writing); primal is available for additional dispatch if needed.
Overloads for LinearAlgebra.Symmetric, LinearAlgebra.Hermitian, and LinearAlgebra.SymTridiagonal live in src/rules/linear_algebra.jl.
Example: full-matrix gradient for a structured matrix type
Suppose MyMatrix{T} stores data compactly but represents a full matrix. To expose a plain Matrix{T} gradient to the user:
# Step 1: tell Mooncake to use a pre-allocated Matrix{T} buffer.
Mooncake.friendly_tangent_cache(x::MyMatrix{T}) where {T} =
Mooncake.FriendlyTangentCache{Mooncake.AsCustomised}(Matrix{T}(undef, size(x)...))
# Step 2: implement the conversion from internal tangent to the buffer.
# Argument order: (dest, primal, tangent) — dest is first, used for dispatch on its type.
function Mooncake.tangent_to_friendly_internal!!(
dest::Matrix{T}, ::MyMatrix{T}, tangent
) where {T}
# `val` unwraps the stored field tangent; adjust the field name to match MyMatrix's layout.
copyto!(dest, Mooncake.val(tangent.fields.data))
return dest
endAny struct that contains a MyMatrix field will automatically expose that field's gradient as a Matrix{T} — no additional overrides required, because the default struct recursion builds a NamedTuple of per-field friendly gradients.
The existing overloads for LinearAlgebra.Symmetric, LinearAlgebra.Hermitian, and LinearAlgebra.SymTridiagonal in src/rules/linear_algebra.jl follow exactly this pattern and serve as reference implementations.