Law VJP customization

This tutorial explains how to customize VJP (vector-Jacobian product) computation of the laws in ODINN.jl and clarifies the runtime flow used internally by the library. It explains which functions are part of the public, user-facing customization API and which are internal helpers used by ODINN when an AD backend is required.

Advanced features

The features presented in this section are considered as advanced features and this tutorial was written primarily for ODINN developers.

It assumes that you have followed the Laws tutorial.

using ODINN
rgi_ids = ["RGI60-11.03638"]
rgi_paths = get_rgi_paths()
params = Parameters(
    simulation = SimulationParameters(rgi_paths = rgi_paths),
    UDE = UDEparameters(grad = ContinuousAdjoint())
)
nn_model = NeuralNetwork(params)
--- NeuralNetwork ---
    architecture:
      Chain(
          layer_1 = Dense(1 => 3, #101),                # 6 parameters
          layer_2 = Dense(3 => 10, #102),               # 40 parameters
          layer_3 = Dense(10 => 3, #103),               # 33 parameters
          layer_4 = Dense(3 => 1, σ),                   # 4 parameters
      )         # Total: 83 parameters,
                #        plus 0 states.
    θ: ComponentVector of length 83

Explanations

High-level summary

At the user level the customization can be made by implementing hand-written VJPs through the following functions:

  • f_VJP_input!(...): VJP w.r.t. inputs
  • f_VJP_θ!(...): VJP w.r.t. parameters θ
  • You may also implement your own precompute function to cache expensive computations which is the purpose of p_VJP!(...). This function is called before solving the adjoint iceflow PDE.

Internally when the user does NOT provide VJPs, ODINN uses a default AD backend (via DifferentiationInterface.jl) to compute the VJPs of the laws. To support efficient reverse-mode execution, ODINN will:

  • compile and precompute adjoint-related helper functions and
  • store preparation objects that are used later during the in adjoint PDE.

This mechanism is triggered by prepare_vjp_law.

Internal function roles

  • prepare_vjp_law (internal)

    Signature used in the codebase:

    prepare_vjp_law(
        simulation,
        law::AbstractLaw,
        law_cache,
        θ,
        glacier_idx,
    )
    • Intent and behavior:
      • This is an internal routine. It is NOT intended to be called by users directly.
      • It is invoked when ODINN must fall back to the AD backend (with DifferentiationInterface.jl) because the law did not supply explicit VJP functions (f_VJP_input!/f_VJP_θ! or because p_VJP! is set to DIVJP()).
      • Its job is to precompile and prepare the AD-based VJP code for a given law and to produce preparation objects that store preparation results from DifferentiationInterface.jl.
      • prepare_vjp_law is typically called just after the iceflow model / law objects have been instantiated — i.e., early in the setup — so that preparations are ready before solving or adjoint runs.
  • precompute_law_VJP (used before solving the adjoint PDE)

    The typical signature in the codebase is:

    precompute_law_VJP(
      law::AbstractLaw,
      cache,
      vjpsPrepLaw,
      simulation,
      glacier_idx,
      t,
      θ
    )
    • Intent and behavior:
      • This function precomputes VJP-related artifacts before the adjoint iceflow PDE is solved for given time t and parameters θ.
      • It typically uses the vjpsPrepLaw (an AbstractPrepVJP instance produced earlier by prepare_vjp_law) together with the cache and simulation object. The produced results are cached in cache and are optionally consumed later by law_VJP_input / law_VJP_θ during the adjoint solve.
  • Entry points used in the adjoint PDE

    These functions are the actual runtime entry points used when computing contributions of the laws to the gradient in the adjoint PDE:

    law_VJP_θ(law::AbstractLaw, cache, simulation, glacier_idx, t, θ)

    and

    law_VJP_input(law::AbstractLaw, cache, simulation, glacier_idx, t, θ)
    • Intent and behavior:
      • These are called during the adjoint solve to compute parameter and input VJPs for the law at time t and for parameters θ.
      • They can either compute the VJPs directly or use cached VJP information that has been already computed in the user-supplied p_VJP! VJP function. The cache allows storing useful information from the forward or from the precomputation step.
      • They therefore carry the runtime context (simulation, glacier index, time, θ) which is necessary for adjoint calculations.

Workflow

For the wide audience we do not recommend to play with the VJPs. ODINN comes with default parameters and the average user does not need to customize the VJPs. Keeping the default values will work fine.

Advanced users seeking maximum performance can customize the VJPs which can significantly speed-up the code.

How do the pieces compose in practice?

  • If you, as a user, provide custom VJP functions (through f_VJP_input!/f_VJP_θ!, or through p_VJP!), ODINN will use them directly in the adjoint and will skip the AD fallback path. You can also provide your own precompute wrapper and cache to optimize expensive computations.
  • If you do NOT provide VJP functions, ODINN runs the AD fallback:
    1. prepare_vjp_law runs early (post-instantiation) to compile/prepare AD-based helpers and returns some AbstractPrepVJP object.
    2. precompute_law_VJP is skipped.
    3. During the adjoint solve, law_VJP_input and law_VJP_θ use the preparation objects precompiled in prepare_vjp_law to automatically differentiate f! with DifferentiationInterface.jl and obtain the VJPs of the law with respect to the inputs and to the parameters θ.
Info

You can change the default AD backend for laws that do not have custom VJPs in the VJP type, for example by setting VJP_method = DiscreteVJP(regressorADBackend = DI.AutoZygote()) when you define the adjoint method.

User level customization

What is user-visible and can be customized?

  • f_VJP_input!(cache, inputs, θ) — compute the VJP with respect to the inputs and store the result in cache.vjp_inp
  • f_VJP_θ!(cache, inputs, θ) — compute the VJP with respect to θ and store the result in cache.vjp_θ
  • p_VJP!(cache, vjpsPrepLaw, inputs, θ) — if you want to precompute some components (or even the whole VJPs when possible) before solving the adjoint iceflow PDE
  • custom cache implementations (described below)

Notes on cache definition

The cache parameter that is threaded through p_VJP!/f_VJP_* calls is the place to store artifacts useful for efficient computation as well as the results of the VJPs computation. The following fields are mandatory:

  • value: a placeholder to store the result of the forward evaluation, can be of any type
  • vjp_θ: a placeholder to store the result of the VJP with respect to θ, depending on the type of law that is defined, it can be a vector or a 3 dimensional array
  • vjp_inp: a placeholder to store the result of the VJP with respect to the inputs, must be of a type that matches the one of the inputs

In order to know the type of the inputs, simply run generate_inputs(law.f.inputs, simulation, glacier_idx, t).

Using the preparation object

Error

For the moment, using the preparation object at the user level is not supported yet.

Best practices and debugging tips

  • If you supply custom VJPs, test them with finite-difference checks for both inputs and parameters. ODINN does not check the correctness of your implementation!
  • If you rely on ODINN's AD fallback, be aware that prepare_vjp_law will precompile and prepare AD helpers at model instantiation time — expect longer setup time but faster adjoint runs thereafter.
  • Inspect/validate cache content if you get inconsistent adjoints — a stale or incorrect cache entry is a common cause.
  • Although the API is designed to provide everything you need as arguments, if your VJP needs anything from the forward pass, ensure it is stored in the cache.

Simple VJP customization

We will explore how we can customize the VJP computation of the law that is used in the Laws tutorial. The cache used for this law is a ScalarCache since the output of this law is a scalar value A, the creep coefficient. We can confirm that this type defines the fields needed for the VJP computation:

fieldnames(ScalarCache)
(:value, :vjp_inp, :vjp_θ)

Before defining the law, we retrieve the model architecture, the physical parameters to be used inside the f! function of the law and we define the inputs:

archi = nn_model.architecture
st = nn_model.st
smodel = ODINN.StatefulLuxLayer{true}(archi, nothing, st)
min_NN = params.physical.minA
max_NN = params.physical.maxA
inputs = (; T = iAvgScalarTemp())
(T = averaged_long_term_temperature: iAvgScalarTemp(),)

And then the f! and init_cache functions:

f! = let smodel = smodel, min_NN = min_NN, max_NN = max_NN
    function (cache, inp, θ)
        inp = collect(values(inp))
        A = only(ODINN.scale(smodel(inp, θ.A), (min_NN, max_NN)))
        ODINN.Zygote.@ignore_derivatives cache.value .= A # We ignore this in-place affectation in order to be able to differentiate it with Zygote hereafter
        return A
    end
end
function init_cache(simulation, glacier_idx, θ)
    return ScalarCache(zeros(), zeros(), zero(θ))
end
init_cache (generic function with 1 method)

The declaration of the law without VJP customization would be:

law = Law{ScalarCache}(;
    inputs = inputs,
    f! = f!,
    init_cache = init_cache
)
(:T,) -> Array{Float64, 0}   (⟳  auto VJP  ❌ precomputed)
Success

We see from the output that the VJPs are inferred using DifferentiationInterface.jl and that ODINN does not use precomputation. Now let's try to customize the VJPs by manually implementing the AD step:

law = Law{ScalarCache}(;
    inputs = inputs,
    f! = f!,
    f_VJP_input! = function (cache, inputs, θ)
        nothing # The input does not depend on the glacier state
    end,
    f_VJP_θ! = function (cache, inputs, θ)
        cache.vjp_θ .= ones(length(θ)) # The VJP is wrong on purpose to check that this function is properly called hereafter
    end,
    init_cache = init_cache
)
(:T,) -> Array{Float64, 0}   (⟳  custom VJP  ❌ precomputed)

In order to instantiate the cache, we need to define the model:

rgi_ids = ["RGI60-11.03638"]
glaciers = initialize_glaciers(rgi_ids, params)
model = Model(
    iceflow = SIA2Dmodel(params; A = law),
    mass_balance = nothing,
    regressors = (; A = nn_model)
)
simulation = Inversion(model, glaciers, params)
Inversion{Sleipnir.Model{SIA2Dmodel{Float64, Law{ScalarCache, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#1#3"{LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, Float64, Float64}}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#5#7"}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#6#8"}, typeof(Main.init_cache), Nothing, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, typeof(Sleipnir.emptyPrepVJPWithInputs)}, CustomVJP}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#9#10"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#11#12"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#13#14"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#15#16"}, NullLaw, NullLaw}, Nothing, ODINN.TrainableComponents{NeuralNetwork{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(θ = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyIC, SIA2D_A_target, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(A = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}}}, Sleipnir.ModelCache{SIA2DCache{Float64, Int64, ScalarCache, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, ScalarCacheNoVJP, ScalarCacheNoVJP}, Nothing}, Glacier2D{Float64, Int64, Climate2D{Rasters.RasterStack{(:prcp, :temp, :gradient), @NamedTuple{prcp::Float64, temp::Float64, gradient::Float64}, 1, @NamedTuple{prcp::Vector{Float64}, temp::Vector{Float64}, gradient::Vector{Float64}}, Tuple{DimensionalData.Dimensions.Ti{DimensionalData.Dimensions.Lookups.Sampled{Dates.DateTime, Vector{Dates.DateTime}, DimensionalData.Dimensions.Lookups.ForwardOrdered, DimensionalData.Dimensions.Lookups.Irregular{Tuple{Nothing, Nothing}}, DimensionalData.Dimensions.Lookups.Points, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}}}, Tuple{}, @NamedTuple{prcp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, temp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, gradient::Tuple{DimensionalData.Dimensions.Ti{Colon}}}, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, @NamedTuple{prcp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, temp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, gradient::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}, Nothing}, Rasters.RasterStack{(:prcp, :temp, :gradient), @NamedTuple{prcp::Float64, temp::Float64, gradient::Float64}, 1, @NamedTuple{prcp::Vector{Float64}, temp::Vector{Float64}, gradient::Vector{Float64}}, Tuple{DimensionalData.Dimensions.Ti{DimensionalData.Dimensions.Lookups.Sampled{Dates.DateTime, Vector{Dates.DateTime}, DimensionalData.Dimensions.Lookups.ForwardOrdered, DimensionalData.Dimensions.Lookups.Irregular{Tuple{Nothing, Nothing}}, DimensionalData.Dimensions.Lookups.Points, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}}}, Tuple{}, @NamedTuple{prcp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, temp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, gradient::Tuple{DimensionalData.Dimensions.Ti{Colon}}}, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, @NamedTuple{prcp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, temp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, gradient::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}, Nothing}, Sleipnir.ClimateStep{Float64}, Climate2Dstep{Float64}, Float64}, Nothing, Nothing}, Results{Sleipnir.Results{Float64, Int64}, TrainingStats{Float64, Int64}}}(Sleipnir.Model{SIA2Dmodel{Float64, Law{ScalarCache, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#1#3"{LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, Float64, Float64}}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#5#7"}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#6#8"}, typeof(Main.init_cache), Nothing, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, typeof(Sleipnir.emptyPrepVJPWithInputs)}, CustomVJP}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#9#10"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#11#12"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#13#14"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#15#16"}, NullLaw, NullLaw}, Nothing, ODINN.TrainableComponents{NeuralNetwork{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(θ = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyIC, SIA2D_A_target, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(A = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}}}(SIA2D iceflow equation  = ∇(D ∇S)  with D = U H̄
  and U = C (ρg)^(pq) H̄^(pq+1) ∇S^(p-1) + Γ H̄^(n+2) ∇S^(n-1)
      Γ = 2A (ρg)^n /(n+2)
      A: (:T,) -> Array{Float64, 0}   (⟳  custom VJP  ❌ precomputed)
      C: ConstantLaw -> Array{Float64, 0}
      n: ConstantLaw -> Array{Float64, 0}
      p: ConstantLaw -> Array{Float64, 0}
      q: ConstantLaw -> Array{Float64, 0}
  where
      T => averaged_long_term_temperature
, nothing,   A: --- NeuralNetwork ---
    architecture:
      Chain(
          layer_1 = Dense(1 => 3, #101),                # 6 parameters
          layer_2 = Dense(3 => 10, #102),               # 40 parameters
          layer_3 = Dense(10 => 3, #103),               # 33 parameters
          layer_4 = Dense(3 => 1, σ),                   # 4 parameters
      )         # Total: 83 parameters,
                #        plus 0 states.
    θ: ComponentVector of length 83
), nothing, 1-element Vector{Glacier2D} distributed over regions 11 (x1)
RGI60-11.03638
, Sleipnir.Parameters{PhysicalParameters{Float64}, SimulationParameters{Int64, Float64, MeanDateVelocityMapping}, Hyperparameters{Float64, Int64}, SolverParameters{Float64, Int64}, UDEparameters{ContinuousAdjoint{Float64, Int64, DiscreteVJP{ADTypes.AutoMooncake{Nothing}}, EnzymeVJP}}, InversionParameters{Float64}}(PhysicalParameters{Float64}(900.0, 9.81, 1.0e-10, 1.0, 8.0e-17, 8.5e-20, 8.0e-17, 8.5e-20, 1.0, -25.0, 5.0e-18), SimulationParameters{Int64, Float64, MeanDateVelocityMapping}(true, true, true, true, 1.0, false, false, (2010.0, 2015.0), 0.08333333333333333, true, 4, "", false, Dict("RGI60-11.00897" => "per_glacier/RGI60-11/RGI60-11.00/RGI60-11.00897", "RGI60-08.00213" => "per_glacier/RGI60-08/RGI60-08.00/RGI60-08.00213", "RGI60-08.00147" => "per_glacier/RGI60-08/RGI60-08.00/RGI60-08.00147", "RGI60-11.01270" => "per_glacier/RGI60-11/RGI60-11.01/RGI60-11.01270", "RGI60-11.03646" => "per_glacier/RGI60-11/RGI60-11.03/RGI60-11.03646", "RGI60-11.03232" => "per_glacier/RGI60-11/RGI60-11.03/RGI60-11.03232", "RGI60-01.22174" => "per_glacier/RGI60-01/RGI60-01.22/RGI60-01.22174", "RGI60-07.00274" => "per_glacier/RGI60-07/RGI60-07.00/RGI60-07.00274", "RGI60-03.04207" => "per_glacier/RGI60-03/RGI60-03.04/RGI60-03.04207", "RGI60-04.04351" => "per_glacier/RGI60-04/RGI60-04.04/RGI60-04.04351"…), "Farinotti19", MeanDateVelocityMapping(:nearest), 1), Hyperparameters{Float64, Int64}(1, 1, Float64[], Optim.BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Optim.Flat}(LineSearches.InitialStatic{Float64}(1.0, false), LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}(0.1, 0.9, Inf, 5.0, 1.0e-6, 0.66, 50, 0.1, 0, Base.RefValue{Bool}(false), nothing, false), nothing, 0.001, Optim.Flat()), 0.0, 50, 15), SolverParameters{Float64, Int64}(OrdinaryDiffEqLowStorageRK.RDPK3Sp35{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), 1.0e-12, 0.08333333333333333, Float64[], false, true, 10, 100000), UDEparameters{ContinuousAdjoint{Float64, Int64, DiscreteVJP{ADTypes.AutoMooncake{Nothing}}, EnzymeVJP}}(SciMLSensitivity.GaussAdjoint{0, true, Val{:central}, SciMLSensitivity.EnzymeVJP}(SciMLSensitivity.EnzymeVJP(0), false), ADTypes.AutoEnzyme(), ContinuousAdjoint{Float64, Int64, DiscreteVJP{ADTypes.AutoMooncake{Nothing}}, EnzymeVJP}(DiscreteVJP{ADTypes.AutoMooncake{Nothing}}(ADTypes.AutoMooncake()), OrdinaryDiffEqLowStorageRK.RDPK3Sp35{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), 1.0e-8, 1.0e-8, 0.08333333333333333, :Linear, 200, EnzymeVJP()), "AD+AD", MultiLoss{Tuple{LossH{L2Sum{Int64}}}, Vector{Float64}}((LossH{L2Sum{Int64}}(L2Sum{Int64}(3)),), [1.0]), :A, :identity), InversionParameters{Float64}([1.0], [0.0], [Inf], [1, 1], 0.001, 0.001, Optim.BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Optim.Flat}(LineSearches.InitialStatic{Float64}(1.0, false), LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}(0.1, 0.9, Inf, 5.0, 1.0e-6, 0.66, 50, 0.1, 0, Base.RefValue{Bool}(false), nothing, false), nothing, nothing, Optim.Flat()))), Results{Sleipnir.Results{Float64, Int64}, TrainingStats{Float64, Int64}}(Sleipnir.Results{Float64, Int64}[], TrainingStats{Float64, Int64}(nothing, Float64[], 0, nothing, ComponentArrays.ComponentVector[], ComponentArrays.ComponentVector[], nothing, Dates.DateTime("0000-01-01T00:00:00"))))

We will also need θ in order to call the VJPs of the law manually although in practice we do not have to worry about retrieving this:

θ = simulation.model.trainable_components.θ
ComponentVector{Float64}(A = (layer_1 = (weight = [-0.206559956073761; -0.5560190081596375; -1.624756097793579;;], bias = [0.1391124725341797, -0.6567966938018799, 0.3310384750366211]), layer_2 = (weight = [0.026965618133544922 0.7515738010406494 -0.7434771060943604; -0.9224939346313477 0.6211857795715332 0.6744678020477295; … ; -0.5732729434967041 -0.16946196556091309 0.6235699653625488; -0.9148321151733398 -0.46260619163513184 -0.5166316032409668], bias = [0.10900626331567764, 0.13164696097373962, -0.22048336267471313, -0.10826941579580307, -0.20583660900592804, -0.5592323541641235, -0.3280256986618042, -0.24063409864902496, -0.15558598935604095, -0.3062073290348053]), layer_3 = (weight = [0.41199973225593567 0.5245338082313538 … 0.4106042683124542 0.5015260577201843; -0.34935665130615234 -0.3589525818824768 … -0.5235913395881653 -0.08577283471822739; -0.43845435976982117 -0.4837602972984314 … -0.2922854423522949 0.17420873045921326], bias = [-0.28604868054389954, 0.0878324955701828, -0.2242458611726761]), layer_4 = (weight = [-0.49367260932922363 -0.6222386360168457 -0.5545613765716553], bias = [-0.20377019047737122])))

We then create the cache, and again all of this is handled internally in ODINN. We need to instantiate manually here to demonstrate how the VJPs can be customized.

glacier_idx = 1
simulation.cache = ODINN.init_cache(model, simulation, glacier_idx, θ)
Sleipnir.ModelCache{SIA2DCache{Float64, Int64, ScalarCache, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, ScalarCacheNoVJP, ScalarCacheNoVJP}, Nothing}(SIA2DCache{Float64, Int64, ScalarCache, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, ScalarCacheNoVJP, ScalarCacheNoVJP}(ScalarCache(fill(0.0), fill(0.0), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ScalarCacheNoVJP(fill(3.0)), fill(1.0), fill(1.0), ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(3.0)), ScalarCacheNoVJP(fill(2.0)), ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(0.0)), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1444.0 1415.0 … 2009.0 2016.0; 1462.0 1428.0 … 2011.0 2016.0; … ; 3049.0 3034.0 … 3202.0 3151.0; 3041.0 3027.0 … 3211.0 3171.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], ScalarCache(fill(7.18946910443e-313), fill(2.04383809168e-312), [6.93984036498385e-310, 6.93992133053385e-310, 6.93984036498464e-310, 6.93992133053385e-310, 6.93984036498543e-310, 6.93992133053385e-310, 6.939914883359e-310, 6.93992133053464e-310, 6.939914883359e-310, 6.93992133053464e-310  …  6.93992133053385e-310, 6.9398403650692e-310, 6.93992133053385e-310, 6.93984036506764e-310, 6.93992133053385e-310, 6.9399256893304e-310, 6.9399256893446e-310, 0.0, 0.0, 0.0]), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], 1, nothing, nothing, nothing, nothing, nothing), nothing)

Finally we demonstrate that this is our custom implementation that is being called:

ODINN.∂law∂θ!(
    params.UDE.grad.VJP_method.regressorADBackend,
    simulation.model.iceflow.A,
    simulation.cache.iceflow.A,
    simulation.cache.iceflow.A_prep_vjps,
    (; T = 1.0), θ)
83-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 ⋮
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

VJP precomputation

Since the law that we have been using so far does not depend on the glacier state, it could be computed once for all at the beginning of the simulation and the VJPs could be precomputed before solving the adjoint iceflow PDE. The definition of the law below illustrates how we can do this in two ways:

Automatic precomputation with DI

law = Law{ScalarCache}(;
    inputs = inputs,
    f! = f!,
    init_cache = init_cache,
    p_VJP! = DIVJP(),
    callback_freq = 0
)
(:T,) -> Array{Float64, 0}   (↧@start  custom VJP  ✅ precomputed (DI))
Success

This law is applied only once before the beginning of the simulation, and the VJP are precomputed automatically.

Manual precomputation

law = Law{ScalarCache}(;
    inputs = inputs,
    f! = f!,
    init_cache = init_cache,
    p_VJP! = function (cache, vjpsPrepLaw, inputs, θ)
        cache.vjp_θ .= ones(length(θ))
    end,
    callback_freq = 0
)
(:T,) -> Array{Float64, 0}   (↧@start  custom VJP  ✅ precomputed)
Success

This law is applied only once before the beginning of the simulation, and the VJP are precomputed using our own implementation.

Simple cache customization

In this last section we illustrate how we can define our own cache to store additional information. Our use case is the interpolation of the VJP on a coarse grid. By coarse grid we mean that in order to evaluate the VJP we do not need the differentiate the law for every value of ice thickness we have on the 2D grid at each time step. We only need to pre-evaluate the VJP for a few values of H (this set of values corresponds to the coarse grid), and then we can interpolate the precomputed VJP at the required values of H. The VJPs on the coarse grid are precomputed before solving the adjoint PDE and the evaluation at the exact points in the adjoint PDE are made using an interpolator that is stored inside the cache object.

params = Parameters(
    simulation = SimulationParameters(rgi_paths = rgi_paths),
    UDE = UDEparameters(grad = ContinuousAdjoint(),
        target = :D_hybrid)
)
nn_model = NeuralNetwork(params)

prescale_bounds = [(-25.0, 0.0), (0.0, 500.0)]
prescale = X -> ODINN._ml_model_prescale(X, prescale_bounds)
postscale = Y -> ODINN._ml_model_postscale(Y, params.physical.maxA)

archi = nn_model.architecture
st = nn_model.st
smodel = ODINN.StatefulLuxLayer{true}(archi, nothing, st)

inputs = (; T = iAvgScalarTemp(), H̄ = iH̄())

f! = let smodel = smodel, prescale = prescale, postscale = postscale
    function (cache, inp, θ)
        Y = map(h -> ODINN._pred_NN([inp.T, h], smodel, θ.Y, prescale, postscale), inp.H̄)
        ODINN.Zygote.@ignore_derivatives cache.value .= Y # # We ignore this in-place affectation in order to be able to differentiate it with Zygote hereafter
        return Y
    end
end
#15 (generic function with 1 method)

We define a new cache struct to store the interpolator:

using Interpolations
mutable struct MatrixCacheInterp <: Cache
    value::Array{Float64, 2}
    vjp_inp::Array{Float64, 2}
    vjp_θ::Array{Float64, 3}
    interp_θ::Interpolations.GriddedInterpolation{
        Vector{Float64}, 1, Vector{Vector{Float64}},
        Interpolations.Gridded{Interpolations.Linear{Interpolations.Throw{OnGrid}}},
        Tuple{Vector{Float64}}}
end
Warning

The cache must of concrete type.

function init_cache_interp(simulation, glacier_idx, θ)
    glacier = simulation.glaciers[glacier_idx]
    (; nx, ny) = glacier
    H_interp = ODINN.create_interpolation(glacier.H₀;
        n_interp_half = simulation.model.trainable_components.target.n_interp_half)
    θvec = ODINN.ComponentVector2Vector(θ)
    grads = [zero(θvec) for i in 1:length(H_interp)]
    grad_itp = interpolate((H_interp,), grads, Gridded(Linear()))
    return MatrixCacheInterp(zeros(nx-1, ny-1), zeros(nx-1, ny-1), zeros(nx-1, ny-1, length(θ)), grad_itp)
end
init_cache_interp (generic function with 1 method)

In order to initialize the cache, we created a fake interpolation grid above. However, this interpolation grid will be computed during the precomputation step based on the provided inputs at the beginning of the adjoint PDE.

Below we define the precomputation function which defines a coarse grid and differentiates the neural network at each of these points.

function p_VJP!(cache, vjpsPrepLaw, inputs, θ)
    H_interp = ODINN.create_interpolation(inputs.H̄;
        n_interp_half = simulation.model.trainable_components.target.n_interp_half)
    grads = Vector{Float64}[]
    for h in H_interp
        ret, = ODINN.Zygote.gradient(_θ -> f!(cache, (; T = inputs.T, H̄ = h), _θ), θ)
        push!(grads, ODINN.ComponentVector2Vector(ret))
    end
    cache.interp_θ = interpolate((H_interp,), grads, Gridded(Linear()))
end
p_VJP! (generic function with 1 method)

Then at each iteration of the adjoint PDE, we use the interpolator that we evaluate with the values in inputs.H̄. Since many of the points are zeros (outside of the glacier), we evaluate the interpolator for H̄=0 only once.

function f_VJP_θ!(cache, inputs, θ)
    H̄ = inputs.H̄
    zero_interp = cache.interp_θ(0.0)
    for i in axes(H̄, 1), j in axes(H̄, 2)

        cache.vjp_θ[i, j, :] = map(h -> ifelse(h == 0.0, zero_interp, cache.interp_θ(h)), H̄[i, j])
    end
end
f_VJP_θ! (generic function with 1 method)

Finally we can define the law:

law = Law{MatrixCacheInterp}(;
    inputs = inputs,
    f! = f!,
    init_cache = init_cache_interp,
    p_VJP! = p_VJP!,
    f_VJP_θ! = f_VJP_θ!,
    f_VJP_input! = function (cache, inputs, θ) # Not implemented in this example
    end
)
(:T, :H̄) -> Matrix{Float64}   (⟳  custom VJP  ✅ precomputed)

As in the previous example, we need to define some objects and make the initialization manually to be able to call the internals of ODINN ODINN.precompute_law_VJP and ODINN.∂law∂θ!.

rgi_ids = ["RGI60-11.03638"]
glaciers = initialize_glaciers(rgi_ids, params)
model = Model(
    iceflow = SIA2Dmodel(params; Y = law),
    mass_balance = nothing,
    regressors = (; Y = nn_model)
)
simulation = Inversion(model, glaciers, params)
θ = simulation.model.trainable_components.θ
glacier_idx = 1
t = simulation.parameters.simulation.tspan[1]
simulation.cache = ODINN.init_cache(model, simulation, glacier_idx, θ)
Sleipnir.ModelCache{SIA2DCache{Float64, Int64, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, Main.MatrixCacheInterp, ScalarCacheNoVJP}, Nothing}(SIA2DCache{Float64, Int64, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, Main.MatrixCacheInterp, ScalarCacheNoVJP}(ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(3.0)), fill(1.0), fill(1.0), ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(3.0)), ScalarCacheNoVJP(fill(2.0)), Main.MatrixCacheInterp([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], 150-element interpolate((::Vector{Float64},), ::Vector{Vector{Float64}}, Gridded(Linear())) with element type Vector{Float64}:
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 ⋮
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ScalarCacheNoVJP(fill(0.0)), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1444.0 1415.0 … 2009.0 2016.0; 1462.0 1428.0 … 2011.0 2016.0; … ; 3049.0 3034.0 … 3202.0 3151.0; 3041.0 3027.0 … 3211.0 3171.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], ScalarCacheNoVJP(fill(6.9397373745099e-310)), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], 1, nothing, nothing, nothing, nothing, nothing), nothing)

