Law VJP customization
This tutorial explains how to customize VJP (vector-Jacobian product) computation of the laws in ODINN.jl and clarifies the runtime flow used internally by the library. It explains which functions are part of the public, user-facing customization API and which are internal helpers used by ODINN when an AD backend is required.
The features presented in this section are considered as advanced features and this tutorial was written primarily for ODINN developers.
It assumes that you have followed the Laws tutorial.
using ODINN
rgi_ids = ["RGI60-11.03638"]
rgi_paths = get_rgi_paths()
params = Parameters(
simulation = SimulationParameters(rgi_paths = rgi_paths),
UDE = UDEparameters(grad = ContinuousAdjoint())
)
nn_model = NeuralNetwork(params)--- NeuralNetwork ---
architecture:
Chain(
layer_1 = Dense(1 => 3, #101), # 6 parameters
layer_2 = Dense(3 => 10, #102), # 40 parameters
layer_3 = Dense(10 => 3, #103), # 33 parameters
layer_4 = Dense(3 => 1, σ), # 4 parameters
) # Total: 83 parameters,
# plus 0 states.
θ: ComponentVector of length 83Explanations
High-level summary
At the user level the customization can be made by implementing hand-written VJPs through the following functions:
f_VJP_input!(...): VJP w.r.t. inputsf_VJP_θ!(...): VJP w.r.t. parameters θ- You may also implement your own precompute function to cache expensive computations which is the purpose of
p_VJP!(...). This function is called before solving the adjoint iceflow PDE.
Internally when the user does NOT provide VJPs, ODINN uses a default AD backend (via DifferentiationInterface.jl) to compute the VJPs of the laws. To support efficient reverse-mode execution, ODINN will:
- compile and precompute adjoint-related helper functions and
- store preparation objects that are used later during the in adjoint PDE.
This mechanism is triggered by prepare_vjp_law.
Internal function roles
prepare_vjp_law(internal)Signature used in the codebase:
prepare_vjp_law( simulation, law::AbstractLaw, law_cache, θ, glacier_idx, )- Intent and behavior:
- This is an internal routine. It is NOT intended to be called by users directly.
- It is invoked when ODINN must fall back to the AD backend (with DifferentiationInterface.jl) because the law did not supply explicit VJP functions (
f_VJP_input!/f_VJP_θ!or becausep_VJP!is set toDIVJP()). - Its job is to precompile and prepare the AD-based VJP code for a given law and to produce preparation objects that store preparation results from DifferentiationInterface.jl.
prepare_vjp_lawis typically called just after the iceflow model / law objects have been instantiated — i.e., early in the setup — so that preparations are ready before solving or adjoint runs.
- Intent and behavior:
precompute_law_VJP(used before solving the adjoint PDE)The typical signature in the codebase is:
precompute_law_VJP( law::AbstractLaw, cache, vjpsPrepLaw, simulation, glacier_idx, t, θ )- Intent and behavior:
- This function precomputes VJP-related artifacts before the adjoint iceflow PDE is solved for given time
tand parametersθ. - It typically uses the
vjpsPrepLaw(anAbstractPrepVJPinstance produced earlier byprepare_vjp_law) together with thecacheandsimulationobject. The produced results are cached incacheand are optionally consumed later bylaw_VJP_input/law_VJP_θduring the adjoint solve.
- This function precomputes VJP-related artifacts before the adjoint iceflow PDE is solved for given time
- Intent and behavior:
Entry points used in the adjoint PDE
These functions are the actual runtime entry points used when computing contributions of the laws to the gradient in the adjoint PDE:
law_VJP_θ(law::AbstractLaw, cache, simulation, glacier_idx, t, θ)and
law_VJP_input(law::AbstractLaw, cache, simulation, glacier_idx, t, θ)- Intent and behavior:
- These are called during the adjoint solve to compute parameter and input VJPs for the law at time
tand for parametersθ. - They can either compute the VJPs directly or use cached VJP information that has been already computed in the user-supplied
p_VJP!VJP function. Thecacheallows storing useful information from the forward or from the precomputation step. - They therefore carry the runtime context (simulation, glacier index, time, θ) which is necessary for adjoint calculations.
- These are called during the adjoint solve to compute parameter and input VJPs for the law at time
- Intent and behavior:
Workflow
For the wide audience we do not recommend to play with the VJPs. ODINN comes with default parameters and the average user does not need to customize the VJPs. Keeping the default values will work fine.
Advanced users seeking maximum performance can customize the VJPs which can significantly speed-up the code.
How do the pieces compose in practice?
- If you, as a user, provide custom VJP functions (through
f_VJP_input!/f_VJP_θ!, or throughp_VJP!), ODINN will use them directly in the adjoint and will skip the AD fallback path. You can also provide your own precompute wrapper and cache to optimize expensive computations. - If you do NOT provide VJP functions, ODINN runs the AD fallback:
prepare_vjp_lawruns early (post-instantiation) to compile/prepare AD-based helpers and returns someAbstractPrepVJPobject.precompute_law_VJPis skipped.- During the adjoint solve,
law_VJP_inputandlaw_VJP_θuse the preparation objects precompiled inprepare_vjp_lawto automatically differentiatef!with DifferentiationInterface.jl and obtain the VJPs of the law with respect to the inputs and to the parametersθ.
You can change the default AD backend for laws that do not have custom VJPs in the VJP type, for example by setting VJP_method = DiscreteVJP(regressorADBackend = DI.AutoZygote()) when you define the adjoint method.
User level customization
What is user-visible and can be customized?
f_VJP_input!(cache, inputs, θ)— compute the VJP with respect to the inputs and store the result incache.vjp_inpf_VJP_θ!(cache, inputs, θ)— compute the VJP with respect toθand store the result incache.vjp_θp_VJP!(cache, vjpsPrepLaw, inputs, θ)— if you want to precompute some components (or even the whole VJPs when possible) before solving the adjoint iceflow PDE- custom cache implementations (described below)
Notes on cache definition
The cache parameter that is threaded through p_VJP!/f_VJP_* calls is the place to store artifacts useful for efficient computation as well as the results of the VJPs computation. The following fields are mandatory:
value: a placeholder to store the result of the forward evaluation, can be of any typevjp_θ: a placeholder to store the result of the VJP with respect toθ, depending on the type of law that is defined, it can be a vector or a 3 dimensional arrayvjp_inp: a placeholder to store the result of the VJP with respect to the inputs, must be of a type that matches the one of the inputs
In order to know the type of the inputs, simply run generate_inputs(law.f.inputs, simulation, glacier_idx, t).
Using the preparation object
Best practices and debugging tips
- If you supply custom VJPs, test them with finite-difference checks for both inputs and parameters. ODINN does not check the correctness of your implementation!
- If you rely on ODINN's AD fallback, be aware that
prepare_vjp_lawwill precompile and prepare AD helpers at model instantiation time — expect longer setup time but faster adjoint runs thereafter. - Inspect/validate cache content if you get inconsistent adjoints — a stale or incorrect cache entry is a common cause.
- Although the API is designed to provide everything you need as arguments, if your VJP needs anything from the forward pass, ensure it is stored in the cache.
Simple VJP customization
We will explore how we can customize the VJP computation of the law that is used in the Laws tutorial. The cache used for this law is a ScalarCache since the output of this law is a scalar value A, the creep coefficient. We can confirm that this type defines the fields needed for the VJP computation:
fieldnames(ScalarCache)(:value, :vjp_inp, :vjp_θ)Before defining the law, we retrieve the model architecture, the physical parameters to be used inside the f! function of the law and we define the inputs:
archi = nn_model.architecture
st = nn_model.st
smodel = ODINN.StatefulLuxLayer{true}(archi, nothing, st)
min_NN = params.physical.minA
max_NN = params.physical.maxA
inputs = (; T = iAvgScalarTemp())(T = averaged_long_term_temperature: iAvgScalarTemp(),)And then the f! and init_cache functions:
f! = let smodel = smodel, min_NN = min_NN, max_NN = max_NN
function (cache, inp, θ)
inp = collect(values(inp))
A = only(ODINN.scale(smodel(inp, θ.A), (min_NN, max_NN)))
ODINN.Zygote.@ignore_derivatives cache.value .= A # We ignore this in-place affectation in order to be able to differentiate it with Zygote hereafter
return A
end
end
function init_cache(simulation, glacier_idx, θ)
return ScalarCache(zeros(), zeros(), zero(θ))
endinit_cache (generic function with 1 method)The declaration of the law without VJP customization would be:
law = Law{ScalarCache}(;
inputs = inputs,
f! = f!,
init_cache = init_cache
)(:T,) -> Array{Float64, 0} (⟳ auto VJP ❌ precomputed)
We see from the output that the VJPs are inferred using DifferentiationInterface.jl and that ODINN does not use precomputation. Now let's try to customize the VJPs by manually implementing the AD step:
law = Law{ScalarCache}(;
inputs = inputs,
f! = f!,
f_VJP_input! = function (cache, inputs, θ)
nothing # The input does not depend on the glacier state
end,
f_VJP_θ! = function (cache, inputs, θ)
cache.vjp_θ .= ones(length(θ)) # The VJP is wrong on purpose to check that this function is properly called hereafter
end,
init_cache = init_cache
)(:T,) -> Array{Float64, 0} (⟳ custom VJP ❌ precomputed)
In order to instantiate the cache, we need to define the model:
rgi_ids = ["RGI60-11.03638"]
glaciers = initialize_glaciers(rgi_ids, params)
model = Model(
iceflow = SIA2Dmodel(params; A = law),
mass_balance = nothing,
regressors = (; A = nn_model)
)
simulation = Inversion(model, glaciers, params)Inversion{Sleipnir.Model{SIA2Dmodel{Float64, Law{ScalarCache, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#1#3"{LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, Float64, Float64}}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#5#7"}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#6#8"}, typeof(Main.init_cache), Nothing, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, typeof(Sleipnir.emptyPrepVJPWithInputs)}, CustomVJP}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#9#10"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#11#12"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#13#14"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#15#16"}, NullLaw, NullLaw}, Nothing, ODINN.TrainableComponents{NeuralNetwork{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(θ = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyIC, SIA2D_A_target, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(A = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}}}, Sleipnir.ModelCache{SIA2DCache{Float64, Int64, ScalarCache, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, ScalarCacheNoVJP, ScalarCacheNoVJP}, Nothing}, Glacier2D{Float64, Int64, Climate2D{Rasters.RasterStack{(:prcp, :temp, :gradient), @NamedTuple{prcp::Float64, temp::Float64, gradient::Float64}, 1, @NamedTuple{prcp::Vector{Float64}, temp::Vector{Float64}, gradient::Vector{Float64}}, Tuple{DimensionalData.Dimensions.Ti{DimensionalData.Dimensions.Lookups.Sampled{Dates.DateTime, Vector{Dates.DateTime}, DimensionalData.Dimensions.Lookups.ForwardOrdered, DimensionalData.Dimensions.Lookups.Irregular{Tuple{Nothing, Nothing}}, DimensionalData.Dimensions.Lookups.Points, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}}}, Tuple{}, @NamedTuple{prcp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, temp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, gradient::Tuple{DimensionalData.Dimensions.Ti{Colon}}}, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, @NamedTuple{prcp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, temp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, gradient::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}, Nothing}, Rasters.RasterStack{(:prcp, :temp, :gradient), @NamedTuple{prcp::Float64, temp::Float64, gradient::Float64}, 1, @NamedTuple{prcp::Vector{Float64}, temp::Vector{Float64}, gradient::Vector{Float64}}, Tuple{DimensionalData.Dimensions.Ti{DimensionalData.Dimensions.Lookups.Sampled{Dates.DateTime, Vector{Dates.DateTime}, DimensionalData.Dimensions.Lookups.ForwardOrdered, DimensionalData.Dimensions.Lookups.Irregular{Tuple{Nothing, Nothing}}, DimensionalData.Dimensions.Lookups.Points, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}}}, Tuple{}, @NamedTuple{prcp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, temp::Tuple{DimensionalData.Dimensions.Ti{Colon}}, gradient::Tuple{DimensionalData.Dimensions.Ti{Colon}}}, DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, @NamedTuple{prcp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, temp::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}, gradient::DimensionalData.Dimensions.Lookups.Metadata{Rasters.NCDsource, Dict{String, Any}}}, Nothing}, Sleipnir.ClimateStep{Float64}, Climate2Dstep{Float64}, Float64}, Nothing, Nothing}, Results{Sleipnir.Results{Float64, Int64}, TrainingStats{Float64, Int64}}}(Sleipnir.Model{SIA2Dmodel{Float64, Law{ScalarCache, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#1#3"{LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, Float64, Float64}}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#5#7"}, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, Main.var"#6#8"}, typeof(Main.init_cache), Nothing, Sleipnir.GenInputsAndApply{@NamedTuple{T::iAvgScalarTemp}, typeof(Sleipnir.emptyPrepVJPWithInputs)}, CustomVJP}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#9#10"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#11#12"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#13#14"}, ConstantLaw{ScalarCacheNoVJP, Huginn.var"#15#16"}, NullLaw, NullLaw}, Nothing, ODINN.TrainableComponents{NeuralNetwork{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{ODINN.var"#101#105", Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{ODINN.var"#102#106", Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{ODINN.var"#103#107", Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(θ = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyTrainableModel, ODINN.emptyIC, SIA2D_A_target, ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(A = ViewAxis(1:83, Axis(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, Shaped1DAxis((3,))))), layer_2 = ViewAxis(7:46, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3))), bias = ViewAxis(31:40, Shaped1DAxis((10,))))), layer_3 = ViewAxis(47:79, Axis(weight = ViewAxis(1:30, ShapedAxis((3, 10))), bias = ViewAxis(31:33, Shaped1DAxis((3,))))), layer_4 = ViewAxis(80:83, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, Shaped1DAxis((1,))))))),)}}}}}(SIA2D iceflow equation = ∇(D ∇S) with D = U H̄
and U = C (ρg)^(pq) H̄^(pq+1) ∇S^(p-1) + Γ H̄^(n+2) ∇S^(n-1)
Γ = 2A (ρg)^n /(n+2)
A: (:T,) -> Array{Float64, 0} (⟳ custom VJP ❌ precomputed)
C: ConstantLaw -> Array{Float64, 0}
n: ConstantLaw -> Array{Float64, 0}
p: ConstantLaw -> Array{Float64, 0}
q: ConstantLaw -> Array{Float64, 0}
where
T => averaged_long_term_temperature
, nothing, A: --- NeuralNetwork ---
architecture:
Chain(
layer_1 = Dense(1 => 3, #101), # 6 parameters
layer_2 = Dense(3 => 10, #102), # 40 parameters
layer_3 = Dense(10 => 3, #103), # 33 parameters
layer_4 = Dense(3 => 1, σ), # 4 parameters
) # Total: 83 parameters,
# plus 0 states.
θ: ComponentVector of length 83
), nothing, 1-element Vector{Glacier2D} distributed over regions 11 (x1)
RGI60-11.03638
, Sleipnir.Parameters{PhysicalParameters{Float64}, SimulationParameters{Int64, Float64, MeanDateVelocityMapping}, Hyperparameters{Float64, Int64}, SolverParameters{Float64, Int64}, UDEparameters{ContinuousAdjoint{Float64, Int64, DiscreteVJP{ADTypes.AutoMooncake{Nothing}}, EnzymeVJP}}, InversionParameters{Float64}}(PhysicalParameters{Float64}(900.0, 9.81, 1.0e-10, 1.0, 8.0e-17, 8.5e-20, 8.0e-17, 8.5e-20, 1.0, -25.0, 5.0e-18), SimulationParameters{Int64, Float64, MeanDateVelocityMapping}(true, true, true, true, 1.0, false, false, (2010.0, 2015.0), 0.08333333333333333, true, 4, "", false, Dict("RGI60-11.00897" => "per_glacier/RGI60-11/RGI60-11.00/RGI60-11.00897", "RGI60-08.00213" => "per_glacier/RGI60-08/RGI60-08.00/RGI60-08.00213", "RGI60-08.00147" => "per_glacier/RGI60-08/RGI60-08.00/RGI60-08.00147", "RGI60-11.01270" => "per_glacier/RGI60-11/RGI60-11.01/RGI60-11.01270", "RGI60-11.03646" => "per_glacier/RGI60-11/RGI60-11.03/RGI60-11.03646", "RGI60-11.03232" => "per_glacier/RGI60-11/RGI60-11.03/RGI60-11.03232", "RGI60-01.22174" => "per_glacier/RGI60-01/RGI60-01.22/RGI60-01.22174", "RGI60-07.00274" => "per_glacier/RGI60-07/RGI60-07.00/RGI60-07.00274", "RGI60-03.04207" => "per_glacier/RGI60-03/RGI60-03.04/RGI60-03.04207", "RGI60-04.04351" => "per_glacier/RGI60-04/RGI60-04.04/RGI60-04.04351"…), "Farinotti19", MeanDateVelocityMapping(:nearest), 1), Hyperparameters{Float64, Int64}(1, 1, Float64[], Optim.BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Optim.Flat}(LineSearches.InitialStatic{Float64}(1.0, false), LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}(0.1, 0.9, Inf, 5.0, 1.0e-6, 0.66, 50, 0.1, 0, Base.RefValue{Bool}(false), nothing, false), nothing, 0.001, Optim.Flat()), 0.0, 50, 15), SolverParameters{Float64, Int64}(OrdinaryDiffEqLowStorageRK.RDPK3Sp35{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), 1.0e-12, 0.08333333333333333, Float64[], false, true, 10, 100000), UDEparameters{ContinuousAdjoint{Float64, Int64, DiscreteVJP{ADTypes.AutoMooncake{Nothing}}, EnzymeVJP}}(SciMLSensitivity.GaussAdjoint{0, true, Val{:central}, SciMLSensitivity.EnzymeVJP}(SciMLSensitivity.EnzymeVJP(0), false), ADTypes.AutoEnzyme(), ContinuousAdjoint{Float64, Int64, DiscreteVJP{ADTypes.AutoMooncake{Nothing}}, EnzymeVJP}(DiscreteVJP{ADTypes.AutoMooncake{Nothing}}(ADTypes.AutoMooncake()), OrdinaryDiffEqLowStorageRK.RDPK3Sp35{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), 1.0e-8, 1.0e-8, 0.08333333333333333, :Linear, 200, EnzymeVJP()), "AD+AD", MultiLoss{Tuple{LossH{L2Sum{Int64}}}, Vector{Float64}}((LossH{L2Sum{Int64}}(L2Sum{Int64}(3)),), [1.0]), :A, :identity), InversionParameters{Float64}([1.0], [0.0], [Inf], [1, 1], 0.001, 0.001, Optim.BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Nothing, Optim.Flat}(LineSearches.InitialStatic{Float64}(1.0, false), LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}(0.1, 0.9, Inf, 5.0, 1.0e-6, 0.66, 50, 0.1, 0, Base.RefValue{Bool}(false), nothing, false), nothing, nothing, Optim.Flat()))), Results{Sleipnir.Results{Float64, Int64}, TrainingStats{Float64, Int64}}(Sleipnir.Results{Float64, Int64}[], TrainingStats{Float64, Int64}(nothing, Float64[], 0, nothing, ComponentArrays.ComponentVector[], ComponentArrays.ComponentVector[], nothing, Dates.DateTime("0000-01-01T00:00:00"))))We will also need θ in order to call the VJPs of the law manually although in practice we do not have to worry about retrieving this:
θ = simulation.model.trainable_components.θComponentVector{Float64}(A = (layer_1 = (weight = [-0.206559956073761; -0.5560190081596375; -1.624756097793579;;], bias = [0.1391124725341797, -0.6567966938018799, 0.3310384750366211]), layer_2 = (weight = [0.026965618133544922 0.7515738010406494 -0.7434771060943604; -0.9224939346313477 0.6211857795715332 0.6744678020477295; … ; -0.5732729434967041 -0.16946196556091309 0.6235699653625488; -0.9148321151733398 -0.46260619163513184 -0.5166316032409668], bias = [0.10900626331567764, 0.13164696097373962, -0.22048336267471313, -0.10826941579580307, -0.20583660900592804, -0.5592323541641235, -0.3280256986618042, -0.24063409864902496, -0.15558598935604095, -0.3062073290348053]), layer_3 = (weight = [0.41199973225593567 0.5245338082313538 … 0.4106042683124542 0.5015260577201843; -0.34935665130615234 -0.3589525818824768 … -0.5235913395881653 -0.08577283471822739; -0.43845435976982117 -0.4837602972984314 … -0.2922854423522949 0.17420873045921326], bias = [-0.28604868054389954, 0.0878324955701828, -0.2242458611726761]), layer_4 = (weight = [-0.49367260932922363 -0.6222386360168457 -0.5545613765716553], bias = [-0.20377019047737122])))We then create the cache, and again all of this is handled internally in ODINN. We need to instantiate manually here to demonstrate how the VJPs can be customized.
glacier_idx = 1
simulation.cache = ODINN.init_cache(model, simulation, glacier_idx, θ)Sleipnir.ModelCache{SIA2DCache{Float64, Int64, ScalarCache, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, ScalarCacheNoVJP, ScalarCacheNoVJP}, Nothing}(SIA2DCache{Float64, Int64, ScalarCache, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, ScalarCacheNoVJP, ScalarCacheNoVJP}(ScalarCache(fill(0.0), fill(0.0), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ScalarCacheNoVJP(fill(3.0)), fill(1.0), fill(1.0), ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(3.0)), ScalarCacheNoVJP(fill(2.0)), ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(0.0)), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1444.0 1415.0 … 2009.0 2016.0; 1462.0 1428.0 … 2011.0 2016.0; … ; 3049.0 3034.0 … 3202.0 3151.0; 3041.0 3027.0 … 3211.0 3171.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], ScalarCache(fill(7.18946910443e-313), fill(2.04383809168e-312), [6.93984036498385e-310, 6.93992133053385e-310, 6.93984036498464e-310, 6.93992133053385e-310, 6.93984036498543e-310, 6.93992133053385e-310, 6.939914883359e-310, 6.93992133053464e-310, 6.939914883359e-310, 6.93992133053464e-310 … 6.93992133053385e-310, 6.9398403650692e-310, 6.93992133053385e-310, 6.93984036506764e-310, 6.93992133053385e-310, 6.9399256893304e-310, 6.9399256893446e-310, 0.0, 0.0, 0.0]), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], 1, nothing, nothing, nothing, nothing, nothing), nothing)Finally we demonstrate that this is our custom implementation that is being called:
ODINN.∂law∂θ!(
params.UDE.grad.VJP_method.regressorADBackend,
simulation.model.iceflow.A,
simulation.cache.iceflow.A,
simulation.cache.iceflow.A_prep_vjps,
(; T = 1.0), θ)83-element Vector{Float64}:
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
⋮
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0VJP precomputation
Since the law that we have been using so far does not depend on the glacier state, it could be computed once for all at the beginning of the simulation and the VJPs could be precomputed before solving the adjoint iceflow PDE. The definition of the law below illustrates how we can do this in two ways:
- by using DifferentiationInterface.jl to automatically compute the VJPs;
- by manually precomputing the VJPs in the
p_VJPfunction.
Automatic precomputation with DI
law = Law{ScalarCache}(;
inputs = inputs,
f! = f!,
init_cache = init_cache,
p_VJP! = DIVJP(),
callback_freq = 0
)(:T,) -> Array{Float64, 0} (↧@start custom VJP ✅ precomputed (DI))
This law is applied only once before the beginning of the simulation, and the VJP are precomputed automatically.
Manual precomputation
law = Law{ScalarCache}(;
inputs = inputs,
f! = f!,
init_cache = init_cache,
p_VJP! = function (cache, vjpsPrepLaw, inputs, θ)
cache.vjp_θ .= ones(length(θ))
end,
callback_freq = 0
)(:T,) -> Array{Float64, 0} (↧@start custom VJP ✅ precomputed)
This law is applied only once before the beginning of the simulation, and the VJP are precomputed using our own implementation.
Simple cache customization
In this last section we illustrate how we can define our own cache to store additional information. Our use case is the interpolation of the VJP on a coarse grid. By coarse grid we mean that in order to evaluate the VJP we do not need the differentiate the law for every value of ice thickness we have on the 2D grid at each time step. We only need to pre-evaluate the VJP for a few values of H (this set of values corresponds to the coarse grid), and then we can interpolate the precomputed VJP at the required values of H. The VJPs on the coarse grid are precomputed before solving the adjoint PDE and the evaluation at the exact points in the adjoint PDE are made using an interpolator that is stored inside the cache object.
params = Parameters(
simulation = SimulationParameters(rgi_paths = rgi_paths),
UDE = UDEparameters(grad = ContinuousAdjoint(),
target = :D_hybrid)
)
nn_model = NeuralNetwork(params)
prescale_bounds = [(-25.0, 0.0), (0.0, 500.0)]
prescale = X -> ODINN._ml_model_prescale(X, prescale_bounds)
postscale = Y -> ODINN._ml_model_postscale(Y, params.physical.maxA)
archi = nn_model.architecture
st = nn_model.st
smodel = ODINN.StatefulLuxLayer{true}(archi, nothing, st)
inputs = (; T = iAvgScalarTemp(), H̄ = iH̄())
f! = let smodel = smodel, prescale = prescale, postscale = postscale
function (cache, inp, θ)
Y = map(h -> ODINN._pred_NN([inp.T, h], smodel, θ.Y, prescale, postscale), inp.H̄)
ODINN.Zygote.@ignore_derivatives cache.value .= Y # # We ignore this in-place affectation in order to be able to differentiate it with Zygote hereafter
return Y
end
end#15 (generic function with 1 method)We define a new cache struct to store the interpolator:
using Interpolations
mutable struct MatrixCacheInterp <: Cache
value::Array{Float64, 2}
vjp_inp::Array{Float64, 2}
vjp_θ::Array{Float64, 3}
interp_θ::Interpolations.GriddedInterpolation{
Vector{Float64}, 1, Vector{Vector{Float64}},
Interpolations.Gridded{Interpolations.Linear{Interpolations.Throw{OnGrid}}},
Tuple{Vector{Float64}}}
endfunction init_cache_interp(simulation, glacier_idx, θ)
glacier = simulation.glaciers[glacier_idx]
(; nx, ny) = glacier
H_interp = ODINN.create_interpolation(glacier.H₀;
n_interp_half = simulation.model.trainable_components.target.n_interp_half)
θvec = ODINN.ComponentVector2Vector(θ)
grads = [zero(θvec) for i in 1:length(H_interp)]
grad_itp = interpolate((H_interp,), grads, Gridded(Linear()))
return MatrixCacheInterp(zeros(nx-1, ny-1), zeros(nx-1, ny-1), zeros(nx-1, ny-1, length(θ)), grad_itp)
endinit_cache_interp (generic function with 1 method)In order to initialize the cache, we created a fake interpolation grid above. However, this interpolation grid will be computed during the precomputation step based on the provided inputs at the beginning of the adjoint PDE.
Below we define the precomputation function which defines a coarse grid and differentiates the neural network at each of these points.
function p_VJP!(cache, vjpsPrepLaw, inputs, θ)
H_interp = ODINN.create_interpolation(inputs.H̄;
n_interp_half = simulation.model.trainable_components.target.n_interp_half)
grads = Vector{Float64}[]
for h in H_interp
ret, = ODINN.Zygote.gradient(_θ -> f!(cache, (; T = inputs.T, H̄ = h), _θ), θ)
push!(grads, ODINN.ComponentVector2Vector(ret))
end
cache.interp_θ = interpolate((H_interp,), grads, Gridded(Linear()))
endp_VJP! (generic function with 1 method)Then at each iteration of the adjoint PDE, we use the interpolator that we evaluate with the values in inputs.H̄. Since many of the points are zeros (outside of the glacier), we evaluate the interpolator for H̄=0 only once.
function f_VJP_θ!(cache, inputs, θ)
H̄ = inputs.H̄
zero_interp = cache.interp_θ(0.0)
for i in axes(H̄, 1), j in axes(H̄, 2)
cache.vjp_θ[i, j, :] = map(h -> ifelse(h == 0.0, zero_interp, cache.interp_θ(h)), H̄[i, j])
end
endf_VJP_θ! (generic function with 1 method)Finally we can define the law:
law = Law{MatrixCacheInterp}(;
inputs = inputs,
f! = f!,
init_cache = init_cache_interp,
p_VJP! = p_VJP!,
f_VJP_θ! = f_VJP_θ!,
f_VJP_input! = function (cache, inputs, θ) # Not implemented in this example
end
)(:T, :H̄) -> Matrix{Float64} (⟳ custom VJP ✅ precomputed)
As in the previous example, we need to define some objects and make the initialization manually to be able to call the internals of ODINN ODINN.precompute_law_VJP and ODINN.∂law∂θ!.
rgi_ids = ["RGI60-11.03638"]
glaciers = initialize_glaciers(rgi_ids, params)
model = Model(
iceflow = SIA2Dmodel(params; Y = law),
mass_balance = nothing,
regressors = (; Y = nn_model)
)
simulation = Inversion(model, glaciers, params)
θ = simulation.model.trainable_components.θ
glacier_idx = 1
t = simulation.parameters.simulation.tspan[1]
simulation.cache = ODINN.init_cache(model, simulation, glacier_idx, θ)Sleipnir.ModelCache{SIA2DCache{Float64, Int64, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, Main.MatrixCacheInterp, ScalarCacheNoVJP}, Nothing}(SIA2DCache{Float64, Int64, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, ScalarCacheNoVJP, Array{Float64, 0}, Array{Float64, 0}, Main.MatrixCacheInterp, ScalarCacheNoVJP}(ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(3.0)), fill(1.0), fill(1.0), ScalarCacheNoVJP(fill(0.0)), ScalarCacheNoVJP(fill(3.0)), ScalarCacheNoVJP(fill(2.0)), Main.MatrixCacheInterp([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], 150-element interpolate((::Vector{Float64},), ::Vector{Vector{Float64}}, Gridded(Linear())) with element type Vector{Float64}:
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
⋮
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ScalarCacheNoVJP(fill(0.0)), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [1444.0 1415.0 … 2009.0 2016.0; 1462.0 1428.0 … 2011.0 2016.0; … ; 3049.0 3034.0 … 3202.0 3151.0; 3041.0 3027.0 … 3211.0 3171.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], ScalarCacheNoVJP(fill(6.9397373745099e-310)), [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0], [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], 1, nothing, nothing, nothing, nothing, nothing), nothing)Apply once to be able to retrieve the inputs
dH = zero(simulation.cache.iceflow.H)
ODINN.Huginn.SIA2D!(dH, simulation.cache.iceflow.H, simulation, t, θ);Finally we call the precompute function and the VJP function called at each iteration of the adjoint PDE.
ODINN.precompute_law_VJP(
simulation.model.iceflow.Y,
simulation.cache.iceflow.Y,
simulation.cache.iceflow.Y_prep_vjps,
simulation,
glacier_idx, t, θ)
ODINN.∂law∂θ!(
params.UDE.grad.VJP_method.regressorADBackend,
simulation.model.iceflow.Y,
simulation.cache.iceflow.Y,
simulation.cache.iceflow.Y_prep_vjps,
(; T = 1.0, H̄ = simulation.cache.iceflow.H̄), θ)Now let us check that the vjp_θ field of the cache, which is spatially varying, has been populated:
simulation.cache.iceflow.Y.vjp_θ137×128×86 Array{Float64, 3}:
[:, :, 1] =
1.29585e-19 1.29585e-19 1.29585e-19 … 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 … 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
⋮ ⋱
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 … 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 … 1.29585e-19 1.29585e-19
1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19 1.29585e-19
[:, :, 2] =
-1.03158e-19 -1.03158e-19 -1.03158e-19 … -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 … -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
⋮ ⋱
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 … -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 … -1.03158e-19 -1.03158e-19
-1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19 -1.03158e-19
[:, :, 3] =
7.21833e-19 7.21833e-19 7.21833e-19 … 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 … 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
⋮ ⋱
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 … 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 … 7.21833e-19 7.21833e-19
7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19 7.21833e-19
;;; …
[:, :, 84] =
7.40594e-18 7.40594e-18 7.40594e-18 … 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 … 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
⋮ ⋱
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 … 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 … 7.40594e-18 7.40594e-18
7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18 7.40594e-18
[:, :, 85] =
4.2713e-18 4.2713e-18 4.2713e-18 … 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 … 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
⋮ ⋱ ⋮
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 … 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 … 4.2713e-18 4.2713e-18 4.2713e-18
4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18 4.2713e-18
[:, :, 86] =
1.34754e-17 1.34754e-17 1.34754e-17 … 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 … 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
⋮ ⋱
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 … 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 … 1.34754e-17 1.34754e-17
1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17 1.34754e-17Frequently Asked Questions
- Can I use the preparation object in the
p_VJP!/f_VJP_*functions?
No it is not possible for the moment to use the preparation object inside these functions. The preparation object is used to store things precompiled by DifferentiationInterface.jl when p_VJP!=DIVJP() and hence it excludes its use in p_VJP!. As for f_VJP_*, the preparation object cannot be accessed for the moment. If there is a need, we might add it as an argument in the future.
This page was generated using Literate.jl.