From 0830c454087182b880fd639d33fa8e8dc5e42da9 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 3 Jun 2025 13:08:58 +0200 Subject: [PATCH 1/2] add to_raster method --- CHANGES.md | 4 ++++ rio_tiler/models.py | 37 +++++++++++++++++++++++++++++++++++++ tests/test_models.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 3b15a16a..b446b574 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,10 @@ # Unreleased +# 7.8.0 + +* add `to_raster()` method to `ImageData` class + # 7.7.4 (2025-05-29) * fix band names for Xarray DataArray diff --git a/rio_tiler/models.py b/rio_tiler/models.py index 1c013f3a..941c56e2 100644 --- a/rio_tiler/models.py +++ b/rio_tiler/models.py @@ -6,6 +6,7 @@ import attr import numpy +import rasterio from affine import Affine from color_operations import parse_operations, scale_dtype, to_math_type from numpy.typing import NDArray @@ -719,6 +720,42 @@ def render( return render(array.data, img_format=img_format, colormap=colormap, **kwargs) + def to_raster(self, dst_path: str, *, driver, **kwargs) -> None: + """Save ImageData array to File.""" + if driver.upper() == "GTIFF": + if "transform" not in kwargs: + kwargs.update({"transform": self.transform}) + if "crs" not in kwargs and self.crs: + kwargs.update({"crs": self.crs}) + + write_nodata = "nodata" in kwargs + count, height, width = self.array.shape + + output_profile = { + "dtype": self.array.dtype, + "count": count if write_nodata else count + 1, + "height": height, + "width": width, + } + output_profile.update(kwargs) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=NotGeoreferencedWarning, + module="rasterio", + ) + with rasterio.open(dst_path, "w", driver=driver, **output_profile) as dst: + dst.write(self.data, indexes=list(range(1, count + 1))) + + # Use Mask as an alpha band + if not write_nodata: + if ColorInterp.alpha not in dst.colorinterp: + dst.colorinterp = *dst.colorinterp[:-1], ColorInterp.alpha + dst.write(self.mask.astype(self.array.dtype), indexes=count + 1) + + return + def statistics( self, categorical: bool = False, diff --git a/tests/test_models.py b/tests/test_models.py index fc17452c..5e2e169e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -534,3 +534,34 @@ def test_image_reproject(): assert reprojected.array.mask.shape[0] == 3 assert reprojected.array.mask[:, 0, 0].tolist() == [True, True, True] assert reprojected.array.mask[:, -10, -10].tolist() == [False, False, False] + + +def test_imageData_to_raster(tmp_path): + """Test ImageData to raster""" + ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster( + tmp_path / "img.tif", driver="GTiff" + ) + with rasterio.open(tmp_path / "img.tif") as src: + assert src.count == 2 + assert src.profile["driver"] == "GTiff" + + ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster( + tmp_path / "img.tif", driver="GTiff", nodata=0 + ) + with rasterio.open(tmp_path / "img.tif") as src: + assert src.count == 1 + assert src.profile["driver"] == "GTiff" + assert src.profile["nodata"] == 0 + + ImageData(numpy.zeros((3, 256, 256), dtype="uint8")).to_raster( + tmp_path / "img.tif", driver="PNG" + ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=NotGeoreferencedWarning, + module="rasterio", + ) + with rasterio.open(tmp_path / "img.tif") as src: + assert src.count == 4 + assert src.profile["driver"] == "PNG" From 9229a3eafd29f8023bfad5e61ec20021b9234db4 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 3 Jun 2025 18:13:03 +0200 Subject: [PATCH 2/2] make gtiff by default --- docs/src/models.md | 34 ++++++++++++++++++++++++++++++++++ rio_tiler/models.py | 6 +++--- tests/test_models.py | 13 +++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/docs/src/models.md b/docs/src/models.md index c89ff2ae..95c4c78c 100644 --- a/docs/src/models.md +++ b/docs/src/models.md @@ -48,6 +48,25 @@ print(ImageData(data)) - **data**: Return data part of the masked array. - **mask**: Return the mask part in form of rasterio dataset mask. +#### ClassMethods + +- **from_bytes()**: Create an ImageData instance from a Raster buffer + + ```python + with open("img.tif", "rb") as f: + img = ImageData.from_bytes(f.read()) + ``` + +- **create_from_list()**: Create ImageData from a sequence of ImageData objects. + + ```python + r = ImageData(numpy.zeros((1, 256, 256))) + g = ImageData(numpy.zeros((1, 256, 256))) + b = ImageData(numpy.zeros((1, 256, 256))) + + img = ImageData.create_from_list([r, g, b]) + ``` + #### Methods - **data_as_image()**: Return the data array reshaped into an image processing/visualization software friendly order @@ -102,6 +121,14 @@ print(ImageData(data)) assert img_r.height == 256 ``` +- **reproject()**: Reproject the ImageData to a user defined projection + + ```python + data = numpy.zeros((3, 1024, 1024), dtype="uint8") + img = ImageData(data, crs="epsg:4326", bounds=(-180, -90, 180, 90)) + img = img.reproject(dst_crs="epsg:3857") + ``` + - **post_process()**: Apply rescaling or/and `color-operations` formula to the data array. Returns a new ImageData instance. ```python @@ -307,6 +334,13 @@ Note: Starting with `rio-tiler==2.1`, when the output datatype is not valid for `rio-tiler` will automatically rescale the data using the `min/max` value for the datatype (ref: https://github.com/cogeotiff/rio-tiler/pull/391). +- **to_raster()**: Save ImageData array to raster file + + ```python + img = ImageData(numpy.zeros((1, 256, 256))) + img.to_raster("img.tif", driver="GTiff") + ``` + ## PointData !!! info "New in version 4.0" diff --git a/rio_tiler/models.py b/rio_tiler/models.py index 941c56e2..71f61a4e 100644 --- a/rio_tiler/models.py +++ b/rio_tiler/models.py @@ -720,8 +720,8 @@ def render( return render(array.data, img_format=img_format, colormap=colormap, **kwargs) - def to_raster(self, dst_path: str, *, driver, **kwargs) -> None: - """Save ImageData array to File.""" + def to_raster(self, dst_path: str, *, driver: str = "GTIFF", **kwargs: Any) -> None: + """Save ImageData array to file.""" if driver.upper() == "GTIFF": if "transform" not in kwargs: kwargs.update({"transform": self.transform}) @@ -787,7 +787,7 @@ def get_coverage_array( shape_crs: CRS = WGS84_CRS, cover_scale: int = 10, ) -> NDArray[numpy.floating]: - """Post-process image data. + """Get Coverage array for a Geometry. Args: shape (Dict): GeoJSON geometry or Feature. diff --git a/tests/test_models.py b/tests/test_models.py index 5e2e169e..274da422 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -538,6 +538,11 @@ def test_image_reproject(): def test_imageData_to_raster(tmp_path): """Test ImageData to raster""" + ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster(tmp_path / "img.tif") + with rasterio.open(tmp_path / "img.tif") as src: + assert src.count == 2 + assert src.profile["driver"] == "GTiff" + ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster( tmp_path / "img.tif", driver="GTiff" ) @@ -545,6 +550,14 @@ def test_imageData_to_raster(tmp_path): assert src.count == 2 assert src.profile["driver"] == "GTiff" + # case insensitive GTiff + ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster( + tmp_path / "img.tif", driver="gtiff" + ) + with rasterio.open(tmp_path / "img.tif") as src: + assert src.count == 2 + assert src.profile["driver"] == "GTiff" + ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster( tmp_path / "img.tif", driver="GTiff", nodata=0 )