"""some common pado classes"""
from __future__ import annotations
import re
import sys
import uuid
from collections import deque
from functools import lru_cache
from itertools import repeat
from reprlib import Repr
from textwrap import dedent
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import ItemsView
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import MutableMapping
from typing import MutableSequence
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import cast
from typing import overload
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
import pandas as pd
from pado._compat import cached_property
from pado.images.ids import ImageId
from pado.io.store import Store
from pado.types import SerializableItem
from pado.types import UrlpathLike
__all__ = [
"PadoMutableSequence",
"PadoMutableSequenceMapping",
"SerializableProviderMixin",
"ProviderStoreMixin",
"GroupedProviderMixin",
"validate_dataframe_index",
"is_valid_identifier",
"clear_provider_getitem_cache",
]
_r = Repr()
_r.maxlist = 3
# === collections =============================================================
PI = TypeVar("PI", bound=SerializableItem)
class PadoMutableMapping(MutableMapping[ImageId, PI]):
__value_type__: Type[PI]
df: pd.DataFrame
identifier: str
def __init__(
self,
provider: Mapping[ImageId, PI] | pd.DataFrame | dict | None = None,
*,
identifier: Optional[str] = None,
):
if provider is None:
provider = {}
if isinstance(provider, type(self)):
self.df = provider.df.copy()
self.identifier = str(identifier) if identifier else provider.identifier
elif isinstance(provider, pd.DataFrame):
validate_dataframe_index(provider)
self.df = provider.copy()
self.identifier = str(identifier) if identifier else str(uuid.uuid4())
elif isinstance(provider, dict):
if not provider:
self.df = pd.DataFrame(columns=self.__value_type__.__fields__)
else:
self.df = pd.DataFrame.from_records(
index=list(map(ImageId.to_str, provider.keys())),
data=list(map(lambda x: x.to_record(), provider.values())),
columns=self.__value_type__.__fields__,
)
self.identifier = str(identifier) if identifier else str(uuid.uuid4())
else:
raise TypeError(
f"expected `{type(self).__name__}`, got: {type(provider).__name__!r}"
)
self.__getitem_cached__ = lru_cache(maxsize=None)(self.__getitem_uncached__)
def __getitem__(self, image_id: ImageId) -> PI:
return self.__getitem_cached__(image_id)
def __getitem_uncached__(self, image_id: ImageId) -> PI:
if not isinstance(image_id, ImageId):
raise TypeError(
f"keys must be ImageId instances, got {type(image_id).__name__!r}"
)
row = self.df.loc[image_id.to_str()]
return self.__value_type__.from_obj(row)
def __setitem__(self, image_id: ImageId, value: PI) -> None:
if not isinstance(image_id, ImageId):
raise TypeError(
f"keys must be ImageId instances, got {type(image_id).__name__!r}"
)
if not isinstance(value, self.__value_type__):
raise TypeError(
f"values must be {self.__value_type__.__name__} instances, got {type(value).__name__!r}"
)
dct = value.to_record()
self.df.loc[image_id.to_str()] = pd.Series(dct)
def __delitem__(self, image_id: ImageId) -> None:
if not isinstance(image_id, ImageId):
raise TypeError(
f"keys must be ImageId instances, got {type(image_id).__name__!r}"
)
self.df.drop(image_id.to_str(), inplace=True)
def __len__(self) -> int:
return len(self.df)
def __iter__(self) -> Iterator[ImageId]:
return iter(map(ImageId.from_str, self.df.index))
def items(self) -> PadoItemsView[ImageId, PI]:
return PadoItemsView(self)
def __repr__(self):
_akw = [_r.repr_dict(cast(dict, self), 0)]
if self.identifier is not None:
_akw.append(f"identifier={self.identifier!r}")
return f"{type(self).__name__}({', '.join(_akw)})"
K = TypeVar("K")
class PadoItemsView(ItemsView, Generic[K, PI]):
_mapping: PadoMutableMapping
def __init__(self, mapping, *, value_type=None, value_transform=None) -> None:
super().__init__(mapping)
if not hasattr(mapping, "__value_type__"):
self._value_type = value_type
else:
self._value_type = self._mapping.__value_type__
self._value_transform = value_transform
def __iter__(self) -> Iterator[tuple[K, PI]]:
iid_from_str = ImageId.from_str
value_from_obj = self._value_type.from_obj
if self._value_transform is None:
for row in self._mapping.df.itertuples(index=True, name="ValueAsRow"):
# noinspection PyProtectedMember
x = row._asdict()
i = x.pop("Index")
yield iid_from_str(i), value_from_obj(x)
else:
vt = self._value_transform
mapping = self._mapping
for row in mapping.df.itertuples(index=True, name="ValueAsRow"):
# noinspection PyProtectedMember
x = row._asdict()
i = x.pop("Index")
yield iid_from_str(i), vt(mapping, value_from_obj(x))
[docs]class PadoMutableSequence(MutableSequence[PI]):
# subclasses must provide this
__item_class__: Type[PI]
# annotations for better IDE support
df: pd.DataFrame
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not hasattr(cls, "__item_class__"):
raise AttributeError(f"subclass {cls.__name__} must define __item_class__")
def __init__(
self,
df: pd.DataFrame | None = None,
*,
image_id: ImageId | None = None,
) -> None:
if df is None:
self.df = pd.DataFrame(columns=self.__item_class__.__fields__)
elif isinstance(df, pd.DataFrame):
self.df = df
else:
raise TypeError(f"requires a pd.DataFrame, not {type(df).__name__}")
self._image_id = image_id
if image_id is not None:
self._update_df_image_id(image_id)
def __repr__(self):
v = _r.repr_list(cast(list, self), 0)
return f"{type(self).__name__}({v}, image_id={self._image_id!r})"
def __eq__(self, other):
if (
not isinstance(other, type(self))
or other.__item_class__ != type(self).__item_class__
):
return NotImplemented
return all(a == b for a, b in zip(self, other))
@property
def image_id(self) -> ImageId | None:
return self._image_id
@image_id.setter
def image_id(self, value: ImageId):
if not isinstance(value, ImageId):
raise TypeError(
f"{value!r} not of type ImageId, got {type(value).__name__}"
)
self._update_df_image_id(image_id=value)
self._image_id = value
def _update_df_image_id(self, image_id: ImageId):
"""internal"""
if self.df.empty:
return
ids = set(self.df["image_id"].unique())
if len(ids) > 2:
raise ValueError(f"image_ids in provider not unique: {ids!r}")
if None not in ids and image_id.to_str() in ids:
return
elif {None, image_id.to_str()}.issuperset(ids):
self.df.loc[self.df["image_id"].isna(), "image_id"] = image_id.to_str()
else:
raise AssertionError(
f"unexpected image_ids in {type(self).__name__}.df: {ids!r}"
)
@overload
def __getitem__(self, index: int) -> PI:
...
@overload
def __getitem__(self, index: slice) -> Self:
...
def __getitem__(self, index: int | slice) -> PI | Self:
if isinstance(index, int):
return self.__item_class__.from_obj(self.df.iloc[index, :])
elif isinstance(index, slice):
return self.__class__(self.df.loc[index, :], image_id=self.image_id)
else:
raise TypeError(
f"{type(self).__name__}: indices must be integers or slices, not {type(index).__name__}"
)
@overload
def __setitem__(self, index: int, value: PI) -> None:
...
@overload
def __setitem__(self, index: slice, value: Iterable[PI]) -> None:
...
def __setitem__(self, index: int | slice, value: PI | Iterable[PI]) -> None:
if isinstance(index, int):
if not isinstance(value, self.__item_class__):
raise TypeError(
f"requires `{self.__item_class__.__name__}` got: {type(value).__name__!r}"
)
self.df.iloc[index, :] = pd.DataFrame(
[value.to_record(self._image_id)],
columns=list(self.__item_class__.__fields__),
)
elif isinstance(index, slice):
if isinstance(value, self.__item_class__):
raise TypeError(
f"requires `Iterable[{self.__item_class__.__name__}]` got: {type(value).__name__!r}"
)
else:
it = iter(value) # type: ignore
self.df.iloc[index, :] = pd.DataFrame(
[x.to_record(self._image_id) for x in it],
columns=list(self.__item_class__.__fields__),
)
else:
raise TypeError(
f"{type(self).__name__}: indices must be integers or slices, not {type(index).__name__}"
)
def __delitem__(self, index: int | slice) -> None:
if isinstance(index, int):
self.df.drop(labels=index, axis=0, inplace=True)
elif isinstance(index, slice):
self.df.drop(labels=self.df.index[index], axis=0, inplace=True)
else:
raise TypeError(
f"{type(self).__name__}: indices must be integers or slices, not {type(index).__name__}"
)
[docs] def insert(self, index: int, value: PI) -> None:
if not isinstance(value, self.__item_class__):
raise TypeError(
f"can only insert type {self.__item_class__.__name__}, got {type(value).__name__!r}"
)
df_a = self.df.iloc[:index, :]
df_i = pd.DataFrame(
[value.to_record(self._image_id)], columns=self.__item_class__.__fields__
)
df_b = self.df.iloc[index:, :]
self.df = pd.concat([df_a, df_i, df_b])
def __len__(self) -> int:
return len(self.df)
@classmethod
def from_records(
cls: Type[Self],
annotation_records: Iterable[dict],
*,
image_id: ImageId | None = None,
) -> Self:
df = pd.DataFrame(
list(annotation_records), columns=cls.__item_class__.__fields__
)
return cls(df, image_id=image_id)
VT = TypeVar("VT", bound="PadoMutableSequence")
[docs]class PadoMutableSequenceMapping(MutableMapping[ImageId, VT]):
__value_class__: Type[VT]
df: pd.DataFrame
identifier: str
def __init__(
self,
provider: Mapping[ImageId, VT] | pd.DataFrame | dict | None = None,
*,
identifier: Optional[str] = None,
):
if provider is None:
provider = {}
if isinstance(provider, type(self)):
self.df = provider.df.copy()
self.identifier = str(identifier) if identifier else provider.identifier
elif isinstance(provider, pd.DataFrame):
validate_dataframe_index(provider)
self.df = provider.copy()
self.identifier = str(identifier) if identifier else str(uuid.uuid4())
elif isinstance(provider, dict):
if not provider:
self.df = pd.DataFrame(
columns=self.__value_class__.__item_class__.__fields__
)
else:
indices: list[ImageId] = []
data: list[dict] = []
for key, value in provider.items():
if value is None:
continue
indices.extend(repeat(ImageId.to_str(key), len(value)))
data.extend(a.to_record() for a in value)
self.df = pd.DataFrame.from_records(
index=indices,
data=data,
columns=self.__value_class__.__item_class__.__fields__,
)
self.identifier = str(identifier) if identifier else str(uuid.uuid4())
else:
raise TypeError(
f"expected `BaseAnnotationProvider`, got: {type(provider).__name__!r}"
)
self._store: dict[ImageId, VT] = {}
def __getitem__(self, image_id: ImageId) -> VT:
if not isinstance(image_id, ImageId):
raise TypeError(
f"keys must be ImageId instances, got {type(image_id).__name__!r}"
)
try:
return self._store[image_id]
except KeyError:
df = self.df.loc[
[image_id.to_str()], :
] # list: return DataFrame even if length == 1
df = df.reset_index(drop=True)
a = self._store[image_id] = self.__value_class__(df, image_id=image_id)
return a
def __setitem__(self, image_id: ImageId, v: VT) -> None:
if not isinstance(image_id, ImageId):
raise TypeError(
f"keys must be ImageId instances, got {type(image_id).__name__!r}"
)
if not isinstance(v, self.__value_class__):
raise TypeError(f"requires Annotations, got {type(v).__name__}")
if v.image_id is None:
v.image_id = image_id
elif v.image_id != image_id:
raise ValueError(f"image_ids don't match: {image_id!r} vs {v.image_id!r}")
self._store[image_id] = v
def __delitem__(self, image_id: ImageId) -> None:
if not isinstance(image_id, ImageId):
raise TypeError(
f"keys must be ImageId instances, got {type(image_id).__name__!r}"
)
try:
del self._store[image_id]
except KeyError:
had_store = False
else:
had_store = True
try:
self.df.drop(image_id.to_str(), inplace=True)
except KeyError:
had_df = False
else:
had_df = True
if not had_store and not had_df:
raise KeyError(image_id)
def __len__(self) -> int:
if not self._store:
return self.df.index.nunique()
else:
return len(
set(map(ImageId.from_str, self.df.index.unique())).union(self._store)
)
def __iter__(self) -> Iterator[ImageId]:
return iter(
set(map(ImageId.from_str, self.df.index.unique())).union(self._store)
)
# === mixins ==================================================================
ST = TypeVar("ST", bound=Store)
[docs]class SerializableProviderMixin(Generic[ST]):
# required attributes
__store_class__: Type[ST]
# these attributes are part of the provider
df: pd.DataFrame
identifier: str
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__name__ == "GroupedProviderMixin":
return
if not hasattr(cls, "__store_class__"):
raise AttributeError(f"subclass {cls.__name__} must define __store_class__")
if cls.__store_class__ is Store:
raise ValueError("must use a subclass of Store for __store_class__")
def __repr__(self):
return f"{type(self).__name__}({self.identifier!r})"
def to_parquet(
self, urlpath: UrlpathLike, *, storage_options: dict[str, Any] | None = None
) -> None:
store = self.__store_class__() # type: ignore
store.to_urlpath(
self.df,
urlpath,
identifier=self.identifier,
storage_options=storage_options,
)
@classmethod
def from_parquet(cls: Type[Self], urlpath: UrlpathLike) -> Self:
store = cls.__store_class__()
df, identifier, user_metadata = store.from_urlpath(urlpath)
if {
store.METADATA_KEY_STORE_TYPE,
store.METADATA_KEY_STORE_VERSION,
store.METADATA_KEY_PADO_VERSION,
store.METADATA_KEY_PROVIDER_VERSION,
store.METADATA_KEY_CREATED_AT,
store.METADATA_KEY_CREATED_BY,
} != set(user_metadata):
raise NotImplementedError(f"currently unused {user_metadata!r}")
inst = cls(identifier=identifier) # type: ignore
inst.df = df
if hasattr(inst, "__getitem_uncached__"):
inst.__getitem_cached__ = lru_cache(maxsize=None)(inst.__getitem_uncached__) # type: ignore
return inst
PT = TypeVar(
"PT", "PadoMutableMapping", "PadoMutableSequence", "PadoMutableSequenceMapping"
)
[docs]class GroupedProviderMixin(
SerializableProviderMixin[ST],
Generic[ST, PT],
):
__provider_class__: Type[PT]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# todo: should be subclass of mapping with ImageId keys
# if PadoMutableSequenceMapping not in cls.__mro__:
# raise AttributeError(
# f"{cls.__name__} must also inherit from a provider class"
# )
if not hasattr(cls, "__provider_class__"):
raise AttributeError(
f"subclass {cls.__name__} must define __provider_class__"
)
def __init__(self, *providers: PT):
super().__init__()
self.providers: list[PT] = []
for p in providers:
if not isinstance(p, self.__provider_class__):
p = self.__provider_class__(p)
if isinstance(p, type(self)):
self.providers.extend(p.providers)
else:
self.providers.append(p)
self.__dict__.pop("df") # clear cache ...
@cached_property
def df(self):
return pd.concat([p.df for p in self.providers])
def __setitem__(self, image_id: ImageId, value: Any) -> None:
raise RuntimeError(f"can't add new item to {type(self).__name__}")
def __delitem__(self, image_id: ImageId) -> None:
raise RuntimeError(f"can't delete from {type(self).__name__}")
def __repr__(self: Any) -> str:
return f'{type(self).__name__}({", ".join(map(repr, self.providers))})'
def to_parquet(
self, urlpath: UrlpathLike, *, storage_options: dict[str, Any] | None = None
) -> None:
return super().to_parquet(urlpath, storage_options=storage_options)
@classmethod
def from_parquet(cls: Type[Self], urlpath: UrlpathLike) -> Self:
raise TypeError(f"unsupported operation for {cls.__name__!r}()")
[docs]class ProviderStoreMixin(Store):
"""stores the image predictions provider in a single file with metadata"""
METADATA_KEY_PROVIDER_VERSION: str
PROVIDER_VERSION: int
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not hasattr(cls, "METADATA_KEY_PROVIDER_VERSION"):
raise AttributeError(
f"subclass {cls.__name__} must define 'METADATA_KEY_PROVIDER_VERSION'"
)
if not hasattr(cls, "PROVIDER_VERSION"):
raise AttributeError(
f"subclass {cls.__name__} must define 'PROVIDER_VERSION'"
)
def __metadata_set_hook__(
self, dct: Dict[bytes, bytes], setter: Callable[[dict, str, Any], None]
) -> None:
setter(dct, self.METADATA_KEY_PROVIDER_VERSION, self.PROVIDER_VERSION)
super().__metadata_set_hook__(dct, setter)
def __metadata_get_hook__(
self, dct: Dict[bytes, bytes], getter: Callable[[dict, str, Any], Any]
) -> Optional[dict]:
provider_version = getter(dct, self.METADATA_KEY_PROVIDER_VERSION, None)
if provider_version is None or provider_version < self.PROVIDER_VERSION:
raise RuntimeError(
"Please migrate ImagePredictionsProvider to newer version."
)
elif provider_version > self.PROVIDER_VERSION:
raise RuntimeError(
"ImageProvider is newer. Please upgrade pado to newer version."
)
md = super().__metadata_get_hook__(dct, getter) or {}
return {
**md,
self.METADATA_KEY_PROVIDER_VERSION: provider_version,
}
# === helpers =================================================================
[docs]def validate_dataframe_index(df: pd.DataFrame, *, unique_index: bool = False) -> None:
"""raise if an incorrect index is used"""
if not isinstance(df, pd.DataFrame):
raise TypeError(f"expected pandas.DataFrame, got: {type(df).__name__!r}")
try:
deque(map(ImageId.from_str, df.index), maxlen=0)
except (TypeError, ValueError):
idx0 = df.index[0]
if isinstance(idx0, tuple):
msg = """\
Detected dataframe indices of type: tuple
Did you forget to cast the ImageIds in the index to str?
>>> df = pd.DataFrame(index=[iid.to_str() for iid in image_ids], data=...)
"""
else:
msg = f"""\
Detected dataframe indices of type: {type(idx0).__name__}
You have to provide a dataframe with string image ids as an index:
>>> df = pd.DataFrame(index=[iid.to_str() for iid in image_ids], data=...)
"""
raise ValueError(dedent(msg))
if unique_index and not df.index.is_unique:
raise ValueError("Dataframe index is required to be unique.")
IDENTIFIER_RE = re.compile(r"^[a-zA-Z0-9](?:[a-zA-Z0-9_-]*[a-zA-Z0-9_])?$")
[docs]def is_valid_identifier(identifier: str) -> bool:
"""check if an identifier is a valid identifier"""
if IDENTIFIER_RE.match(identifier):
return True
else:
return False
def clear_provider_getitem_cache(p: PadoMutableMapping) -> None:
if hasattr(p, "__getitem_cached__"):
p.__getitem_cached__.cache_clear()
elif hasattr(p, "providers"):
for _p in p.providers:
clear_provider_getitem_cache(p)
elif hasattr(p, "_provider"):
# noinspection PyProtectedMember
clear_provider_getitem_cache(p._provider)
else:
pass