Apply once to be able to retrieve the inputs

dH = zero(simulation.cache.iceflow.H)
ODINN.Huginn.SIA2D!(dH, simulation.cache.iceflow.H, simulation, t, θ);

Finally we call the precompute function and the VJP function called at each iteration of the adjoint PDE.

ODINN.precompute_law_VJP(
    simulation.model.iceflow.Y,
    simulation.cache.iceflow.Y,
    simulation.cache.iceflow.Y_prep_vjps,
    simulation,
    glacier_idx, t, θ)

ODINN.∂law∂θ!(
    params.UDE.grad.VJP_method.regressorADBackend,
    simulation.model.iceflow.Y,
    simulation.cache.iceflow.Y,
    simulation.cache.iceflow.Y_prep_vjps,
    (; T = 1.0, H̄ = simulation.cache.iceflow.H̄), θ)

Now let us check that the vjp_θ field of the cache, which is spatially varying, has been populated:

simulation.cache.iceflow.Y.vjp_θ
137×128×86 Array{Float64, 3}:
[:, :, 1] =
 1.29585e-19  1.29585e-19  1.29585e-19  …  1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19  …  1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 ⋮                                      ⋱               
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19  …  1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19  …  1.29585e-19  1.29585e-19
 1.29585e-19  1.29585e-19  1.29585e-19     1.29585e-19  1.29585e-19

