From ef659eb799ca01c1729e27722a6638a750da9ca0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Mar 2025 12:48:51 -0400 Subject: [PATCH 1/4] chore: bump version for release --- lib/LuxLib/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index c9e47b526..01e0f2231 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.6.3" +version = "1.7.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 3b0c102695b8133be1431f6b34bd4744823574c8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Mar 2025 13:22:00 -0400 Subject: [PATCH 2/4] ci: disable benchmarks for now [skip ci] --- .buildkite/pipeline.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index b76c1243d..314e14fcf 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -129,5 +129,6 @@ steps: # Documentation buildkite-agent pipeline upload .buildkite/documentation.yml - # Benchmarks - buildkite-agent pipeline upload .buildkite/benchmarks.yml + # Disable benchmarks for now + # # Benchmarks + # buildkite-agent pipeline upload .buildkite/benchmarks.yml From d5dfb0c5a5a285eaf026d508e7be22c24b170674 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 26 Mar 2025 13:41:26 -0400 Subject: [PATCH 3/4] docs: refer to the Turing docs for BayesianNN (#1272) * CompatHelper: bump compat for Turing to 0.37 for package BayesianNN, (keep existing compat) * docs: link the official Turing docs --------- Co-authored-by: CompatHelper Julia Co-authored-by: Avik Pal --- docs/src/.vitepress/config.mts | 4 - docs/src/tutorials/index.md | 12 +- docs/tutorials.jl | 2 +- examples/BayesianNN/Project.toml | 19 --- examples/BayesianNN/main.jl | 208 +------------------------------ 5 files changed, 10 insertions(+), 235 deletions(-) delete mode 100644 examples/BayesianNN/Project.toml diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 007748aba..1f0b6f4f9 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -226,10 +226,6 @@ export default defineConfig({ text: "MNIST Classification using Neural ODEs", link: "/tutorials/intermediate/1_NeuralODE", }, - { - text: "Bayesian Neural Network", - link: "/tutorials/intermediate/2_BayesianNN", - }, { text: "Training a HyperNetwork on MNIST and FashionMNIST", link: "/tutorials/intermediate/3_HyperNet", diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index f7e17c169..341a55ec1 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -48,12 +48,6 @@ const intermediate = [ caption: "MNIST Classification using Neural ODE", desc: "Train a Neural Ordinary Differential Equations to classify MNIST Images." }, - { - href: "intermediate/2_BayesianNN", - src: "https://github.com/TuringLang.png", - caption: "Bayesian Neural Networks", - desc: "Figure out how to use Probabilistic Programming Frameworks like Turing with Lux." - }, { href: "intermediate/3_HyperNet", src: "../hypernet.jpg", @@ -129,6 +123,12 @@ const third_party = [ caption: "GPU-Accelerated Physics-Informed Neural Networks", desc: "Use Machine Learning (PINNs) to solve the Heat Equation PDE on a GPU." }, + { + href: "https://turinglang.org/docs/tutorials/bayesian-neural-networks/", + src: "https://github.com/TuringLang.png", + caption: "Bayesian Neural Networks", + desc: "Figure out how to use Probabilistic Programming Frameworks like Turing with Lux." + }, { href: "https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode_weather_forecast/", src: "../weather-neural-ode.gif", diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 7559fe47e..ce1d45ba6 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -8,7 +8,7 @@ const BEGINNER_TUTORIALS = [ ] const INTERMEDIATE_TUTORIALS = [ "NeuralODE/main.jl" => "CUDA", - "BayesianNN/main.jl" => "CPU", + "BayesianNN/main.jl" => "CPU", # This is an empty tutorial, left to redirect to Turing "HyperNet/main.jl" => "CUDA", "PINN2DPDE/main.jl" => "CUDA", "ConvolutionalVAE/main.jl" => "CUDA", diff --git a/examples/BayesianNN/Project.toml b/examples/BayesianNN/Project.toml deleted file mode 100644 index d15a0eaae..000000000 --- a/examples/BayesianNN/Project.toml +++ /dev/null @@ -1,19 +0,0 @@ -[deps] -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[compat] -CairoMakie = "0.12, 0.13" -Functors = "0.4, 0.5" -LinearAlgebra = "1" -Lux = "1.2" -Random = "1" -Tracker = "0.2.37" -Turing = "0.34, 0.35, 0.36" -Zygote = "0.6.69, 0.7" diff --git a/examples/BayesianNN/main.jl b/examples/BayesianNN/main.jl index 3bc500dae..6372fa175 100644 --- a/examples/BayesianNN/main.jl +++ b/examples/BayesianNN/main.jl @@ -1,207 +1,5 @@ # # Bayesian Neural Network -# We borrow this tutorial from the -# [official Turing Docs](https://turinglang.org/docs/tutorials/03-bayesian-neural-network/index.html). -# We will show how the explicit parameterization of Lux enables first-class composability -# with packages which expect flattened out parameter vectors. - -# Note: The tutorial in the official Turing docs is now using Lux instead of Flux. - -# We will use [Turing.jl](https://turinglang.org/) with [Lux.jl](https://lux.csail.mit.edu/) -# to implement implementing a classification algorithm. Lets start by importing the relevant -# libraries. - -## Import libraries - -using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra - -## Sampling progress -Turing.setprogress!(true); - -# ## Generating data - -# Our goal here is to use a Bayesian neural network to classify points in an artificial -# dataset. The code below generates data points arranged in a box-like pattern and displays -# a graph of the dataset we'll be working with. - -## Number of points to generate -N = 80 -M = round(Int, N / 4) -rng = Random.default_rng() -Random.seed!(rng, 1234) - -## Generate artificial data -x1s = rand(rng, Float32, M) * 4.5f0; -x2s = rand(rng, Float32, M) * 4.5f0; -xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M]) -x1s = rand(rng, Float32, M) * 4.5f0; -x2s = rand(rng, Float32, M) * 4.5f0; -append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M])) - -x1s = rand(rng, Float32, M) * 4.5f0; -x2s = rand(rng, Float32, M) * 4.5f0; -xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M]) -x1s = rand(rng, Float32, M) * 4.5f0; -x2s = rand(rng, Float32, M) * 4.5f0; -append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M])) - -## Store all the data for later -xs = [xt1s; xt0s] -ts = [ones(2 * M); zeros(2 * M)] - -## Plot data points - -function plot_data() - x1 = first.(xt1s) - y1 = last.(xt1s) - x2 = first.(xt0s) - y2 = last.(xt0s) - - fig = Figure() - ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y") - - scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2) - scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2) - - return fig -end - -plot_data() - -# ## Building the Neural Network - -# The next step is to define a feedforward neural network where we express our parameters as -# distributions, and not single points as with traditional neural networks. For this we will -# use `Dense` to define liner layers and compose them via `Chain`, both are neural network -# primitives from `Lux`. The network `nn` we will create will have two hidden layers with -# `tanh` activations and one output layer with `sigmoid` activation, as shown below. - -# The `nn` is an instance that acts as a function and can take data, parameters and current -# state as inputs and output predictions. We will define distributions on the neural network -# parameters. - -## Construct a neural network using Lux -nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid)) - -## Initialize the model weights and state -ps, st = Lux.setup(rng, nn) - -Lux.parameterlength(nn) # number of parameters in NN - -# The probabilistic model specification below creates a parameters variable, which has IID -# normal variables. The parameters represents all parameters of our neural net (weights and -# biases). - -## Create a regularization term and a Gaussian prior variance term. -alpha = 0.09 -sig = sqrt(1.0 / alpha) - -# Construct named tuple from a sampled parameter vector. We could also use ComponentArrays -# here and simply broadcast to avoid doing this. But let's do it this way to avoid -# dependencies. -function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) - @assert length(ps_new) == Lux.parameterlength(ps) - i = 1 - function get_ps(x) - z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) - i += length(x) - return z - end - return fmap(get_ps, ps) -end - -# To interface with external libraries it is often desirable to use the -# [`StatefulLuxLayer`](@ref) to automatically handle the neural network states. -const model = StatefulLuxLayer{true}(nn, nothing, st) - -## Specify the probabilistic model. -@model function bayes_nn(xs, ts) - ## Sample the parameters - nparameters = Lux.parameterlength(nn) - parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters)))) - - ## Forward NN to make predictions - preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps)) - - ## Observe each prediction. - for i in eachindex(ts) - ts[i] ~ Bernoulli(preds[i]) - end -end - -# Inference can now be performed by calling sample. We use the HMC sampler here. - -## Perform inference. -N = 5000 -ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N) - -# Now we extract the parameter samples from the sampled chain as θ (this is of size -# `5000 x 20` where `5000` is the number of iterations and `20` is the number of -# parameters). We'll use these primarily to determine how good our model's classifier is. - -## Extract all weight and bias parameters. -θ = MCMCChains.group(ch, :parameters).value; - -# ## Prediction Visualization - -## A helper to run the nn through data `x` using parameters `θ` -nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps)) - -## Plot the data we have. -fig = plot_data() - -## Find the index that provided the highest log posterior in the chain. -_, i = findmax(ch[:lp]) - -## Extract the max row value from i. -i = i.I[1] - -## Plot the posterior distribution with a contour plot -x1_range = collect(range(-6; stop=6, length=25)) -x2_range = collect(range(-6; stop=6, length=25)) -Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range] -contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright) -fig - -# The contour plot above shows that the MAP method is not too bad at classifying our data. -# Now we can visualize our predictions. - -# $p(\tilde{x} | X, \alpha) = \int_{\theta} p(\tilde{x} | \theta) p(\theta | X, \alpha) \approx \sum_{\theta \sim p(\theta | X, \alpha)}f_{\theta}(\tilde{x})$ - -# The `nn_predict` function takes the average predicted value from a network parameterized -# by weights drawn from the MCMC chain. - -## Return the average predicted value across multiple weights. -nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num]) - -# Next, we use the `nn_predict` function to predict the value at a sample of points where -# the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a -# satisfactory fit to our data, and more importantly, we can also see where the neural -# network is uncertain about its predictions much easier---those regions between cluster -# boundaries. - -# Plot the average prediction. -fig = plot_data() - -n_end = 1500 -x1_range = collect(range(-6; stop=6, length=25)) -x2_range = collect(range(-6; stop=6, length=25)) -Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range] -contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright) -fig - -# Suppose we are interested in how the predictive power of our Bayesian neural network -# evolved between samples. In that case, the following graph displays an animation of the -# contour plot generated from the network weights in samples 1 to 5,000. - -fig = plot_data() -Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range] -c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright) -record(fig, "results.gif", 1:250:size(θ, 1)) do i - fig.current_axis[].title = "Iteration: $i" - Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range] - c[3] = Z - return fig -end - -# ![](results.gif) +# This tutorial has been upstreamed from Lux to the official Turing Documentation. +# See https://turinglang.org/docs/tutorials/bayesian-neural-networks/ for the updated +# version. From a7b1f112a8bb206049c1e93fa987bc8666f3cdf8 Mon Sep 17 00:00:00 2001 From: "Helmut H. Strey" Date: Wed, 26 Mar 2025 14:23:41 -0400 Subject: [PATCH 4/4] feat: support for ForwardDiff training (#1273) * added extension for ForwardDiff * moved compute_gradients_imp ForwardDiff dispatch to /helpers/training/jl * removed LuxForwardDiffExt from Project.toml * Update src/helpers/training.jl * Update src/helpers/training.jl * added test for ForwardDiff training * removed () * created new testitem for ForwardDiff and added ForwardDiff Limitation to docstring * added test condition at the end of ForwardDiff test, and reduced reduced function calls * feat: use caching to reduce memory allocations * Apply suggestions from code review --------- Co-authored-by: Helmut Strey Co-authored-by: Avik Pal Co-authored-by: Avik Pal --- Project.toml | 4 +- src/Lux.jl | 4 +- src/helpers/forwarddiff_training.jl | 91 +++++++++++++++++++++++++++++ src/helpers/training.jl | 3 + test/helpers/training_tests.jl | 50 ++++++++++++++++ 5 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 src/helpers/forwarddiff_training.jl diff --git a/Project.toml b/Project.toml index afc2b117c..2842f403a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.10.1" +version = "1.11.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -11,6 +11,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" @@ -82,6 +83,7 @@ ChainRulesCore = "1.25" Compat = "4.16" ComponentArrays = "0.15.22" ConcreteStructs = "0.2.3" +DiffResults = "1.1" DispatchDoctor = "0.4.12" Enzyme = "0.13.35" EnzymeCore = "0.8.8" diff --git a/src/Lux.jl b/src/Lux.jl index 5255db7f9..f8aaaa3fa 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -82,6 +82,7 @@ include("extended_ops.jl") # Training Helpers include("helpers/optimizers.jl") include("helpers/training.jl") +include("helpers/forwarddiff_training.jl") # Experimental include("contrib/contrib.jl") @@ -155,7 +156,8 @@ export Training export jacobian_vector_product, vector_jacobian_product export batched_jacobian -export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export AutoEnzyme, + AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote, AutoForwardDiff export BinaryCrossEntropyLoss, BinaryFocalLoss, diff --git a/src/helpers/forwarddiff_training.jl b/src/helpers/forwarddiff_training.jl new file mode 100644 index 000000000..ef7b9ddab --- /dev/null +++ b/src/helpers/forwarddiff_training.jl @@ -0,0 +1,91 @@ +using ADTypes: AutoForwardDiff +using DiffResults: DiffResults +using ForwardDiff: ForwardDiff +using Setfield: @set! +using Static: True, False + +function Training.compute_gradients_impl( + ad::AutoForwardDiff, obj_fn::F, data, ts::Training.TrainState +) where {F} + @assert ts.parameters isa AbstractArray "AutoForwardDiff only supports AbstractArray \ + parameters, not $(typeof(ts.parameters)). To \ + convert the parameter structure to an array \ + use `ComponentArray(ps)`." + + obj_fn_wrap, st_wrap, stats_wrap = Training.wrap_objective_function( + obj_fn, ts.model, ts.parameters, ts.states, data, True() + ) + + gradient_result = DiffResults.GradientResult(ts.parameters) + ForwardDiff.gradient!( + gradient_result, ps -> obj_fn_wrap(ts.model, ps, ts.states, data), ts.parameters + ) + + cache = Training.TrainingBackendCache( + ad, False(), gradient_result, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap) + ) + @set! ts.cache = cache + @set! ts.objective_function = obj_fn + @set! ts.states = st_wrap[] + return ( + DiffResults.gradient(gradient_result), + DiffResults.value(gradient_result), + stats_wrap[], + ts, + ) +end + +const FORWARDDIFF_CACHE_TYPE = Training.TrainingBackendCache{ + <:AutoForwardDiff,False,PS,<:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)} +} where {PS} + +function Training.compute_gradients_impl( + ::AutoForwardDiff, obj_fn::F, data, ts::Training.TrainState{<:FORWARDDIFF_CACHE_TYPE,F} +) where {F} + gradient_result = ts.cache.dparameters + + ForwardDiff.gradient!( + gradient_result, + ps -> ts.cache.extras.obj_fn(ts.model, ps, ts.states, data), + ts.parameters, + ) + + @set! ts.objective_function = obj_fn + @set! ts.states = ts.cache.extras.st_wrap[] + + return ( + DiffResults.gradient(gradient_result), + DiffResults.value(gradient_result), + ts.cache.extras.stats_wrap[], + ts, + ) +end + +function Training.compute_gradients_impl( + ::AutoForwardDiff, + obj_fn::F, + data, + ts::Training.TrainState{<:Training.TrainingBackendCache{<:AutoForwardDiff,False}}, +) where {F} + @warn "Detected calls to `compute_gradients(::AutoForwardDiff, ...)` with objective \ + function that is changing across function calls. This can lead to the \ + generation of slow code" maxlog = 1 + gradient_result = ts.cache.dparameters + + # We do exactly same thing as the first case but without caching the function + obj_fn_wrap, st_wrap, stats_wrap = Training.wrap_objective_function( + obj_fn, ts.model, ts.parameters, ts.states, data, False() + ) + + ForwardDiff.gradient!( + gradient_result, ps -> obj_fn_wrap(ts.model, ps, ts.states, data), ts.parameters + ) + + @set! ts.states = st_wrap[] + return ( + DiffResults.gradient(gradient_result), + DiffResults.value(gradient_result), + stats_wrap[], + ts, + ) +end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 7a9438904..537547a59 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -160,6 +160,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`. | `AutoReverseDiff(; compile)` | `ReverseDiff.jl` | | `AutoTracker` | `Tracker.jl` | | `AutoEnzyme` | `Enzyme.jl` | +| `AutoForwardDiff` | | ## Arguments @@ -185,6 +186,8 @@ A 4-Tuple containing: - `AutoReverseDiff(; compile=true)` is not supported for Lux models with non-empty state `st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these issues in most cases and throw an error. + - AutoForwardDiff only works with parameters that are AbstractArrays + (e.g. ps=ComponentVector(ps)) !!! danger "Aliased Gradients" diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index a8a67532a..338af8aab 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -139,6 +139,56 @@ end end end +@testitem "Training API ForwardDiff" setup = [SharedTestSetup] tags = [:misc] begin + using ADTypes, Optimisers, ComponentArrays + + mse = MSELoss() + + rng = StableRNG(12345) + + x_data = randn(rng, Float32, 4, 32) + y_data = evalpoly.(x_data, ((1, 2, 3),)) .- evalpoly.(x_data, ((5, 2),)) + y_data = (y_data .- minimum(y_data)) ./ (maximum(y_data) - minimum(y_data)) + dataset = [(x_data[:, i], y_data[:, i]) for i in Iterators.partition(1:32, 8)] + + model = Chain( + Dense(4, 32, tanh), BatchNorm(32), Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4) + ) + + dataset_ = [(x, y) for (x, y) in dataset] + opt = Adam(0.001f0) + + ps, st = Lux.setup(rng, model) + tstate = Training.TrainState(model, ComponentVector(ps), st, opt) + + initial_loss = first( + mse(model, tstate.parameters, Lux.testmode(tstate.states), dataset_[1]) + ) + + for epoch in 1:100, (x, y) in dataset_ + grads, loss, _, tstate = allow_unstable() do + Training.compute_gradients(AutoForwardDiff(), mse, (x, y), tstate) + end + tstate = Training.apply_gradients!(tstate, grads) + end + + for epoch in 1:100, (x, y) in dataset_ + grads, loss, _, tstate = allow_unstable() do + Training.single_train_step!(AutoForwardDiff(), mse, (x, y), tstate) + end + end + + for epoch in 1:100, (x, y) in dataset_ + grads, loss, _, tstate = allow_unstable() do + Training.single_train_step(AutoForwardDiff(), mse, (x, y), tstate) + end + end + + final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) + + @test final_loss * 50 < initial_loss +end + @testitem "Enzyme: Invalidate Cache on State Update" setup = [SharedTestSetup] tags = [ :misc ] skip = :(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin