Defining Rules

Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations, which we discuss before discussing the more involved process of actually writing rules.

Simplifying Code via Overlays

Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
foo (generic function with 1 method)

where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:

julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)

When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.

A Worked Example

To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!

First, consider a simple example:

julia> scale(x) = 2x
scale (generic function with 1 method)

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(10.0, (NoTangent(), 2.0))

We can use @mooncake_overlay to change the definition which Mooncake.jl sees:

julia> Mooncake.@mooncake_overlay scale(x) = 3x

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(15.0, (NoTangent(), 3.0))

As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.

Additionally, it is possible to use the usual multi-line syntax to declare an overlay:

julia> Mooncake.@mooncake_overlay function scale(x)
           return 4x
       end

julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});

julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
(20.0, (NoTangent(), 4.0))
source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions whose adjoints always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual

julia> foo(x::Vararg{Int}) = 5
foo (generic function with 1 method)

julia> world = Base.get_world_counter();

julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}, world) = true;

julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);

julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Adding Methods To rrule!! And build_primitive_rrule

If the above strategies do not work for you, you should first implement a method of Mooncake.is_primitive for the signature of interest:

Mooncake.is_primitiveFunction
is_primitive(ctx::Type, mode::Type{<:Mode}, sig::Type{<:Tuple}, world::UInt)

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of context ctx in mode mode at world age world.

julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode

julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64}, Base.get_world_counter())
true

world is needed as rules which Mooncake derives are associated to a particular Julia world age. As a result, anything declared a primitive after the construction of a rule ought not to be considered a primitive by that rule. One can explicitly derive a new rule (eg. via build_frule, build_rrule, or a function from the higher-level interface such as prepare_derivative_cache, prepare_pullback_cache or prepare_gradient_cache) after new @is_primitive declarations, should it be needed in cases where a rule has been previously derived. To see how this works, consider the following:

julia> using Mooncake: is_primitive, DefaultCtx, ReverseMode, @is_primitive

julia> foo(x::Float64) = 5x
foo (generic function with 1 method)

julia> old_world_age = Base.get_world_counter();

julia> @is_primitive DefaultCtx ReverseMode Tuple{typeof(foo),Float64}

julia> new_world_age = Base.get_world_counter();

julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, old_world_age)
false

julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}, new_world_age)
true

Observe that is_primitive returns false for the world age prior to declaring foo a primitive, but true afterwards. For more information on Julia's world age mechanism, see https://docs.julialang.org/en/v1/manual/methods/#Redefining-Methods .

source

Then implement a method of one of the following:

Mooncake.rrule!!Function
rrule!!(f::CoDual, x::CoDual...)

Performs the forwards-pass of AD. The tangent field of f and each x should contain the forwards tangent data (fdata) associated to each corresponding primal field.

Returns a 2-tuple. The first element, y, is a CoDual whose primal field is the value associated to running f.primal(map(x -> x.primal, x)...), and whose tangent field is its associated fdata. The second element contains the pullback, which runs the reverse-pass. It maps from the rdata associated to y to the rdata associated to f and each x.

using Mooncake: zero_fcodual, CoDual, NoFData, rrule!!
y, pb!! = rrule!!(zero_fcodual(sin), CoDual(5.0, NoFData()))
pb!!(1.0)

# output

(NoRData(), 0.28366218546322625)
source
Mooncake.build_primitive_rruleFunction
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature sig. For this function to be called in build_rrule, you must also ensure that a method of _is_primitive(context_type, ReverseMode, sig) exists, preferably by using the @is_primitive macro. The callable returned by this must obey the rrule interface, but there are no restrictions on the type of callable itself. For example, you might return a callable struct. By default, this function returns rrule!! so, most of the time, you should just implement a method of rrule!!.

Extended Help

The purpose of this function is to permit computation at rule construction time, which can be re-used at runtime. For example, you might wish to derive some information from sig which you use at runtime (e.g. the fdata type of one of the arguments). While constant propagation will often optimise this kind of computation away, it will sometimes fail to do so in hard-to-predict circumstances. Consequently, if you need certain computations not to happen at runtime in order to guarantee good performance, you might wish to e.g. emit a callable struct with type parameters which are the result of this computation. In this context, the motivation for using this function is the same as that of using staged programming (e.g. via @generated functions) more generally.

source

Canonicalising Tangent Types

For some differentiation rules, Mooncake performs an explicit canonicalisation step inside frule!!/rrule!! that collapses heterogeneous array and tangent types into a small set of canonical representations. By canonicalising at the rule boundary, a single implementation can support many combinations of argument and tangent types without duplicating logic or relying on complex dispatch. This allows the remainder of the rule to assume a single, well-defined tangent representation.

Recall that rrule!! methods in Mooncake receive CoDual-wrapped arguments, including the function itself. Each CoDual carries both a primal value and an associated tangent (or FData). Consider a kron rule:

function Mooncake.rrule!!(
    ::CoDual{typeof(kron)},
    x1::CoDual{<:AbstractVecOrMat{<:T}},
    x2::CoDual{<:AbstractVecOrMat{<:T}},
) where {T<:Base.IEEEFloat}
    # Canonicalise inputs: although this method constrains `x1`/`x2` to `AbstractVecOrMat`,
    # they may still be realised by many concrete array types (e.g. vectors, matrices, views,
    # `Diagonal`, `Symmetric`, `PDMat`, and other wrappers). Canonicalising at the rule boundary
    # avoids a proliferation of specialised methods and lets the pullback operate on a single,
    # predictable dense matrix tangent representation.
    # `matrixify` returns a tuple (primal, tangent_matrix).
    px1, dx1 = matrixify(x1)
    px2, dx2 = matrixify(x2)

    # Run the primal computation
    y = kron(px1, px2)
    dy = zero(y)

    # Work with canonicalised tangent arrays.
    function kron_pb!!(::NoRData)
        # Run the pullback computation
        # Code omitted here for brevity
        return NoRData(), NoRData(), NoRData()
    end
    return CoDual(y, dy), kron_pb!!
end

The key insight is that matrixify is one of several canonicalisation utilities (alongside arrayify) used to reconcile heterogeneous tangent representations into simple, uniform forms. In this case, tangents associated with vectors, matrices, views, Diagonal, Symmetric, PDMat, and other array wrappers are converted into a standard dense matrix representation that the rule can consume directly. Without this step, the rule would require multiple specialised methods or intricate dispatch logic to account for every admissible tangent representation.

Although this pattern is especially visible in BLAS- and LAPACK-backed rules—where performance-critical kernels must accommodate many array wrappers—it is not specific to linear algebra. Canonicalisation is a general rule-design technique: it isolates type heterogeneity at the boundary of the rule, simplifies the core logic, and improves maintainability across any domain where primitives admit many equivalent tangent representations (e.g. broadcasting, structured arrays, or custom numeric types).