Source code for ch5mpy.dict

# coding: utf-8

# ====================================================
# imports
from __future__ import annotations

from collections.abc import Iterable, KeysView, MutableMapping
from pathlib import Path
from typing import Any, ItemsView, Iterator, TypeVar, Union, ValuesView, cast

import numpy as np
from h5py._hl.base import ItemsViewHDF5

from ch5mpy.h5array import H5Array
from ch5mpy.names import H5Mode
from ch5mpy.objects.dataset import AsStrWrapper, Dataset
from ch5mpy.objects.group import File, Group
from ch5mpy.objects.object import H5Object
from ch5mpy.options import _ERROR_MODE
from ch5mpy.read import read_object
from ch5mpy.write import write_object

# ====================================================
# code
_T = TypeVar("_T")

_NO_OBJECT = object()


def _get_in_memory(value: Any) -> Any:
    if isinstance(value, H5Dict):
        return value.copy()

    elif isinstance(value, (Dataset, AsStrWrapper)):
        return value[()]

    elif isinstance(value, (H5Object, H5Array)):
        return value.copy()

    return value


def _is_group(obj: Group | Dataset[Any]) -> bool:
    h5_type = obj.attrs.get("__h5_type__", "<UNKNOWN>")
    return isinstance(obj, Group) and h5_type != "object"


def _get_repr(items: ItemsViewHDF5[str, Group | Dataset[Any]]) -> str:
    if not len(items):
        return "{}"

    return (
        "{\n\t"
        + ",\n\t".join(
            [
                str(k)
                + ": "
                + ("{...}" if _is_group(v) else repr(read_object(v, error="ignore")).replace("\n", "\n\t"))
                for k, v in items
            ]
        )
        + "\n}"
    )


def _get_note(annotation: str | None) -> str:
    return "" if annotation is None else f"[{annotation}]"


class H5DictValuesView(ValuesView[_T]):
    """Class for iterating over values in an H5Dict."""

    # region magic methods
    def __repr__(self) -> str:
        return f"{type(self).__name__}([{len(self)} values])"

    def __iter__(self) -> Iterator[_T]:
        return (
            read_object(v, error=_ERROR_MODE["h5dict"])
            for v in cast(Iterator[Union[Group, Dataset[Any]]], super().__iter__())
        )

    # endregion


class H5DictItemsView(ItemsView[str, _T]):
    """Class for iterating over items in an H5Dict."""

    # region magic methods
    def __repr__(self) -> str:
        return f"{type(self).__name__}([{len(self)} items])"

    def __iter__(self) -> Iterator[tuple[str, _T]]:
        return (
            (k, read_object(v, error=_ERROR_MODE["h5dict"]))
            for k, v in cast(Iterator[tuple[str, Union[Group, Dataset[Any]]]], super().__iter__())
        )

    # endregion


def _diff(a: Any, b: Any) -> bool:
    return bool(np.array(a != b).any())


[docs]class H5Dict(H5Object, MutableMapping[str, _T]): """Class for managing dictionaries backed on h5 files.""" # region magic methods
[docs] def __init__(self, file: File | Group, annotation: str | None = None): super().__init__(file) self.annotation = annotation
def __dir__(self) -> Iterable[str]: return dir(H5Dict) + list(self.keys()) def __repr__(self) -> str: if self.is_closed: return "Closed H5Dict{}" return f"H5Dict{_get_note(self.annotation)}{_get_repr(self._file.items())}" def __getitem__(self, key: str) -> _T: return cast(_T, read_object(self._file[key], error=_ERROR_MODE["h5dict"])) def __setitem__(self, key: str, value: Any) -> None: if callable(value): value(name=key, loc=self._file) elif _diff(self.get(key, _NO_OBJECT), value): write_object(self, key, value, overwrite=True) def __delitem__(self, key: str) -> None: del self._file[key] def __getattr__(self, item: str) -> _T: return self.__getitem__(item) def __len__(self) -> int: return len(self._file.keys()) def __iter__(self) -> Iterator[str]: return iter(self.keys()) def __deepcopy__(self, _memo: dict[Any, Any]) -> dict[str, Any]: return self.copy() # endregion # region class methods
[docs] @classmethod def read(cls, path: str | Path | File | Group, name: str | None = None, mode: H5Mode = H5Mode.READ) -> H5Dict[Any]: file = File(path, mode=mode) if isinstance(path, (str, Path)) else path if name is not None: file = file[name] return H5Dict(file)
# endregion # region methods
[docs] def keys(self) -> KeysView[str]: return self._file.keys()
[docs] def values(self) -> H5DictValuesView[_T]: return H5DictValuesView(self._file) # type: ignore[arg-type]
[docs] def items(self) -> H5DictItemsView[_T]: return H5DictItemsView(self._file) # type: ignore[arg-type]
[docs] def get(self, key: str, default: Any = None) -> Any: res = self._file.get(key, default) if res is default: return default return read_object(res, error=_ERROR_MODE["h5dict"])
[docs] def rename(self, name: str, new_name: str) -> None: self._file.move(name, new_name)
[docs] def copy(self) -> dict[str, Any]: """ Build an in-memory copy of this H5Dict object. """ return {k: _get_in_memory(v) for k, v in self.items()}
# endregion