Source code for biocutils.intersect

from typing import Sequence

from .is_missing_scalar import is_missing_scalar
from .map_to_index import DUPLICATE_METHOD


[docs] def intersect(*x: Sequence, duplicate_method: DUPLICATE_METHOD = "first") -> list: """ Identify the intersection of values in multiple sequences, while preserving the order of values in the first sequence. Args: x: Zero, one or more sequences of interest containing hashable values. We ignore missing values as defined by :py:meth:`~biocutils.is_missing_scalar.is_missing_scalar`. duplicate_method: Whether to keep the first or last occurrence of duplicated values when preserving order in the first sequence. Returns: Intersection of values across all ``x``. """ nargs = len(x) if nargs == 0: return [] first = x[0] if nargs == 1: # Special handling of n == 1, for efficiency. present = set() output = [] def handler(f): if not is_missing_scalar(f) and f not in present: output.append(f) present.add(f) if duplicate_method == "first": for f in first: handler(f) else: for f in reversed(first): handler(f) output.reverse() return output # The 'occurrences' dict contains the count and the index of the last # sequence that incremented the count. The intersection consists of all # values where the count == number of sequences. We need to store the index # of the last sequence so as to avoid adding a duplicate value twice from a # single sequence. occurrences = {} for f in first: if not is_missing_scalar(f) and f not in occurrences: occurrences[f] = [1, 0] for i in range(1, nargs): for f in x[i]: if not is_missing_scalar(f) and f in occurrences: state = occurrences[f] if state[1] < i: state[0] += 1 state[1] = i # Going through the first vector again to preserve order. output = [] def handler(f): if not is_missing_scalar(f) and f in occurrences: state = occurrences[f] if state[0] == nargs and state[1] >= 0: output.append(f) state[1] = -1 # avoid duplicates if duplicate_method == "first": for f in first: handler(f) else: for f in reversed(first): handler(f) output.reverse() return output