Source code for pado.dataset

from __future__ import annotations

import math
import os
import sys
import uuid
import warnings
from collections.abc import Iterable
from collections.abc import Sized
from typing import Any
from typing import Callable
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Union
from typing import cast
from typing import overload

if sys.version_info >= (3, 8):
    from typing import Literal
    from typing import get_args
else:
    from typing_extensions import Literal
    from typing_extensions import get_args

import fsspec
import pandas as pd
import shapely.wkt

from pado._compat import cached_property
from pado._repr import DescribeFormat
from pado._repr import describe_format_plain_text
from pado._repr import number
from pado.annotations import AnnotationProvider
from pado.annotations import Annotations
from pado.annotations import GroupedAnnotationProvider
from pado.images.ids import ImageId
from pado.images.image import Image
from pado.images.providers import GroupedImageProvider
from pado.images.providers import ImageProvider
from pado.io.files import fsopen
from pado.io.files import urlpathlike_get_path
from pado.io.files import urlpathlike_to_fs_and_path
from pado.io.files import urlpathlike_to_string
from pado.io.store import StoreType
from pado.io.store import get_store_type
from pado.metadata import GroupedMetadataProvider
from pado.metadata import MetadataProvider
from pado.predictions.providers import ImagePredictionProvider
from pado.predictions.proxy import PredictionProxy
from pado.types import DatasetSplitter
from pado.types import IOMode
from pado.types import UrlpathLike

__all__ = [
    "PadoDataset",
    "PadoItem",
    "DescribeFormat",
]


