Source code for openmnglab.util.dicts
from typing import Mapping, TypeVar, Optional, Iterable, Callable, MutableMapping
_KT = TypeVar('_KT')
_VT_co = TypeVar('_VT_co', covariant=True)
[docs]def get_any_key(map: Mapping[_KT, _VT_co], *keys: _KT) -> Optional[_KT]:
"""
Returns any of the given keys if they exist in the mapping
:param map: mapping in which the keys should be contained
:param keys: key to check for in the mapping
:return: anyone of the given keys that exists or None if no key exists
"""
for key in keys:
if key in map.keys():
return key
return None
[docs]def get_and_incr(dct: dict[_KT, int], key: _KT) -> int:
n = dct.get(key, 0)
dct[key] = n + 1
return n
[docs]def get_any(map: Mapping[_KT, _VT_co], *keys: _KT, default: Optional[_VT_co] = None) -> Optional[_VT_co]:
"""
Gets any value mapped by one of the existing keys if it exists in the mapping
:param map: the mapping to work on
:param keys: keys to look for in the mapping
:param default: default value to return when none of the given keys exist in the mapping
:return: an existing value of any of the given keys or the default
"""
key = get_any_key(map, *keys)
if key is not None:
return map[key]
return default
[docs]def setfactory(map: MutableMapping[_KT, _VT_co], key: _KT, factory: Callable[[], _VT_co]) -> _VT_co:
"""
Equivalant of dict.setdefault with a factory instead of a value.
:param map: mapping to work on
:param key: key to look for
:param factory: factory to create a new value
:return: The value for the given key
"""
v = map.get(key)
if v is None:
v = factory()
map[key] = v
return v
[docs]def group_dict(vals: Iterable[_VT_co], key_getter: Callable[[_VT_co], _KT]) -> dict[_KT, list[_VT_co]]:
"""
groups the given vals in a dictionary by the key returned from the key getter. More robust as the groupby variant of
itertools, as it doesn't require the data to be sorted first.
:param vals: values to be grouped
:param key_getter: function that produces the key the given value should be grouped by
:return: a dictionary containing the grouped values
"""
d: dict[_KT, list[_VT_co]] = dict()
for v in vals:
setfactory(d, key_getter(v), list).append(v)
return d