8000 Type stability of `prepare_gradient` in DifferentiationInterface · Issue #611 · chalk-lab/Mooncake.jl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Type stability of prepare_gradient in DifferentiationInterface #611
Open
@MilesCranmer

Description

@MilesCranmer

I think this is more of a question rather than an issue. I was curious if the type stability of prepare_gradient could be improved somehow. Here's an example of where I see some type stability issues come up:

julia> foo(x) = sum(x .^ 2)
foo (generic function with 1 method)

julia> Test.@inferred prepare_gradient(f, AutoMooncake(;config=nothing), rand(32))
ERROR: return type DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{Tuple{var"#13#14", AutoMooncake{Nothing}, Vector{Float64}, Tuple{}}, Mooncake.Cache{Mooncake.DerivedRule{Tuple{var"#13#14", Vector{Float64}}, Tuple{Mooncake.CoDual{var"#13#14", Mooncake.NoFData}, Mooncake.CoDual{Vector{Float64}, Vector{Float64}}}, Mooncake.CoDual{Float64, Mooncake.NoFData}, Tuple{Float64}, Tuple{Mooncake.NoRData, Mooncake.NoRData}, false, Val{2}}, Nothing, Tuple{Mooncake.NoTangent, Vector{Float64}}}} does not match inferred return type DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{Tuple{var"#13#14", AutoMooncake{Nothing}, Vector{Float64}, Tuple{}}, Tcache} where Tcache<:(Mooncake.Cache{_A, Nothing, Tuple{Mooncake.NoTangent, Vector{Float64}}} where _A)
Stacktrace:

Here's a formatted version of those types:

Return type:

DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{
    Tuple{var"#13#14",AutoMooncake{Nothing},Vector{Float64},Tuple{}},
    Mooncake.Cache{
        Mooncake.DerivedRule{
            Tuple{var"#13#14",Vector{Float64}},
            Tuple{
                Mooncake.CoDual{var"#13#14",Mooncake.NoFData},
                Mooncake.CoDual{Vector{Float64},Vector{Float64}},
            },
            Mooncake.CoDual{Float64,Mooncake.NoFData},
            Tuple{Float64},
            Tuple{Mooncake.NoRData,Mooncake.NoRData},
            false,
            Val{2},
        },
        Nothing,
        Tuple{Mooncake.NoTangent,Vector{Float64}},
    },
}

Inferred type:

DifferentiationInterfaceMooncakeExt.MooncakeGradientPrep{
    Tuple{var"#13#14",AutoMooncake{Nothing},Vector{Float64},Tuple{}},
    Tcache
} where {Tcache<:(Mooncake.Cache{_A,Nothing,Tuple{Mooncake.NoTangent,Vector{Float64}}} where {_A})}

So it seems the compiler is not able to infer this part of the Tcache parameter:

Mooncake.DerivedRule{
    Tuple{var"#13#14",Vector{Float64}},
    Tuple{
        Mooncake.CoDual{var"#13#14",Mooncake.NoFData},
        Mooncake.CoDual{Vector{Float64},Vector{Float64}},
    },
    Mooncake.CoDual{Float64,Mooncake.NoFData},
    Tuple{Float64},
    Tuple{Mooncake.NoRData,Mooncake.NoRData},
    false,
    Val{2},
}

Is this too heavy on the compiler to infer, which is why it's left to compute dynamically?

Again this is probably just a question rather than an issue! Was just curious.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancement (performance)Would reduce the time it takes to run some bit of the code

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0