Open
Description
I think this is more of a question rather than an issue. I was curious if the type stability of prepare_gradient
could be improved somehow. Here's an example of where I see some type stability issues come up:
julia> foo(x) = sum(x .^ 2)
foo (generic function with 1 method)
julia> Test.@inferred prepare_gradient(f, AutoMooncake(;config=nothing), rand(32))
ERROR: return type DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{Tuple{var"#13#14", AutoMooncake{Nothing}, Vector{Float64}, Tuple{}}, Mooncake.Cache{Mooncake.DerivedRule{Tuple{var"#13#14", Vector{Float64}}, Tuple{Mooncake.CoDual{var"#13#14", Mooncake.NoFData}, Mooncake.CoDual{Vector{Float64}, Vector{Float64}}}, Mooncake.CoDual{Float64, Mooncake.NoFData}, Tuple{Float64}, Tuple{Mooncake.NoRData, Mooncake.NoRData}, false, Val{2}}, Nothing, Tuple{Mooncake.NoTangent, Vector{Float64}}}} does not match inferred return type DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{Tuple{var"#13#14", AutoMooncake{Nothing}, Vector{Float64}, Tuple{}}, Tcache} where Tcache<:(Mooncake.Cache{_A, Nothing, Tuple{Mooncake.NoTangent, Vector{Float64}}} where _A)
Stacktrace:
Here's a formatted version of those types:
Return type:
DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{
Tuple{var"#13#14",AutoMooncake{Nothing},Vector{Float64},Tuple{}},
Mooncake.Cache{
Mooncake.DerivedRule{
Tuple{var"#13#14",Vector{Float64}},
Tuple{
Mooncake.CoDual{var"#13#14",Mooncake.NoFData},
Mooncake.CoDual{Vector{Float64},Vector{Float64}},
},
Mooncake.CoDual{Float64,Mooncake.NoFData},
Tuple{Float64},
Tuple{Mooncake.NoRData,Mooncake.NoRData},
false,
Val{2},
},
Nothing,
Tuple{Mooncake.NoTangent,Vector{Float64}},
},
}
Inferred type:
DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{
Tuple{var"#13#14",AutoMooncake{Nothing},Vector{Float64},Tuple{}},
Tcache
} where {Tcache<:(Mooncake.Cache{_A,Nothing,Tuple{Mooncake.NoTangent,Vector{Float64}}} where {_A})}
So it seems the compiler is not able to infer this part of the Tcache
parameter:
Mooncake.DerivedRule{
Tuple{var"#13#14",Vector{Float64}},
Tuple{
Mooncake.CoDual{var"#13#14",Mooncake.NoFData},
Mooncake.CoDual{Vector{Float64},Vector{Float64}},
},
Mooncake.CoDual{Float64,Mooncake.NoFData},
Tuple{Float64},
Tuple{Mooncake.NoRData,Mooncake.NoRData},
false,
Val{2},
}
Is this too heavy on the compiler to infer, which is why it's left to compute dynamically?
Again this is probably just a question rather than an issue! Was just curious.