8000 add to_raster method by vincentsarago · Pull Request #810 · cogeotiff/rio-tiler · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

add to_raster method #810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@

# Unreleased

# 7.8.0

* add `to_raster()` method to `ImageData` class

# 7.7.4 (2025-05-29)

* fix band names for Xarray DataArray
Expand Down
34 changes: 34 additions & 0 deletions docs/src/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ print(ImageData(data))
- **data**: Return data part of the masked array.
- **mask**: Return the mask part in form of rasterio dataset mask.

#### ClassMethods

- **from_bytes()**: Create an ImageData instance from a Raster buffer

```python
with open("img.tif", "rb") as f:
img = ImageData.from_bytes(f.read())
```

- **create_from_list()**: Create ImageData from a sequence of ImageData objects.

```python
r = ImageData(numpy.zeros((1, 256, 256)))
g = ImageData(numpy.zeros((1, 256, 256)))
b = ImageData(numpy.zeros((1, 256, 256)))

img = ImageData.create_from_list([r, g, b])
```

#### Methods

- **data_as_image()**: Return the data array reshaped into an image processing/visualization software friendly order
Expand Down Expand Up @@ -102,6 +121,14 @@ print(ImageData(data))
assert img_r.height == 256
```

- **reproject()**: Reproject the ImageData to a user defined projection

```python
data = numpy.zeros((3, 1024, 1024), dtype="uint8")
img = ImageData(data, crs="epsg:4326", bounds=(-180, -90, 180, 90))
img = img.reproject(dst_crs="epsg:3857")
```

- **post_process()**: Apply rescaling or/and `color-operations` formula to the data array. Returns a new ImageData instance.

```python
Expand Down Expand Up @@ -307,6 +334,13 @@ Note: Starting with `rio-tiler==2.1`, when the output datatype is not valid for
`rio-tiler` will automatically rescale the data using the `min/max` value for the datatype (ref: https://github.com/cogeotiff/rio-tiler/pull/391).


- **to_raster()**: Save ImageData array to raster file

```python
img = ImageData(numpy.zeros((1, 256, 256)))
img.to_raster("img.tif", driver="GTiff")
```

## PointData

!!! info "New in version 4.0"
Expand Down
39 changes: 38 additions & 1 deletion rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import attr
import numpy
import rasterio
from affine import Affine
from color_operations import parse_operations, scale_dtype, to_math_type
from numpy.typing import NDArray
Expand Down Expand Up @@ -719,6 +720,42 @@ def render(

return render(array.data, img_format=img_format, colormap=colormap, **kwargs)

def to_raster(self, dst_path: str, *, driver: str = "GTIFF", **kwargs: Any) -> None:
"""Save ImageData array to file."""
if driver.upper() == "GTIFF":
if "transform" not in kwargs:
kwargs.update({"transform": self.transform})
if "crs" not in kwargs and self.crs:
kwargs.update({"crs": self.crs})

write_nodata = "nodata" in kwargs
count, height, width = self.array.shape

output_profile = {
"dtype": self.array.dtype,
"count": count if write_nodata else count + 1,
"height": height,
"width": width,
}
output_profile.update(kwargs)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with rasterio.open(dst_path, "w", driver=driver, **output_profile) as dst:
dst.write(self.data, indexes=list(range(1, count + 1)))

# Use Mask as an alpha band
if not write_nodata:
if ColorInterp.alpha not in dst.colorinterp:
dst.colorinterp = *dst.colorinterp[:-1], ColorInterp.alpha
dst.write(self.mask.astype(self.array.dtype), indexes=count + 1)

return

def statistics(
self,
categorical: bool = False,
Expand Down Expand Up @@ -750,7 +787,7 @@ def get_coverage_array(
shape_crs: CRS = WGS84_CRS,
cover_scale: int = 10,
) -> NDArray[numpy.floating]:
"""Post-process image data.
"""Get Coverage array for a Geometry.

Args:
shape (Dict): GeoJSON geometry or Feature.
Expand Down
44 changes: 44 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,47 @@ def test_image_reproject():
assert reprojected.array.mask.shape[0] == 3
assert reprojected.array.mask[:, 0, 0].tolist() == [True, True, True]
assert reprojected.array.mask[:, -10, -10].tolist() == [False, False, False]


def test_imageData_to_raster(tmp_path):
"""Test ImageData to raster"""
ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster(tmp_path / "img.tif")
with rasterio.open(tmp_path / "img.tif") as src:
assert src.count == 2
assert src.profile["driver"] == "GTiff"

ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster(
tmp_path / "img.tif", driver="GTiff"
)
with rasterio.open(tmp_path / "img.tif") as src:
assert src.count == 2
assert src.profile["driver"] == "GTiff"

# case insensitive GTiff
ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster(
tmp_path / "img.tif", driver="gtiff"
)
with rasterio.open(tmp_path / "img.tif") as src:
assert src.count == 2
assert src.profile["driver"] == "GTiff"

ImageData(numpy.zeros((1, 256, 256), dtype="float32")).to_raster(
tmp_path / "img.tif", driver="GTiff", nodata=0
)
with rasterio.open(tmp_path / "img.tif") as src:
assert src.count == 1
assert src.profile["driver"] == "GTiff"
assert src.profile["nodata"] == 0

ImageData(numpy.zeros((3, 256, 256), dtype="uint8")).to_raster(
tmp_path / "img.tif", driver="PNG"
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with rasterio.open(tmp_path / "img.tif") as src:
assert src.count == 4
assert src.profile["driver"] == "PNG"
0