8000 cleanup docs, add post_conversion by d-chambers · Pull Request #6 · DASDAE/unidas · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

cleanup docs, add post_conversion #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

[![coverage](https://codecov.io/gh/dasdae/unidas/branch/main/graph/badge.svg)](https://codecov.io/gh/dasdae/unidas)
[![PyPI Version](https://img.shields.io/pypi/v/unidas.svg)](https://pypi.python.org/pypi/unidas)
[![supported versions](https://img.shields.io/pypi/pyversions/unidas.svg?label=python_versions)](https://pypi.python.org/pypi/unidas)
[![Licence](https://img.shields.io/badge/license-MIT-blue)](https://opensource.org/license/mit)

A DAS compatibility package.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages = ["src/unidas.py"]
[project]
name = "unidas"

version = "0.0.0" # Make sure to bump dascore.__version__ as well!
version = "0.0.1" # Make sure to bump dascore.__version__ as well!

authors = [
{ name="Derrick Chambers", email="chambers.ja.derrick@gmail.com" },
Expand Down
55 changes: 35 additions & 20 deletions src/unidas.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""
Core functionality for unidas.

Currently, the base representation is a dictionary of the following form.
Unidas: A DAS Compatibility Package.
"""

from __future__ import annotations

# Unidas version indicator. When incrementing, be sure to update
# pyproject.toml as well.
__version__ = "0.0.0"
__version__ = "0.0.1"

# Explicitly defines unidas' public API.
# https://peps.python.org/pep-0008/#public-and-internal-interfaces
Expand All @@ -19,7 +17,7 @@
import inspect
import zoneinfo
from collections import defaultdict, deque
from collections.abc import Mapping, Sequence, Sized
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import cache, wraps
from types import ModuleType
Expand All @@ -35,6 +33,7 @@
"xdas": "https://github.com/xdas-dev/xdas",
}

# Datetime precision. This can change between python versions.
DT_PRECISION = datetime.datetime.resolution.total_seconds()

# A generic type variable.
Expand Down Expand Up @@ -82,7 +81,7 @@ def optional_import(package_name: str) -> ModuleType:

def converts_to(target: str):
"""
A decorator which marks a method as conversion function.
Marks a method on a `Converter` as a conversion function.

Parameters
----------
Expand All @@ -99,8 +98,12 @@ def decorator(func):
return decorator


def get_object_key(object_class):
"""Get the tuple which defines the objects unique id."""
def get_class_key(object_class) -> str:
"""
Get a string which defines the class's identifier.

The general format is "{package_name}.{class_name}".
"""
module_name = object_class.__module__.split(".")[0]
class_name = object_class.__name__
return f"{module_name}.{class_name}"
Expand All @@ -114,7 +117,7 @@ def extract_attrs(obj, attrs_names):


def time_to_float(obj):
"""Converts a datetime or numpy datetime object to float."""
"""Converts a datetime or numpy datetime object to a float (timestamp)."""
if isinstance(obj, np.datetime64) or isinstance(obj, np.timedelta64):
obj = obj.astype("timedelta64") / np.timedelta64(1, "s")
elif hasattr(obj, "timestamp"):
Expand All @@ -137,13 +140,13 @@ def time_to_datetime(obj):


def to_stripped_utc(time: datetime.datetime):
"""Convert a datetime to UTC then strip timezone info"""
"""Convert a datetime to UTC then strip timezone info."""
out = time.astimezone(zoneinfo.ZoneInfo("UTC")).replace(tzinfo=None)
return out


@runtime_checkable
class ArrayLike(Protocol, Sized):
class ArrayLike(Protocol):
"""
Simple definition of an array for now.
"""
Expand Down Expand Up @@ -183,7 +186,7 @@ def to_xdas_coord(self):
@dataclass
class EvenlySampledCoordinate(Coordinate):
"""
A coordinate which is evenly sampled and contiguous.
A coordinate which is evenly sampled, sorted, and contiguous.

Parameters
----------
Expand Down Expand Up @@ -248,6 +251,8 @@ class ArrayCoordinate(Coordinate):
"""
A coordinate which is not evenly sampled and contiguous.

The coordinate is represented by a generic array.

Parameters
----------
data
Expand Down Expand Up @@ -345,7 +350,7 @@ class Converter:
conversion methods with the `converts_to` decorator.
"""

name: str = None # should be "{module}.{class_name}" see get_object_key.
name: str = None # should be "{module}.{class_name}" see get_class_key.
_registry: ClassVar[dict[str, Converter]] = {}
_graph: ClassVar[dict[str, list[str]]] = defaultdict(list)
_converters: ClassVar[dict[str, callable]] = {}
Expand Down Expand Up @@ -378,7 +383,8 @@ def post_conversion(self, input_obj: T, output_obj: T) -> T:

Some conversions are lossy. This optional method allows subclasses
to modify the output of `convert` before it gets returned. This might
be useful to re-attach lost metadata for example.
be useful to re-attach lost metadata for example. It doesn't work with
the `convert` function (in that case it needs to be applied manually).

Parameters
----------
Expand Down Expand Up @@ -425,6 +431,8 @@ def get_shortest_path(cls, start, target):
path.append(current)
current = visited[current]
return tuple(path[::-1])
# TODO: Maybe add a check for DASBase here so that is tried
# before other potential conversion paths.
for neighbor in graph[current]:
if neighbor not in visited:
visited[neighbor] = current
Expand Down Expand Up @@ -711,12 +719,19 @@ def _decorator(obj, *args, **kwargs):
# Convert the incoming object to target. This should do nothing
# if it is already the correct format.
cls = obj if inspect.isclass(obj) else type(obj)
key = get_object_key(cls)
key = get_class_key(cls)
conversion_class: Converter = Converter._registry[key]
input_obj = convert(obj, to)
out = func(input_obj, *args, **kwargs)
output_obj = convert(out, key)

return output_obj
func_out = func(input_obj, *args, **kwargs)
cls_out = obj if inspect.isclass(func_out) else type(func_out)
# Sometimes a function can return a different type than its input
# e.g., a dataframe. In this case just return output.
if get_class_key(cls_out) != to:
return func_out
output_obj = convert(func_out, key)
# Apply class specific logic to compensate for lossy conversion.
out = conversion_class.post_conversion(input_obj, output_obj)
return out

# Following the convention of pydantic, we attach the raw function
# in case it needs to be accessed later. Also ensures to keep the
Expand Down Expand Up @@ -748,7 +763,7 @@ def convert(obj, to: str):
The input object converted to the specified format.
"""
obj_class = obj if inspect.isclass(obj) else type(obj)
key = get_object_key(obj_class)
key = get_class_key(obj_class)
# No conversion needed, simply return object.
if key == to:
return obj
Expand Down
67 changes: 36 additions & 31 deletions test/test_unidas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dascore as dc
import daspy
import numpy as np
import pandas as pd
import pytest
import unidas
from unidas import BaseDAS, adapter, convert, optional_import
Expand Down Expand Up @@ -73,34 +74,6 @@ def test_version(self):
# --------- Tests for unidas conversions.


class TestFormatConversionCombinations:
"""Tests for combinations of different formats."""

# Note: we could also parametrize the base structure fixtures to make
# all of these one test, but then it can get confusing to debug so
# I am making one test for each format that then tests converting to
# all other formats.
def test_convert_blast(self, lightguide_blast, format_name):
"""Test that the base blast can be converted to all formats."""
out = convert(lightguide_blast, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])

def test_convert_patch(self, dascore_patch, format_name):
"""Test that the base patch can be converted to all formats."""
out = convert(dascore_patch, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])

def test_convert_data_array(self, xdas_dataarray, format_name):
"""Test that the base data array can be converted to all formats."""
out = convert(xdas_dataarray, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])

def test_convert_section(self, daspy_section, format_name):
"""Test that the base section can be converted to all formats."""
out = convert(daspy_section, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])


class TestDASCorePatch:
"""Test suite for converting DASCore Patches."""

Expand Down Expand Up @@ -128,6 +101,11 @@ def test_to_xdas_time_coord(self, dascore_patch):
time_coord2 = out.coords["time"].values
assert np.all(time_coord1 == time_coord2)

def test_convert_patch_to_other(self, dascore_patch, format_name):
"""Test that the base patch can be converted to all formats."""
out = convert(dascore_patch, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])


class TestDASPySection:
"""Test suite for converting DASPy sections."""
Expand All @@ -145,10 +123,16 @@ def test_from_base_das(self, daspy_base_das, daspy_section):
"""Ensure the default section can round-trip."""
out = convert(daspy_base_das, "daspy.Section")
# TODO these objects aren't equal but their strings are.
# Need to fix this.
# We need to fix this.
# assert out == daspy_section
assert str(out) == str(daspy_section)
assert np.all(out.data == daspy_section.data)

def test_convert_section(self, daspy_section, format_name):
"""Test that the base section can be converted to all formats."""
out = convert(daspy_section, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])


class TestXdasDataArray:
"""Tests for converting xdas DataArrays."""
Expand All @@ -162,12 +146,17 @@ def test_to_base_das(self, xdas_base_das):
"""Ensure the example data_array can be converted to BaseDAS."""
assert isinstance(xdas_base_das, BaseDAS)

def test_convert_data_array_to_other(self, xdas_dataarray, format_name):
"""Test that the base data array can be converted to all formats."""
out = convert(xdas_dataarray, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])

def test_from_base_das(self, xdas_base_das, xdas_dataarray):
"""Ensure xdas DataArray can round trip."""
out = convert(xdas_base_das, "xdas.DataArray")
assert np.all(out.data == xdas_dataarray.data)
# TODO the str rep of coords are equal but not coords themselves.
# Need to look into this.
# We need to look into this.
assert str(out.coords) == str(xdas_dataarray.coords)
attr1, attr2 = out.attrs, xdas_dataarray.attrs
assert attr1 == attr2 or (not attr1 and not attr2)
Expand Down Expand Up @@ -197,14 +186,19 @@ def test_from_base_das(self, lightguide_base_das, lightguide_blast):
out = convert(lightguide_base_das, "lightguide.Blast")
# TODO here the objects also do not compare equal. Need to figure out
# why. For now just do weaker checks.
# assert np.all(out.data == lightguide_blast.data)
# assert out == lightguide_blast
assert out.start_time == lightguide_blast.start_time
assert np.all(out.data == lightguide_blast.data)
assert out.unit == lightguide_blast.unit
assert out.channel_spacing == lightguide_blast.channel_spacing
assert out.start_channel == lightguide_blast.start_channel
assert out.sampling_rate == lightguide_blast.sampling_rate

def test_convert_blast_to_other(self, lightguide_blast, format_name):
"""Test that the base blast can be converted to all formats."""
out = convert(lightguide_blast, to=format_name)
assert isinstance(out, NAME_CLASS_MAP[format_name])


class TestConvert:
"""Generic tests for the convert function."""
Expand Down Expand Up @@ -255,6 +249,17 @@ def my_patch_func(patch):
assert new2.raw_function is my_patch_func.raw_function
assert new.raw_function is my_patch_func.raw_function

def test_different_return_type(self, daspy_section):
"""Ensure wrapped functions that return different types still work."""

@adapter("dascore.Patch")
def dummy_func(patch):
"""Dummy function that returns dataframe."""
return dc.spool(patch).get_contents()

out = dummy_func(daspy_section)
assert isinstance(out, pd.DataFrame)


class TestIntegrations:
"""Tests for integrating different data structures."""
Expand Down
Loading
0