From 6c221f510ebddbcfb8267b8454694a27d45942d9 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Sat, 7 Oct 2023 15:46:27 -0700 Subject: [PATCH 1/3] Fix documentation --- sleap_roots/scanline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap_roots/scanline.py b/sleap_roots/scanline.py index 1d39598..b623fe9 100644 --- a/sleap_roots/scanline.py +++ b/sleap_roots/scanline.py @@ -102,7 +102,7 @@ def get_scanline_last_ind(scanline_intersection_counts: np.ndarray): Return: Scalar of count_scanline_interaction index for the last interaction. """ - # get the first scanline index using scanline_intersection_counts + # get the last scanline index using scanline_intersection_counts if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0: scanline_last_ind = np.where((scanline_intersection_counts > 0))[0][-1] return scanline_last_ind From 36127f2d7ed350363ccf9a58548979c651f956a4 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Sat, 7 Oct 2023 16:06:09 -0700 Subject: [PATCH 2/3] Import `lengths` in init --- sleap_roots/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sleap_roots/__init__.py b/sleap_roots/__init__.py index a61f078..fe5c29e 100644 --- a/sleap_roots/__init__.py +++ b/sleap_roots/__init__.py @@ -6,6 +6,7 @@ import sleap_roots.convhull import sleap_roots.ellipse import sleap_roots.networklength +import sleap_roots.lengths import sleap_roots.points import sleap_roots.scanline import sleap_roots.series @@ -16,4 +17,4 @@ # Define package version. # This is read dynamically by setuptools in pyproject.toml to determine the release version. -__version__ = "0.0.4" +__version__ = "0.0.5" From 68a7044815a2f2347281920720d9df8dbf8c8385 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Sat, 7 Oct 2023 16:07:25 -0700 Subject: [PATCH 3/3] Change gravitropism to curvature --- sleap_roots/lengths.py | 18 +++++++------- sleap_roots/trait_pipelines.py | 18 +++++++------- tests/test_lengths.py | 44 +++++++++++++++++----------------- tests/test_trait_pipelines.py | 12 +++++----- 4 files changed, 46 insertions(+), 46 deletions(-) diff --git a/sleap_roots/lengths.py b/sleap_roots/lengths.py index a46c30e..d623f30 100644 --- a/sleap_roots/lengths.py +++ b/sleap_roots/lengths.py @@ -111,13 +111,13 @@ def get_root_lengths_max(pts: np.ndarray) -> np.ndarray: return max_length -def get_grav_index( +def get_curve_index( lengths: Union[float, np.ndarray], base_tip_dists: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: - """Calculate the gravitropism index of a root. + """Calculate the curvature index of a root. - The gravitropism index quantifies the curviness of the root's growth. A higher - gravitropism index indicates a curvier root (less responsive to gravity), while a + The curvature index quantifies the curviness of the root's growth. A higher + curvature index indicates a curvier root (less responsive to gravity), while a lower index indicates a straighter root (more responsive to gravity). The index is computed as the difference between the maximum root length and straight-line distance from the base to the tip of the root, normalized by the root length. @@ -129,7 +129,7 @@ def get_grav_index( root(s). Can be a scalar or a 1D numpy array of shape `(instances,)`. Returns: - Gravitropism index of the root(s), quantifying its/their curviness. Will be a + Curvature index of the root(s), quantifying its/their curviness. Will be a scalar if input is scalar, or a 1D numpy array of shape `(instances,)` otherwise. """ @@ -144,8 +144,8 @@ def get_grav_index( if lengths.shape != base_tip_dists.shape: raise ValueError("The shapes of lengths and base_tip_dists must match.") - # Calculate the gravitropism index where possible - grav_index = np.where( + # Calculate the curvature index where possible + curve_index = np.where( (~np.isnan(lengths)) & (~np.isnan(base_tip_dists)) & (lengths > 0) @@ -156,6 +156,6 @@ def get_grav_index( # Return scalar or array based on the input type if is_scalar_input: - return grav_index.item() + return curve_index.item() else: - return grav_index + return curve_index diff --git a/sleap_roots/trait_pipelines.py b/sleap_roots/trait_pipelines.py index 6c18d55..3723f38 100644 --- a/sleap_roots/trait_pipelines.py +++ b/sleap_roots/trait_pipelines.py @@ -36,7 +36,7 @@ get_ellipse_b, get_ellipse_ratio, ) -from sleap_roots.lengths import get_grav_index, get_max_length_pts, get_root_lengths +from sleap_roots.lengths import get_curve_index, get_max_length_pts, get_root_lengths from sleap_roots.networklength import ( get_bbox, get_network_distribution, @@ -811,13 +811,13 @@ def define_traits(self) -> List[TraitDef]: description="Scalar of base median ratio.", ), TraitDef( - name="grav_index", - fn=get_grav_index, + name="curve_index", + fn=get_curve_index, input_traits=["primary_length", "primary_base_tip_dist"], scalar=True, include_in_csv=True, kwargs={}, - description="Scalar of primary root gravity index.", + description="Scalar of primary root curvature index.", ), TraitDef( name="base_length_ratio", @@ -1189,13 +1189,13 @@ def define_traits(self) -> List[TraitDef]: "tip(s) of the main root(s).", ), TraitDef( - name="main_grav_indices", + name="main_curve_indices", fn=get_base_tip_dist, input_traits=["main_base_pts", "main_tip_pts"], scalar=False, include_in_csv=True, kwargs={}, - description="Gravitropism index for each main root.", + description="Curvature index for each main root.", ), TraitDef( name="network_solidity", @@ -1291,13 +1291,13 @@ def define_traits(self) -> List[TraitDef]: "convex hull.", ), TraitDef( - name="grav_index", - fn=get_grav_index, + name="curve_index", + fn=get_curve_index, input_traits=["primary_length", "primary_base_tip_dist"], scalar=True, include_in_csv=True, kwargs={}, - description="Scalar of primary root gravity index.", + description="Scalar of primary root curvature index.", ), TraitDef( name="primary_base_tip_dist", diff --git a/tests/test_lengths.py b/tests/test_lengths.py index 1c1a3f5..6df97ed 100644 --- a/tests/test_lengths.py +++ b/tests/test_lengths.py @@ -1,5 +1,5 @@ from sleap_roots.lengths import ( - get_grav_index, + get_curve_index, get_root_lengths, get_root_lengths_max, get_max_length_pts, @@ -146,8 +146,8 @@ def lengths_all_nan(): return np.array([np.nan, np.nan, np.nan]) -# tests for get_grav_index function -def test_get_grav_index_canola(canola_h5): +# tests for get_curve_index function +def test_get_curve_index_canola(canola_h5): series = Series.load( canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" ) @@ -158,22 +158,22 @@ def test_get_grav_index_canola(canola_h5): bases = get_bases(max_length_pts) tips = get_tips(max_length_pts) base_tip_dist = get_base_tip_dist(bases, tips) - grav_index = get_grav_index(primary_length, base_tip_dist) - np.testing.assert_almost_equal(grav_index, 0.08898137324716636) + curve_index = get_curve_index(primary_length, base_tip_dist) + np.testing.assert_almost_equal(curve_index, 0.08898137324716636) -def test_get_grav_index(): +def test_get_curve_index(): # Test 1: Scalar inputs where length > base_tip_dist - # Gravitropism index should be (10 - 8) / 10 = 0.2 - assert get_grav_index(10, 8) == 0.2 + # Curvature index should be (10 - 8) / 10 = 0.2 + assert get_curve_index(10, 8) == 0.2 # Test 2: Scalar inputs where length and base_tip_dist are zero # Should return NaN as length is zero - assert np.isnan(get_grav_index(0, 0)) + assert np.isnan(get_curve_index(0, 0)) # Test 3: Scalar inputs where length < base_tip_dist # Should return NaN as it's an invalid case - assert np.isnan(get_grav_index(5, 10)) + assert np.isnan(get_curve_index(5, 10)) # Test 4: Array inputs covering various cases # Case 1: length > base_tip_dist, should return 0.2 @@ -183,13 +183,13 @@ def test_get_grav_index(): lengths = np.array([10, 0, 5, 15]) base_tip_dists = np.array([8, 0, 10, 12]) expected = np.array([0.2, np.nan, np.nan, 0.2]) - result = get_grav_index(lengths, base_tip_dists) + result = get_curve_index(lengths, base_tip_dists) assert np.allclose(result, expected, equal_nan=True) # Test 5: Mismatched shapes between lengths and base_tip_dists # Should raise a ValueError with pytest.raises(ValueError): - get_grav_index(np.array([10, 20]), np.array([8])) + get_curve_index(np.array([10, 20]), np.array([8])) # Test 6: Array inputs with NaN values # Case 1: length > base_tip_dist, should return 0.2 @@ -197,13 +197,13 @@ def test_get_grav_index(): lengths = np.array([10, np.nan, np.nan]) base_tip_dists = np.array([8, 8, np.nan]) expected = np.array([0.2, np.nan, np.nan]) - result = get_grav_index(lengths, base_tip_dists) + result = get_curve_index(lengths, base_tip_dists) assert np.allclose(result, expected, equal_nan=True) -def test_get_grav_index_shape(): +def test_get_curve_index_shape(): # Check if scalar inputs result in scalar output - result = get_grav_index(10, 8) + result = get_curve_index(10, 8) assert isinstance( result, (int, float) ), f"Expected scalar output, got {type(result)}" @@ -211,7 +211,7 @@ def test_get_grav_index_shape(): # Check if array inputs result in array output lengths = np.array([10, 15]) base_tip_dists = np.array([8, 12]) - result = get_grav_index(lengths, base_tip_dists) + result = get_curve_index(lengths, base_tip_dists) assert isinstance( result, np.ndarray ), f"Expected np.ndarray output, got {type(result)}" @@ -225,7 +225,7 @@ def test_get_grav_index_shape(): # Check the shape of output for larger array inputs lengths = np.array([10, 15, 20, 25]) base_tip_dists = np.array([8, 12, 18, 22]) - result = get_grav_index(lengths, base_tip_dists) + result = get_curve_index(lengths, base_tip_dists) assert ( result.shape == lengths.shape ), f"Output shape {result.shape} does not match input shape {lengths.shape}" @@ -235,7 +235,7 @@ def test_nan_values(): lengths = np.array([10, np.nan, 30]) base_tip_dists = np.array([8, 16, np.nan]) np.testing.assert_array_equal( - get_grav_index(lengths, base_tip_dists), np.array([0.2, np.nan, np.nan]) + get_curve_index(lengths, base_tip_dists), np.array([0.2, np.nan, np.nan]) ) @@ -243,14 +243,14 @@ def test_zero_lengths(): lengths = np.array([0, 20, 30]) base_tip_dists = np.array([0, 16, 24]) np.testing.assert_array_equal( - get_grav_index(lengths, base_tip_dists), np.array([np.nan, 0.2, 0.2]) + get_curve_index(lengths, base_tip_dists), np.array([np.nan, 0.2, 0.2]) ) def test_invalid_scalar_values(): - assert np.isnan(get_grav_index(np.nan, 8)) - assert np.isnan(get_grav_index(10, np.nan)) - assert np.isnan(get_grav_index(0, 8)) + assert np.isnan(get_curve_index(np.nan, 8)) + assert np.isnan(get_curve_index(10, np.nan)) + assert np.isnan(get_curve_index(0, 8)) # tests for `get_root_lengths` diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py index a1deac2..feec406 100644 --- a/tests/test_trait_pipelines.py +++ b/tests/test_trait_pipelines.py @@ -63,14 +63,14 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder): # Value range assertions for traits assert ( - rice_traits["grav_index"].fillna(0) >= 0 - ).all(), "grav_index in rice_traits contains negative values" + rice_traits["curve_index"].fillna(0) >= 0 + ).all(), "curve_index in rice_traits contains negative values" assert ( - all_traits["grav_index_median"] >= 0 - ).all(), "grav_index in all_traits contains negative values" + all_traits["curve_index_median"] >= 0 + ).all(), "curve_index in all_traits contains negative values" assert ( - all_traits["main_grav_indices_mean_median"] >= 0 - ).all(), "main_grav_indices_mean_median in all_traits contains negative values" + all_traits["main_curve_indices_mean_median"] >= 0 + ).all(), "main_curve_indices_mean_median in all_traits contains negative values" assert ( (0 <= rice_traits["main_angles_proximal_p95"]) & (rice_traits["main_angles_proximal_p95"] <= 180)