Misc. Internals Notes
This document contains an assortment of notes on some implementation details in Mooncake.jl. It is occassionally helpful to have them here for reference, but they are typically not essential reading unless working on the specific parts of Mooncake.jl to which they pertain.
tangent_type
and friends
Last checked: 21/01/2025, Julia v1.10.7 / v1.11.2, Mooncake 0.4.
Background
Mooncake.jl makes extensive use of @generated
functions to ensure that its tangent_type
function (among others) is both type-stable, and constant folds. I recently changed how tangent_type
is implemented in Mooncake.jl to ensure that the implementations respect some specific limitations of generated functions. Here I outline the overall problem, the mistake the previous implementation made, and how the recent changes fix it.
tangent_type
tangent_type
is a regular Julia function, which given a "primal" type returns another type, the tangent type. It is side-effect free, and its return value is determined entirely by the type of its argument. This means it should be possible to constant-fold. For example, consider the following definitions:
tangent_type(::Type{Float64}) = Float64
tangent_type(::Type{P}) where {P<:Tuple} = Tuple{map(tangent_type, fieldtypes(P))...}
If we inspect the IRCode
associated to this for Float64
, we see that everything is as expected – the function literally just returns Float64
:
julia> Base.code_ircode(tangent_type, (Type{Float64}, ))[1]
1 ─ return Main.Float64
=> Type{Float64}
A simple Tuple
type will also have this property:
julia> Base.code_ircode(tangent_type, (Type{Tuple{Float64}}, ))[1]
1 ─ return Tuple{Float64}
=> Type{Tuple{Float64}}
However, for even slightly more complicated types, things fall over:
julia> Base.code_ircode(tangent_type, (Type{Tuple{Tuple{Float64}}}, ))[1]
1 1 ─ %1 = Main.tangent_type::Core.Const(tangent_type)
│ %2 = invoke %1(Tuple{Float64}::Type{Tuple{Float64}})::Type{<:Tuple}
│ %3 = Core.apply_type(Tuple, %2)::Type{<:Tuple}
└── return %3
=> Type{<:Tuple}
This is just one specific example, but it is really very straightforward to find others, necessitating a hunt for a more robust way of implementing tangent_type.
Bad Generated Function Implementation
You might think to instead implement tangent_type
for Tuple
s as follows:
bad_tangent_type(::Type{Float64}) = Float64
@generated function bad_tangent_type(::Type{P}) where {P<:Tuple}
return Tuple{map(bad_tangent_type, fieldtypes(P))...}
end
bad_tangent_type(::Type{Float32}) = Float32
Since the generated function literally just returns the type that we want, it will definitely constant-fold:
julia> Base.code_ircode(bad_tangent_type, (Type{Tuple{Tuple{Float64}}}, ))[1]
1 1 ─ return Tuple{Tuple{Float64}}
=> Type{Tuple{Tuple{Float64}}}
However, this implementation has a crucial flaw: we rely on the definition of bad_tangent_type
in the body of the @generated
method of bad_tangent_type
. This means that if we e.g. add methods to bad_tangent_type
after the initial definition, they won't show up. For example, in the above block, we defined the method of bad_tangent_type
for Float32
after that of Tuple
s. This results in the following error when we call bad_tangent_type(Tuple{Float32})
:
julia> bad_tangent_type(Tuple{Float32})
ERROR: MethodError: no method matching bad_tangent_type(::Type{Float32})
The applicable method may be too new: running in world age 26713, while current world is 26714.
Closest candidates are:
bad_tangent_type(::Type{Float32}) (method too new to be called from this world context.)
@ Main REPL[10]:1
bad_tangent_type(::Type{Float64})
@ Main REPL[8]:1
bad_tangent_type(::Type{P}) where P<:Tuple
@ Main REPL[9]:1
Stacktrace:
[1] map(f::typeof(bad_tangent_type), t::Tuple{DataType})
@ Base ./tuple.jl:355
[2] #s1#1
@ ./REPL[9]:2 [inlined]
[3] var"#s1#1"(P::Any, ::Any, ::Any)
@ Main ./none:0
[4] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:707
[5] top-level scope
@ REPL[12]:1
This behaviour of @generated
functions is discussed in the Julia docs – I would recommend reading this bit of the docs if you've not previously, as the explanation is quite clear.
Good Generated Function Implementation
@generated
functions can still come to our rescue though. A better implementation is as follows:
good_tangent_type(::Type{Float64}) = Float64
@generated function good_tangent_type(::Type{P}) where {P<:Tuple}
exprs = map(p -> :(good_tangent_type($p)), fieldtypes(P))
return Expr(:curly, :Tuple, exprs...)
end
good_tangent_type(::Type{Float32}) = Float32
This leads to generated code which constant-folds / infers correctly:
julia> Base.code_ircode(good_tangent_type, (Type{Tuple{Tuple{Float64}}}, ))[1]
1 1 ─ return Tuple{Tuple{Float64}}
=> Type{Tuple{Tuple{Float64}}}
I believe this works better because the recursion doesn't happen through another function, but appears directly in the function body. This is right at the edge of my understanding of Julia's compiler heuristics surrounding recursion though, so I might be mistaken.
It also behaves correctly under the addition of new methods of good_tangent_type
, because good_tangent_type
only appears in the expression returned by the generated function, not the body of the generated function itself:
julia> good_tangent_type(Tuple{Float32})
Tuple{Float32}
Effects Etc
This implementation is nearly sufficient to guarantee correct performance in all situations. However, in some cases it is possible that even this implementation will fall over. Annoyingly I've not managed to produce a MWE that is even vaguely minimal in order to support this example, so you'll just have to believe me.
Based on all of the examples that I have seen thus far, it appears to be true that if you just tell the compiler that
- for the same inputs, the function always returns the same outputs, and
- the function has no side-effects, so can be removed,
everything will always constant fold nicely. This can be achieved by using the Base.@assume_effects
macro in your method definitions, with the effects :consistent
and :removable
.
How Recursion Is Handled
Last checked: 09/02/2025, Julia v1.10.8 / v1.11.3, Mooncake 0.4.82.
Mooncake handles recursive function calls by delaying code generation for generic function calls until the first time that they are actually run. The docstring below contains a thorough explanation:
Mooncake.LazyDerivedRule
— TypeLazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)
For internal use only.
A type-stable wrapper around a DerivedRule
, which only instantiates the DerivedRule
when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.
If debug_mode
is true
, then the rule constructed will be a DebugRRule
. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.
Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.
Extended Help
There are two main reasons why deferring the construction of a DerivedRule
until we need to use it is crucial.
The first is to do with recursion. Consider the following function:
f(x) = x > 0 ? f(x - 1) : x
If we generate the IRCode
for this function, we will see something like the following:
julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})[1][1]
1 1 ─ %1 = Base.lt_float(0.0, _2)::Bool
│ %2 = Base.or_int(%1, false)::Bool
└── goto #6 if not %2
2 ─ %4 = Base.sub_float(_2, 1.0)::Float64
│ %5 = Base.lt_float(0.0, %4)::Bool
│ %6 = Base.or_int(%5, false)::Bool
└── goto #4 if not %6
3 ─ %8 = Base.sub_float(%4, 1.0)::Float64
│ %9 = invoke Main.f(%8::Float64)::Float64
└── goto #5
4 ─ goto #5
5 ┄ %12 = φ (#3 => %9, #4 => %4)::Float64
└── return %12
6 ─ return _2
Suppose that we decide to construct a DerivedRule
immediately whenever we find an :invoke
statement in a rule that we're currently building a DerivedRule
for. In the above example, we produce an infinite recursion when we attempt to produce a DerivedRule
for %9, because it has the same signature as the call which generates this IR. By instead adopting a policy of constructing a LazyDerivedRule
whenever we encounter an :invoke
statement, we avoid this problem.
The second reason that delaying the construction of a DerivedRule
, is essential is that it ensures that we don't derive rules for method instances which aren't run. Suppose that function B contains code for which we can't derive a rule – perhaps it contains an unsupported language feature like a PhiCNode
or an UpsilonNode
. Suppose that function A contains an :invoke
which refers to function B
, but that this call is on a branch which deals with error handling, and doesn't get run run unless something goes wrong. By deferring the derivation of the rule for B, we only ever attempt to derive it if we land on this error handling branch. Conversely, if we attempted to derive the rule for B when we derive the rule for A, we would be unable to complete the derivation of the rule for A.