Source code for compressed_lists.split_generic
from collections import defaultdict
from functools import singledispatch
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
from .base import CompressedList
from .partition import Partitioning
__author__ = "Jayaram Kancherla"
__copyright__ = "Jayaram Kancherla"
__license__ = "MIT"
[docs]
def groups_to_partition(
data: Any, groups: list, names: Optional[Sequence[str]] = None
) -> Tuple[List[Any], Partitioning]:
"""Convert group membership vector to partitioned data and Partitioning object.
Args:
data:
The data to be split (flat vector-like object).
groups:
Group membership vector, same length as data.
names:
Optional names for groups.
Returns:
Tuple of (partitioned_data_list, partitioning_object)
"""
if len(data) != len(groups):
raise ValueError(f"Length of data ({len(data)}) must match length of groups ({len(groups)})")
group_dict = defaultdict(list)
for item, group in zip(data, groups):
group_dict[group].append(item)
sorted_groups = sorted(group_dict.keys())
partitioned_data = [group_dict[group] for group in sorted_groups]
if names is None:
group_names = [str(group) for group in sorted_groups]
else:
if len(names) != len(sorted_groups):
raise ValueError(
f"Length of names ({len(names)}) must match number of unique groups ({len(sorted_groups)})"
)
group_names = names
partitioning = Partitioning.from_list(partitioned_data, group_names)
return partitioned_data, partitioning
[docs]
@singledispatch
def splitAsCompressedList(
data: Any,
groups_or_partitions: Union[list, Partitioning],
names: Optional[Sequence[str]] = None,
metadata: Optional[dict] = None,
) -> CompressedList:
"""Generic function to split data into an appropriate `CompressedList` subclass.
This function can work in two modes:
1. Group-based splitting where a flat vector is split according to group membership.
2. Partition-based splitting where a flat vector is split according to explicit partitions.
Args:
data:
The data to split into a `CompressedList`.
groups_or_partitions:
Optional group membership vector (same length as data) or
explicit partitioning object.
names:
Optional names for the list elements.
metadata:
Optional metadata for the `CompressedList`.
Returns:
An appropriate `CompressedList` subclass instance.
"""
element_type = type(data)
raise NotImplementedError(f"No `splitAsCompressedList` dispatcher found for element type {element_type}")
def _generic_register_helper(data, groups_or_partitions, names=None):
if groups_or_partitions is None:
raise ValueError("'groups_or_paritions' cannot be 'None'.")
if not data:
raise ValueError("'data' cannot be empty.")
if isinstance(groups_or_partitions, Partitioning):
if names is not None:
groups_or_partitions = groups_or_partitions.set_names(names, in_place=False)
# TODO: probably not necessary to split when groups is a partition object.
# unless ordering matters
# partitioned_data = []
# for i in range(len(groups_or_partitions)):
# start, end = groups_or_partitions.get_partition_range(i)
# partitioned_data.append(data[start:end])
partitioned_data = data
elif isinstance(groups_or_partitions, (list, np.ndarray)):
partitioned_data, groups_or_partitions = groups_to_partition(data, groups=groups_or_partitions, names=names)
if len(partitioned_data) == 0:
raise ValueError("No data after grouping")
else:
raise ValueError("'groups_or_paritions' must be a group vector or a Partition object.")
return partitioned_data, groups_or_partitions