From 63e0e435d1416c807e36ede07cd084627f563fa5 Mon Sep 17 00:00:00 2001 From: Paolo D'Elia Date: Fri, 12 Apr 2024 17:58:22 +0200 Subject: [PATCH 1/2] #21: fixed heston gen paths --- Makefile | 9 +- jaxfin/models/heston/heston.py | 186 +++++++++++---------------------- tests/models/test_heston.py | 7 +- 3 files changed, 70 insertions(+), 132 deletions(-) diff --git a/Makefile b/Makefile index 34f1327..47f8749 100644 --- a/Makefile +++ b/Makefile @@ -4,15 +4,16 @@ LINT_PATHS=jaxfin/ tests/ pytest: python -m pytest --no-header -vv --html=test_report.html --self-contained-html +pylint: + pylint jaxfin --output-format=text:pylint_res.txt,colorized + type: mypy ${LINT_PATHS} lint: ruff check jaxfin --output-format=full -complete-lint: - lint - pylint jaxfin --output-format=text:pylint_res.txt,colorized +lint-complete: lint pylint format: isort ${LINT_PATHS} @@ -31,4 +32,4 @@ test-release: python -m build twine upload dist/* -r testpypi -.PHONY: clean spelling doc lint format check-codestyle commit-checks \ No newline at end of file +.PHONY: clean spelling doc lint format check-codestyle commit-checks pylint \ No newline at end of file diff --git a/jaxfin/models/heston/heston.py b/jaxfin/models/heston/heston.py index 9940898..66649db 100644 --- a/jaxfin/models/heston/heston.py +++ b/jaxfin/models/heston/heston.py @@ -5,7 +5,7 @@ import jax import jax.numpy as jnp -from jax import jit, random +import numpy as np from ..utils import check_symmetric @@ -136,7 +136,7 @@ def dtype(self) -> jax.numpy.dtype: return self._dtype def sample_paths( - self, seed: int, maturity: float, n: int, n_sim: int + self, maturity: float, n: int, n_sim: int ) -> Tuple[jax.Array, jax.Array]: """ Sample of paths from the Univariate Heston model @@ -150,20 +150,37 @@ def sample_paths( Returns: Tuple[jax.Array, jax.Array]: The simulated paths of the asset and the variance process """ - key = random.PRNGKey(seed) - - dt = maturity / n - - W = random.normal(key, shape=(n_sim, n)).T - Z = ( - self._rho - * W - * jnp.sqrt(1 - self._rho**2) - * random.normal(key, shape=(n_sim, n)).T - ) - - return _sample_paths( - self.s0, self.v0, self.mean, self.kappa, self.theta, self.sigma, dt, W, Z, n + dt = maturity / (n - 1) + dt_sq = np.sqrt(dt) + + assert 2 * self.kappa * self.theta > self.sigma**2 # Feller condition + + # Generate random Brownian Motions for all paths and time steps + W_1 = np.random.normal(loc=0, scale=1, size=(n_sim, n - 1)) + W_2 = np.random.normal(loc=0, scale=1, size=(n_sim, n - 1)) + W_v = W_1 + W_S = self._rho * W_1 + np.sqrt(1 - self._rho**2) * W_2 + + # Initialize arrays to store trajectories + v_paths = np.zeros((n_sim, n)) + S_paths = np.zeros((n_sim, n)) + v_paths[:, 0] = self._v0 + S_paths[:, 0] = self._s0 + + # Compute trajectories of v and S using vectorized operations + for t in range(1, n): + v_paths[:, t] = np.abs( + v_paths[:, t - 1] + + self._kappa * (self._theta - v_paths[:, t - 1]) * dt + + self._sigma * np.sqrt(v_paths[:, t - 1]) * dt_sq * W_v[:, t - 1] + ) + S_paths[:, t] = S_paths[:, t - 1] * np.exp( + (self._mean - 0.5 * v_paths[:, t - 1]) * dt + + np.sqrt(v_paths[:, t - 1]) * dt_sq * W_S[:, t - 1] + ) + + return jnp.asarray(S_paths.T, dtype=self._dtype), jnp.asarray( + v_paths.T, dtype=self._dtype ) @@ -294,7 +311,7 @@ def dtype(self) -> jax.numpy.dtype: return self._dtype def sample_paths( - self, seed: int, maturity: float, n: int, n_sim: int + self, maturity: float, n: int, n_sim: int ) -> Tuple[jax.Array, jax.Array]: """Sample paths from the Multivariate Heston model @@ -307,114 +324,31 @@ def sample_paths( Returns: Tuple[jax.Array, jax.Array]: The simulated paths and the variance processes of the assets """ - key = random.PRNGKey(seed) - dt = maturity / n - - W = random.normal(key, shape=(n_sim, n * self._dim)) - Z = random.normal(key, shape=(n_sim, n * self._dim)) - W = jnp.reshape(W, (n_sim, n, self._dim)).transpose((1, 0, 2)) - Z = jnp.reshape(Z, (n_sim, n, self._dim)).transpose((1, 0, 2)) - - L = jnp.linalg.cholesky(self._corr) - - epsilon_S = W @ L - epsilon_v = epsilon_S @ self._corr + Z @ jnp.sqrt(1 - self._corr**2) - - return _sample_paths( - self.s0, - self.v0, - self.mean, - self.kappa, - self.theta, - self.sigma, - dt, - epsilon_S, - epsilon_v, - n, + dt_sq = np.sqrt(dt) + + W_1 = np.random.normal(loc=0, scale=1, size=(n_sim, n - 1, self._dim)) + W_2 = np.random.normal(loc=0, scale=1, size=(n_sim, n - 1, self._dim)) + W_v = W_1 + W_S = W_1 @ self._corr + W_2 @ np.sqrt(1 - self._corr**2) + + v_paths = np.zeros((n_sim, n, self._dim)) + S_paths = np.zeros((n_sim, n, self._dim)) + v_paths[:, 0, :] = self._v0 + S_paths[:, 0, :] = self._s0 + + # Compute trajectories of v and S using vectorized operations + for t in range(1, n): + v_paths[:, t, :] = np.abs( + v_paths[:, t - 1, :] + + self._kappa * (self._theta - v_paths[:, t - 1, :]) * dt + + self._sigma * np.sqrt(v_paths[:, t - 1, :]) * dt_sq * W_v[:, t - 1, :] + ) + S_paths[:, t, :] = S_paths[:, t - 1, :] * np.exp( + (self._mean - 0.5 * v_paths[:, t - 1, :]) * dt + + np.sqrt(v_paths[:, t - 1, :]) * dt_sq * W_S[:, t - 1, :] + ) + + return jnp.asarray(S_paths.transpose(1, 0, 2)), jnp.asarray( + v_paths.transpose(1, 0, 2) ) - - -# Helpers - - -def _variance_process_step( - v: jax.Array, - kappa: jax.Array, - theta: jax.Array, - sigma: jax.Array, - dt: jax.Array, - dZ: jax.Array, -) -> jax.Array: - return ( - v - + kappa * (theta - jnp.maximum(v, 0.0)) * dt - + sigma * jnp.sqrt(jnp.maximum(v, 0.0)) * jnp.sqrt(dt) * dZ - ) - - -@jit -def _sample_paths( - s0: jax.Array, - v0: jax.Array, - mean: jax.Array, - kappa: jax.Array, - theta: jax.Array, - sigma: jax.Array, - dt: jax.Array, - W: jax.Array, - Z: jax.Array, - N: int, -) -> Tuple[jax.Array, jax.Array]: - """Main function that simulate the path of the Heston model leveragin the - jax.lax.while_loop function that allows to perform a loop in a jax makes - it useful for reducing compilation times for jit-compiled functions, - since native Python loop constructs in an @jit function are unrolled, - leading to large XLA computations. - - Args: - s0 (jax.Array): The initial price of the stock(s) - v0 (jax.Array): The initial variance of the stock(s) - mean (jax.Array): The mean of the stock(s) - kappa (jax.Array): The speed of the mean-reversion of the variance - theta (jax.Array): The long-term mean of the variance - sigma (jax.Array): The volatility of the variance of the stock(s) - dt (jax.Array): The time step - W (jax.Array): The Weiner process of the stock(s) - Z (jax.Array): The (correlated) Weiner process of the variance of the stock(s) - N (int): The number of steps - - Returns: - Tuple[jax.Array, jax.Array]: The simulated paths of the stock(s), and the variance process(es) - """ - - def init_fn(W, Z): - W_ = W - Z_ = Z - - S = jnp.full_like(W_, s0) - v = jnp.full_like(W_, v0) - - return 0, S, v, W_, Z_ - - def cond_fn(val): - i, *_ = val - return i < N - - def body_val(val): - i, S, v, W, Z = val - dW = W[i] - dZ = Z[i] - - S = S.at[i + 1].set( - S[i] - * jnp.exp((mean - 0.5 * v[i]) * dt + jnp.sqrt(v[i]) * jnp.sqrt(dt) * dW) - ) - - v = v.at[i + 1].set(_variance_process_step(v[i], kappa, theta, sigma, dt, dZ)) - - return i + 1, S, v, W, Z - - _, S, v, _, _ = jax.lax.while_loop(cond_fn, body_val, init_fn(W, Z)) - - return S, v diff --git a/tests/models/test_heston.py b/tests/models/test_heston.py index 8208816..f9c7b76 100644 --- a/tests/models/test_heston.py +++ b/tests/models/test_heston.py @@ -1,9 +1,12 @@ import jax.numpy as jnp +import numpy as np from jaxfin.models.heston.heston import MultiHestonModel, UnivHestonModel SEED: int = 42 +np.random.seed(SEED) + class TestUnivHestonModel: def test_init(self): @@ -42,7 +45,7 @@ def test_sample_paths(self): s0, v0, mean, kappa, theta, sigma, rho, dtype=jnp.float32 ) paths, variance_process = heston_model.sample_paths( - seed=SEED, maturity=1.0, n=100, n_sim=100 + maturity=1.0, n=100, n_sim=100 ) assert paths.shape == (100, 100) @@ -86,7 +89,7 @@ def test_sample_paths(self): s0, v0, mean, kappa, theta, sigma, corr, dtype=jnp.float32 ) paths, variance_processes = heston_model.sample_paths( - seed=SEED, maturity=1.0, n=100, n_sim=100 + maturity=1.0, n=100, n_sim=100 ) assert paths.shape == (100, 100, 2) From b34f45c7d9fbbfb4916518977add696ebd3be5dc Mon Sep 17 00:00:00 2001 From: Paolo D'Elia Date: Tue, 16 Apr 2024 13:48:56 +0200 Subject: [PATCH 2/2] #21: Used numpy instead of jax for path generation, problem with the seed --- Makefile | 7 +++++-- jaxfin/models/gbm/gbm.py | 24 +++++++++++------------- tests/models/test_gbm.py | 7 +++++-- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/Makefile b/Makefile index 47f8749..522a7d5 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ SHELL=/bin/bash LINT_PATHS=jaxfin/ tests/ +build: + python -m build + pytest: python -m pytest --no-header -vv --html=test_report.html --self-contained-html @@ -25,11 +28,11 @@ check-codestyle: commit-checks: format type lint release: - python -m build + build twine upload dist/* test-release: - python -m build + build twine upload dist/* -r testpypi .PHONY: clean spelling doc lint format check-codestyle commit-checks pylint \ No newline at end of file diff --git a/jaxfin/models/gbm/gbm.py b/jaxfin/models/gbm/gbm.py index d29a179..bc4361a 100644 --- a/jaxfin/models/gbm/gbm.py +++ b/jaxfin/models/gbm/gbm.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from jax import random +import numpy as np from ..utils import check_symmetric @@ -85,7 +85,7 @@ def dtype(self) -> jax.numpy.dtype: """ return self._dtype - def sample_paths(self, seed: int, maturity: float, n: int, n_sim: int) -> jax.Array: + def sample_paths(self, maturity: float, n: int, n_sim: int) -> jax.Array: """ Simulate a sample of paths from the Geometric Brownian Motion (GBM). @@ -97,18 +97,18 @@ def sample_paths(self, seed: int, maturity: float, n: int, n_sim: int) -> jax.Ar Returns: jax.Array: Array containing the sample paths. """ - key = random.PRNGKey(seed) - dt = maturity / n - Xt = jnp.exp( + Xt = np.exp( (self._mean - self._sigma**2 / 2) * dt - + self._sigma * jnp.sqrt(dt) * random.normal(key, shape=(n_sim, n - 1)).T + + self._sigma + * np.sqrt(dt) + * np.random.normal(loc=0, scale=1, size=(n_sim, n - 1)).T ) - Xt = jnp.vstack([jnp.ones(n_sim), Xt]) + Xt = np.vstack([np.ones(n_sim), Xt]) - return self._s0 * Xt.cumprod(axis=0) + return jnp.asarray(self._s0 * Xt.cumprod(axis=0)) class MultiGeometricBrownianMotion: @@ -207,7 +207,7 @@ def dimension(self) -> int: """ return self._dim - def sample_paths(self, seed: int, maturity: float, n: int, n_sim: int) -> jax.Array: + def sample_paths(self, maturity: float, n: int, n_sim: int) -> jax.Array: """ Simulate a sample of paths from the Multivariate Geometric Brownian Motion (GBM). @@ -219,12 +219,10 @@ def sample_paths(self, seed: int, maturity: float, n: int, n_sim: int) -> jax.Ar Returns: jax.Array: Array containing the sample paths. """ - key = random.PRNGKey(seed) - dt = maturity / n - normal_draw = random.normal(key, shape=(n_sim, n * self._dim)) - normal_draw = jnp.reshape(normal_draw, (n_sim, n, self._dim)).transpose( + normal_draw = np.random.normal(loc=0, scale=1, size=(n_sim, n * self._dim)) + normal_draw = np.reshape(normal_draw, (n_sim, n, self._dim)).transpose( (1, 0, 2) ) diff --git a/tests/models/test_gbm.py b/tests/models/test_gbm.py index 76868c1..31d6a9b 100644 --- a/tests/models/test_gbm.py +++ b/tests/models/test_gbm.py @@ -1,9 +1,12 @@ import jax.numpy as jnp +import numpy as np from jaxfin.models.gbm import MultiGeometricBrownianMotion, UnivGeometricBrownianMotion SEED: int = 42 +np.random.seed(SEED) + class TestUnivGBM: def test_init(self): @@ -25,7 +28,7 @@ def test_sim_paths_shape(self): dtype = jnp.float32 gbm = UnivGeometricBrownianMotion(s0, mean, sigma, dtype) - stock_paths = gbm.sample_paths(SEED, 1.0, 52, 100) + stock_paths = gbm.sample_paths(1.0, 52, 100) assert stock_paths.shape == (52, 100) @@ -53,6 +56,6 @@ def test_sample_path(self): corr = jnp.array([[1, 0.1], [0.1, 1]]) dtype = jnp.float32 gbm = MultiGeometricBrownianMotion(s0, mean, sigma, corr, dtype) - sample_path = gbm.sample_paths(SEED, 1.0, 52, 100) + sample_path = gbm.sample_paths(1.0, 52, 100) assert sample_path.shape == (52, 100, 2)