[:, :, 2] =
 -1.03158e-19  -1.03158e-19  -1.03158e-19  …  -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19  …  -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
  ⋮                                        ⋱                
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19  …  -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19  …  -1.03158e-19  -1.03158e-19
 -1.03158e-19  -1.03158e-19  -1.03158e-19     -1.03158e-19  -1.03158e-19

[:, :, 3] =
 7.21833e-19  7.21833e-19  7.21833e-19  …  7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19  …  7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 ⋮                                      ⋱               
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19  …  7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19  …  7.21833e-19  7.21833e-19
 7.21833e-19  7.21833e-19  7.21833e-19     7.21833e-19  7.21833e-19

;;; … 

[:, :, 84] =
 7.40594e-18  7.40594e-18  7.40594e-18  …  7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18  …  7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 ⋮                                      ⋱               
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18  …  7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18  …  7.40594e-18  7.40594e-18
 7.40594e-18  7.40594e-18  7.40594e-18     7.40594e-18  7.40594e-18

[:, :, 85] =
 4.2713e-18  4.2713e-18  4.2713e-18  …  4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18  …  4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 ⋮                                   ⋱  ⋮                       
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18  …  4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18  …  4.2713e-18  4.2713e-18  4.2713e-18
 4.2713e-18  4.2713e-18  4.2713e-18     4.2713e-18  4.2713e-18  4.2713e-18

[:, :, 86] =
 1.34754e-17  1.34754e-17  1.34754e-17  …  1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17  …  1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 ⋮                                      ⋱               
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17  …  1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17  …  1.34754e-17  1.34754e-17
 1.34754e-17  1.34754e-17  1.34754e-17     1.34754e-17  1.34754e-17

Frequently Asked Questions

  • Can I use the preparation object in the p_VJP!/f_VJP_* functions?

No it is not possible for the moment to use the preparation object inside these functions. The preparation object is used to store things precompiled by DifferentiationInterface.jl when p_VJP!=DIVJP() and hence it excludes its use in p_VJP!. As for f_VJP_*, the preparation object cannot be accessed for the moment. If there is a need, we might add it as an argument in the future.


This page was generated using Literate.jl.