class PadoDataset:
    __version__ = 2

    def __init__(
        self,
        urlpath: UrlpathLike | None,
        mode: IOMode = "r",
        *,
        storage_options: dict[str, Any] | None = None,
    ) -> None:
        """open or create a new PadoDataset

        Parameters
        ----------
        urlpath:
            fsspec urlpath to `pado.dataset.toml` file, or its parent directory.
            If explicitly set to None, uses an in-memory filesystem to store the dataset.
        mode:
            'r' --> readonly, error if not there
            'r+' --> read/write, error if not there
            'a' --> read/write, create if not there, append if there
            'w' --> read/write, create if not there, truncate if there
            'x' --> read/write, create if not there, error if there
        storage_options:
            an optional dictionary with options passed to fsspec for opening the urlpath

        """
        self._mode: IOMode = mode
        self._storage_options: dict[str, Any] = storage_options or {}

        if urlpath is None:
            # enable in-memory pado datasets and change mode to enable write
            self._urlpath = f"memory://pado-{uuid.uuid4()}"
            self._mode = "r+"
            self._ensure_dir()
        else:
            try:
                self._urlpath = urlpathlike_to_string(urlpath)
            except TypeError as err:
                raise TypeError(f"incompatible urlpath {urlpath!r}") from err

            # check mode
            if mode not in get_args(IOMode):
                raise ValueError(f"unsupported mode {mode!r}")

            # if the dataset files should be there, check them
            try:
                fs = self._fs
            except OSError as err:
                raise RuntimeError(
                    f"can't instantiate filesystem (urlpath={self._urlpath!r}) error: {err!r}"
                )
            if mode in {"r", "r+"}:
                try:
                    is_dir = fs.isdir(self._root)  # raises if not there or reachable
                except BaseException as err:
                    raise RuntimeError(f"{self._urlpath!r} not accessible") from err
                if not is_dir:
                    raise NotADirectoryError(f"{self._urlpath!r} not a directory")
                if not any(fs.glob(self._get_fspath("*.image.parquet"))):
                    raise ValueError(f"{self._urlpath!r} has no image parquet file.")
            elif mode == "x":
                if fs.isdir(self._root) and fs.glob(
                    self._get_fspath("*.image.parquet")
                ):
                    raise FileExistsError(f"{self._urlpath!r} exists")

            if not self.readonly:
                if fs.exists(self._get_fspath(".frozen")):
                    raise PermissionError(
                        "PadoDataset has been frozen. Can only use mode='r'"
                    )
                self._ensure_dir()

    @property
    def urlpath(self) -> str:
        """the urlpath pointing to the PadoDataset"""
        return self._urlpath

    @property
    def storage_options(self) -> dict[str, Any]:
        """the storage options used for the PadoDataset"""
        return self._storage_options

    @property
    def _fs(self) -> fsspec.AbstractFileSystem:
        fs, _ = urlpathlike_to_fs_and_path(
            self._urlpath, storage_options=self._storage_options
        )
        return fs

    @property
    def _root(self) -> str:
        return urlpathlike_get_path(self._urlpath)

    @property
    def readonly(self) -> bool:
        """is the dataset in readonly mode"""
        return self._mode == "r"

    @property
    def persistent(self) -> bool:
        """is the dataset stored in a persistent location"""
        # todo: this might need to be extended if we find other usecases than memory fs
        return self._fs.protocol != "memory"

    def __repr__(self):
        so = ""
        if self._storage_options:
            so = f", storage_options={self._storage_options!r}"
        return f"{type(self).__name__}({self.urlpath!r}, mode={self._mode!r}{so})"

    # === data properties ===

    @cached_property
    def index(self) -> Sequence[ImageId]:
        """sequence of image_ids in the dataset"""
        image_ids = self.images.keys()
        if isinstance(image_ids, Sequence):
            return image_ids
        else:
            return tuple(image_ids)

    @cached_property
    def images(self) -> ImageProvider:
        """mapping image_ids to images in the dataset"""
        fs = self._fs
        providers = [
            ImageProvider.from_parquet(fsopen(fs, p, mode="rb"))
            for p in fs.glob(self._get_fspath("*.image.parquet"))
            if fs.isfile(p)
        ]

        if len(providers) == 0:
            image_provider = ImageProvider()
        elif len(providers) == 1:
            image_provider = providers[0]
        else:
            image_provider = GroupedImageProvider(*providers)

        return image_provider

    @cached_property
    def annotations(self) -> AnnotationProvider:
        """mapping image_ids to annotations in the dataset"""
        fs = self._fs
        providers = [
            AnnotationProvider.from_parquet(fsopen(fs, p, mode="rb"))
            for p in fs.glob(self._get_fspath("*.annotation.parquet"))
            if fs.isfile(p)
        ]

        if len(providers) == 0:
            annotation_provider = AnnotationProvider({})
        elif len(providers) == 1:
            annotation_provider = providers[0]
        else:
            annotation_provider = GroupedAnnotationProvider(*providers)

        return annotation_provider

    @cached_property
    def metadata(self) -> MetadataProvider:
        """mapping image_ids to metadata in the dataset"""
        fs = self._fs
        providers = [
            MetadataProvider.from_parquet(fsopen(fs, p, mode="rb"))
            for p in fs.glob(self._get_fspath("*.metadata.parquet"))
            if fs.isfile(p)
        ]

        if len(providers) == 0:
            metadata_provider = MetadataProvider({})
        elif len(providers) == 1:
            metadata_provider = providers[0]
        else:
            metadata_provider = GroupedMetadataProvider(*providers)

        return metadata_provider

    @cached_property
    def predictions(self) -> PredictionProxy:
        return PredictionProxy(self)

    # === access ===

    @overload
    def __getitem__(self, key: ImageId | int) -> PadoItem:
        ...

    @overload
    def __getitem__(self, key: slice) -> PadoDataset:
        ...

    def __getitem__(self, key):
        if isinstance(key, slice):
            selected = self.index[key]
            return self.filter(selected)

        if isinstance(key, ImageId):
            image_id = key

        elif isinstance(key, int):
            image_id = self.index[key]

        else:
            raise TypeError(f"Unexpected type {type(key)}")

        try:
            return PadoItem(
                image_id,
                self.images[image_id],
                self.annotations.get(image_id),
                self.metadata.get(image_id),
            )
        except KeyError:
            raise KeyError(f"{key} does not match any images in this dataset.")

    def get_by_id(self, image_id: ImageId) -> PadoItem:
        if not isinstance(image_id, ImageId):
            raise TypeError(f"Unexpected type {type(image_id)}")
        warnings.warn(
            "`get_by_id` is deprecated and will be removed in a future release. Use `__getitem__` instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self[image_id]

    def get_by_idx(self, idx: int) -> PadoItem:
        if not isinstance(idx, int):
            raise TypeError(f"Unexpected type {type(idx)}")
        warnings.warn(
            "`get_by_idx` is deprecated and will be removed in a future release. Use `__getitem__` instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self[idx]

    def __len__(self):
        return len(self.images)

    # === filter functionality ===

    def filter(
        self,
        ids_or_func: Sequence[ImageId] | Callable[[PadoItem], bool],
        *,
        urlpath: Optional[UrlpathLike] = None,
        mode: IOMode = "r",
        on_empty: Literal["ignore", "warn", "error"] = "warn",
    ) -> PadoDataset:
        """filter a pado dataset

        Parameters
        ----------
        ids_or_func:
            either a Sequence of ImageId instances or a function that gets
            called with each PadoItem and returns a bool indicating if it should
            be kept or not.
        urlpath:
            a urlpath to store the filtered provider. If None (default) returns
            a in-memory PadoDataset
        mode:
            set the io mode for the returned dataset
        on_empty:
            "warn" (default) will warn if the filtering returns an empty dataset.
            "error" raises a ValueError.
            "ignore" returns empty datasets without warning.

        """
        # todo: if this is not fast enough might consider lazy filtering

        if isinstance(ids_or_func, ImageId):
            raise ValueError("must provide a list of ImageIds")

        if isinstance(ids_or_func, Iterable) and isinstance(ids_or_func, Sized):
            ids = pd.Series(ids_or_func, dtype=object).apply(str.__call__)
            _ip, _ap, _mp = self.images, self.annotations, self.metadata
            ip = ImageProvider(
                _ip.df.loc[_ip.df.index.intersection(ids), :], identifier=_ip.identifier
            )
            ap = AnnotationProvider(
                _ap.df.loc[_ap.df.index.intersection(ids), :], identifier=_ap.identifier
            )
            mp = MetadataProvider(
                _mp.df.loc[_mp.df.index.intersection(ids), :], identifier=_mp.identifier
            )

        elif callable(ids_or_func):
            func = ids_or_func
            ip = cast(ImageProvider, {})
            ap = cast(AnnotationProvider, {})
            mp = cast(MetadataProvider, {})
            for image_id in self.index:
                item = self[image_id]
                keep = func(item)
                if not keep:
                    continue
                image = item.image
                assert image is not None, "images currently required"
                ip[image_id] = image
                if item.annotations is not None:
                    ap[image_id] = item.annotations
                if item.metadata is not None:
                    mp[image_id] = item.metadata

        else:
            raise TypeError(
                f"requires sequence of ImageId or a callable of type FilterFunc, got {ids_or_func!r}"
            )

        if len(ip) == 0:
            if on_empty == "error":
                raise ValueError("did not match any images")
            elif on_empty == "warn":
                warnings.warn("did not match any images", stacklevel=2)
            elif on_empty == "ignore":
                pass
            else:
                raise ValueError(
                    f"on_empty not one of {'error', 'warn', 'ignore'}, got: {on_empty!r}"
                )

        ds = PadoDataset(urlpath, mode="w")
        ds.ingest_obj(ImageProvider(ip, identifier=self.images.identifier))

        if len(ap) > 0:
            ds.ingest_obj(
                AnnotationProvider(ap, identifier=self.annotations.identifier)
            )
        if len(mp) > 0:
            ds.ingest_obj(MetadataProvider(mp, identifier=self.metadata.identifier))
        elif len(mp.df.columns) > 0:
            ds.ingest_obj(MetadataProvider(mp.df, identifier=self.metadata.identifier))

        return PadoDataset(ds.urlpath, mode=mode)

    def partition(
        self,
        splitter: DatasetSplitter,
        label_func: Optional[Callable[[PadoDataset], Sequence[Any]]] = None,
        group_func: Optional[Callable[[PadoDataset], Sequence[Any]]] = None,
    ) -> List[Split]:
        """partition a pado dataset into train and test

        Parameters
        ----------
        splitter:
            a DatasetSplitter instance (basically all sklearn.model_selection splitter classes)
        label_func:
            gets called with the pado dataset and has to return a sequence of labels with the
            same length as the dataset.index. (default None)
        group_func:
            gets called with the pado dataset and has to return a sequence of groups with the
            same length as the dataset.index. (default None)

        Notes
        -----
        dependent on the provided splitter instance, label_func and group_func might be ignored.

        """
        if label_func is not None:
            labels = label_func(self)
        else:
            labels = None
        if group_func is not None:
            groups = group_func(self)
        else:
            groups = None
        splits = splitter.split(X=self.index, y=labels, groups=groups)
        image_ids = pd.Series(self.index).values

        output = []
        for train_idxs, test_idxs in splits:
            ds0 = self.filter(image_ids[train_idxs])
            ds1 = self.filter(image_ids[test_idxs])
            output.append(Split(ds0, ds1))
        return output

    # === data ingestion ===

    def ingest_obj(
        self, obj: Any, *, identifier: Optional[str] = None, overwrite: bool = False
    ) -> None:
        """ingest an object into the dataset"""
        if self.readonly:
            raise RuntimeError(f"{self!r} opened in readonly mode")

        if isinstance(obj, PadoDataset):
            for x in [obj.images, obj.metadata, obj.annotations]:
                self.ingest_obj(x)
            return

        cache: Literal["images", "annotations", "metadata", "predictions"]

        if isinstance(obj, dict):
            raise NotImplementedError("todo: guess provider type")

        elif isinstance(obj, ImageProvider):
            if identifier is None and obj.identifier is None:
                raise ValueError("need to provide an identifier for ImageProvider")
            identifier = identifier or obj.identifier
            pth = self._get_fspath(f"{identifier}.image.parquet")
            cache = "images"

        elif isinstance(obj, AnnotationProvider):
            if identifier is None and obj.identifier is None:
                raise ValueError("need to provide an identifier for AnnotationProvider")
            identifier = identifier or obj.identifier
            pth = self._get_fspath(f"{identifier}.annotation.parquet")
            cache = "annotations"

        elif isinstance(obj, MetadataProvider):
            if identifier is None and obj.identifier is None:
                raise ValueError("need to provide an identifier for MetadataProvider")
            identifier = identifier or obj.identifier
            pth = self._get_fspath(f"{identifier}.metadata.parquet")
            cache = "metadata"

        elif isinstance(obj, ImagePredictionProvider):
            if identifier is None and obj.identifier is None:
                raise ValueError(
                    "need to provide an identifier for ImagePredictionProvider"
                )
            identifier = identifier or obj.identifier
            pth = self._get_fspath(f"{identifier}.image_predictions.parquet")
            cache = "predictions"

        else:
            raise TypeError(f"unsupported object type {type(obj).__name__}: {obj!r}")

        if overwrite:
            obj.to_parquet(fsopen(self._fs, pth, mode="wb"))
        else:
            obj.to_parquet(fsopen(self._fs, pth, mode="xb"))
        self._clear_caches(cache)

    def ingest_file(
        self, urlpath: UrlpathLike, *, identifier: Optional[str] = None
    ) -> None:
        """ingest a file into the dataset"""
        if self.readonly:
            raise RuntimeError(f"{self!r} opened in readonly mode")
        store_type = get_store_type(urlpath)
        if store_type == StoreType.IMAGE:
            self.ingest_obj(ImageProvider.from_parquet(urlpath), identifier=identifier)

        elif store_type == StoreType.ANNOTATION:
            self.ingest_obj(
                AnnotationProvider.from_parquet(urlpath), identifier=identifier
            )

        elif store_type == StoreType.METADATA:
            self.ingest_obj(
                MetadataProvider.from_parquet(urlpath), identifier=identifier
            )

        else:
            raise NotImplementedError("todo: implement more files")

    # === describe (summarise) dataset ===

    @overload
    def describe(self) -> dict:
        ...

    @overload
    def describe(self, output_format: Literal[DescribeFormat.PLAIN_TEXT]) -> str:
        ...

    @overload
    def describe(self, output_format: Literal[DescribeFormat.DICT]) -> dict:
        ...

    @overload
    def describe(self, output_format: str) -> Union[str, dict]:
        ...

    def describe(
        self, output_format: DescribeFormat | str = DescribeFormat.PLAIN_TEXT
    ) -> Union[str, dict]:
        """A 'to string' method for essential PadoDataset information"""
        if output_format not in list(DescribeFormat):
            raise ValueError(f"{output_format!r} is not a valid output format.")

        # convert annotations df
        idf = self.images.df
        adf = self.annotations.df
        adf["area"] = adf["geometry"].apply(lambda x: shapely.wkt.loads(x).area)
        agg_annotations = adf.groupby("classification")["area"].agg(["sum", "count"])

        # get metadata columns
        try:
            mp = self.metadata
        except TypeError:
            # todo: self.metadata currently raises TypeError if no provider found
            md_columns = []
        else:
            md_columns = mp.df.columns.to_list()

        def make_replace_nan_cast(cast: Callable, default: Any) -> Callable:
            def _cast(x: Any) -> Any:
                if isinstance(x, float) and math.isnan(x):
                    return default
                else:
                    return cast(x)

            return _cast

        data = {
            "path": self.urlpath,
            "num_images": len(self.images),
            "num_mpps": [
                {"mpp": k, "num": v}
                for k, v in idf[["mpp_x", "mpp_y"]].value_counts().items()
            ],
            "avg_image_width": number(idf["width"], agg="avg", unit="px"),
            "avg_image_height": number(idf["height"], agg="avg", unit="px"),
            "avg_image_size": number(idf["size_bytes"], agg="avg", unit="b"),
            "avg_annotations_per_image": number(
                adf.groupby("image_id")["geometry"].count(),
                agg="avg",
                cast_to=make_replace_nan_cast(int, default=0),
            ),
            "metadata_columns": md_columns,
            "total_size_images": number(idf["size_bytes"], agg="sum", unit="b"),
            "total_num_annotations": sum(
                len(x) for x in list(self.annotations.values())
            ),
            "common_classes_count": dict(
                agg_annotations["count"].sort_values(ascending=False)[:5].items()
            ),
            "common_classes_area": {
                k: number(v, cast_to=float, unit="px")
                for k, v in agg_annotations["sum"]
                .sort_values(ascending=False)[:5]
                .items()
            },
        }

        if output_format in {DescribeFormat.DICT, DescribeFormat.JSON}:
            return data

        elif output_format == DescribeFormat.PLAIN_TEXT:
            return describe_format_plain_text(data)

        else:
            raise NotImplementedError(f'Format "{output_format}" is not allowed.')

    # === internal utility methods ===

    def _clear_caches(
        self,
        *caches: Literal["images", "metadata", "annotations", "predictions"],
        _target: dict | None = None,
    ) -> None:
        """clear each requested cached_property"""
        valid_caches = ("images", "metadata", "annotations", "predictions")
        if not caches:
            caches = valid_caches  # type: ignore
        elif not set(caches).issubset(valid_caches):
            raise ValueError(
                f"unsupported cache: {set(caches).difference(valid_caches)}"
            )
        _caches: list[str] = list(caches)
        if "images" in _caches:
            _caches.insert(_caches.index("images") + 1, "index")
        if _target is None:
            _target = self.__dict__
        for cache in reversed(_caches):
            try:
                del _target[cache]
            except KeyError:
                pass

    def _get_fspath(self, *parts: Union[str, os.PathLike]) -> str:
        """return a fspath for a resource"""
        return os.fspath(os.path.join(self._root, *parts))

    def _ensure_dir(self, *parts: Union[str, os.PathLike]) -> str:
        """ensure that a folder within the dataset exists"""
        fs, pth = self._fs, self._get_fspath(*parts)
        if not fs.isdir(pth):
            fs.mkdir(pth)
        return pth

    # === pickling ===

    def __getstate__(self) -> dict[str, Any]:
        # clear caches and specialize for memory:// datasets
        state = self.__dict__.copy()
        self._clear_caches(_target=state)
        if type(self._fs).__name__ == "MemoryFileSystem":
            from fsspec.implementations.memory import MemoryFileSystem

            if not isinstance(self._fs, MemoryFileSystem):
                raise RuntimeError(f"unexpected error: {self._fs!r}")
            path = urlpathlike_get_path(self._urlpath, fs_cls=type(self._fs))
            store = {
                k: v for k, v in MemoryFileSystem.store.items() if k.startswith(path)
            }
            if store:
                warnings.warn(
                    "Pickling a `memory://` filesystem backed pado dataset.",
                    stacklevel=2,
                )
                state["__pado_fsspec_memory_store__"] = store

        return state

    def __setstate__(self, state: dict[str, Any]) -> None:
        # specialized for memory:// datasets
        memory_store = state.pop("__pado_fsspec_memory_store__", None)
        if memory_store is not None:
            from fsspec.implementations.memory import MemoryFileSystem

            # warn if overwriting pseudo files in the MemoryFileSystem
            if not memory_store.keys().isdisjoint(MemoryFileSystem.store):
                warnings.warn(
                    "Key collision when unpickling a `memory://` filesystem backed pado dataset:"
                    f" {set(memory_store).intersection(MemoryFileSystem.store)!r}",
                    stacklevel=2,
                )

            # reconstruct pseudo dirs in the MemoryFileSystem
            dirs = set(map(os.path.dirname, memory_store))
            for path in sorted(dirs):
                if path not in MemoryFileSystem.pseudo_dirs:
                    MemoryFileSystem.pseudo_dirs.append(path)

            MemoryFileSystem.store.update(memory_store)

        self.__dict__.update(state)


# === helpers and utils =======================================================


[docs]class PadoItem(NamedTuple): """A 'row' of a dataset as returned by PadoDataset.__getitem__""" id: Optional[ImageId] image: Optional[Image] annotations: Optional[Annotations] metadata: Optional[pd.DataFrame]
class Split(NamedTuple): """train test tuple as returned by PadoDataset.partition method""" train: PadoDataset test: PadoDataset