Tangents
As discussed in Representing Gradients, Mooncake requires that each "primal" type be associated to a unique "tangent" type, given by the function tangent_type. Moreover, we must be able to "split" a given tangent into its fdata ("forwards-data") and rdata ("reverse-data"), whose types are given by Mooncake.fdata_type
and Mooncake.rdata_type
respectively. Furthermore, we (at the very least) require methods of rrule!!
for a few core functions in order to be able to differentiate through construction and the getting / setting of fields.
Very occassionally it may be necessary to specify your own tangent type. This is not an entirely trivial undertaking, as there is quite a lot of functionality that must be added to make it work properly. So, before diving in to add your own custom type, seriously consider whether it is worth the effort, and whether the default definition given by Mooncake are really inadequate for your use-case.
Testing Functionality
The interface is given in the form of three functions, each of which specifiy which functions you must implement methods for when creating a custom tangent type:
Mooncake.TestUtils.test_tangent_interface
— Functiontest_tangent_interface(rng::AbstractRNG, p; interface_only=false)
Verify that standard functionality for tangents runs, and is consistent. This function is the defacto formal definition of the "tangent interface" – if this function runs without error for a given value of p
, then that p
satisfies the tangent interface.
Extended Help
Verifies that the following functions are implemented correctly (as far as possible) for p
/ its type, and its tangents / their type:
Mooncake.tangent_type
Mooncake.zero_tangent_internal
Mooncake.randn_tangent_internal
Mooncake.TestUtils.has_equal_data
Mooncake.increment_internal!!
Mooncake.set_to_zero_internal!!
Mooncake._add_to_primal_internal
Mooncake._diff_internal
Mooncake._dot_internal
Mooncake._scale_internal
Mooncake.TestUtils.populate_address_map_internal
In conjunction with the functions tested by test_tangent_splitting
, these functions constitute a complete set of functions which must be applicable to p
in order to ensure that it operates correctly in the context of reverse-mode AD. This list should be up to date at any given point in time, but the best way to verify that you've implemented everything is simply to run this function, and see whether it errors / produces a failing test.
Mooncake.TestUtils.test_tangent_splitting
— Functiontest_tangent_splitting(rng::AbstractRNG, p::P) where {P}
Verify that tangent splitting functionality associated to primal p
works correctly. Ensure that test_tangent_interface
runs for p
before running these tests.
Extended Help
Verifies that the following functionality work correctly for p
/ its type / tangents:
Mooncake.fdata_type
Mooncake.rdata_type
Mooncake.fdata
Mooncake.rdata
Mooncake.uninit_fdata
Mooncake.tangent_type
(binary method)Mooncake.tangent
(binary method)
Mooncake.TestUtils.test_rule_and_type_interactions
— Functiontest_rule_and_type_interactions(rng::AbstractRNG, p)
Check that a collection of standard functions for which we ought to have a working rrule for p
work, and produce the correct answer. For example, the rrule!!
for typeof
should work correctly on any type, we should have a working rule for getfield
for any struct-type, and we should have a rule for setfield!
for any mutable struct type. See extended help for more info.
Extended Help
The purpose of this test is to ensure that, for any given p
, the full range of primitive functions that ought to work on it, do indeed work on it.
This is one part of the interface where some care might be required. If, for some reason, it should never be the case that e.g. for a particular p
, getfield
should be called, then it may make no sense at all to run these tests. In such cases, the author of the type is responsible for knowing what they are doing. Please open an issue to discuss for your type if you are at all unsure what to do.
When defining a custom tangent type for P
, the functions that you will need to pay attention to writing rules for are
In all cases, you may wish to consult the current implementations of rrule!!
for these functions for inspiration regarding how you might implement them for your type.
You can call all three of these functions at once using
Mooncake.TestUtils.test_data
— Functiontest_data(rng::AbstractRNG, p::P)
Verify that all tangent / fdata / rdata functionality work properly for x
. Furthermore, verify that all primitives listed in TestUtils.test_rule_and_type_interactions
work correctly on x
. This functionality is particularly useful if you are writing your own custom tangent / fdata / rdata types and want to be confident that you have implemented the functionality that you need in order to make these custom types work with all the rules written in Mooncake itself.
You should consult the docstrings for test_tangent_interface
, test_tangent_splitting
, and test_rule_and_type_interactions
, in order to see what is required to satisfy the full tangent interface for p
.
If all the tests in these functions pass, then you have satisfied the interface.