from __future__ import annotations

from typing import TYPE_CHECKING

from polars import functions as F
from polars._utils.deprecation import deprecated
from polars._utils.parse import parse_into_list_of_expressions
from polars._utils.wrap import wrap_df, wrap_ldf

if TYPE_CHECKING:
    import sys
    from collections.abc import Callable, Iterable

    from polars import DataFrame, LazyFrame
    from polars._plr import PyLazyGroupBy
    from polars._typing import IntoExpr, QuantileMethod, SchemaDict

    if sys.version_info >= (3, 13):
        from warnings import deprecated
    else:
        from typing_extensions import deprecated  # noqa: TC004


class LazyGroupBy:
    """
    Utility class for performing a group by operation over a lazy DataFrame.

    Generated by calling `df.lazy().group_by(...)`.
    """

    def __init__(self, lgb: PyLazyGroupBy) -> None:
        self.lgb = lgb

    def having(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> LazyGroupBy:
        """
        Filter groups with a list of predicates after aggregation.

        Using this method is equivalent to adding the predicates to the aggregation and
        filtering afterwards.

        This method can be chained and all conditions will be combined using `&`.

        Parameters
        ----------
        *predicates
            Expressions that evaluate to a boolean value for each group. Typically, this
            requires the use of an aggregation function. Multiple predicates are
            combined using `&`.

        Examples
        --------
        Only keep groups that contain more than one element.

        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": ["a", "b", "a", "b", "c"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("a").having(
        ...     pl.len() > 1
        ... ).agg().collect()  # doctest: +IGNORE_RESULT
        shape: (2, 1)
        ┌─────┐
        │ a   │
        │ --- │
        │ str │
        ╞═════╡
        │ b   │
        │ a   │
        └─────┘
        """
        pyexprs = parse_into_list_of_expressions(*predicates)
        self.lgb = self.lgb.having(pyexprs)
        return self

    def agg(
        self,
        *aggs: IntoExpr | Iterable[IntoExpr],
        **named_aggs: IntoExpr,
    ) -> LazyFrame:
        """
        Compute aggregations for each group of a group by operation.

        Parameters
        ----------
        *aggs
            Aggregations to compute for each group of the group by operation,
            specified as positional arguments.
            Accepts expression input. Strings are parsed as column names.
        **named_aggs
            Additional aggregations, specified as keyword arguments.
            The resulting columns will be renamed to the keyword used.

        Examples
        --------
        Compute the aggregation of the columns for each group.

        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": ["a", "b", "a", "b", "c"],
        ...         "b": [1, 2, 1, 3, 3],
        ...         "c": [5, 4, 3, 2, 1],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("a").agg(
        ...     [pl.col("b"), pl.col("c")]
        ... ).collect()  # doctest: +IGNORE_RESULT
        shape: (3, 3)
        ┌─────┬───────────┬───────────┐
        │ a   ┆ b         ┆ c         │
        │ --- ┆ ---       ┆ ---       │
        │ str ┆ list[i64] ┆ list[i64] │
        ╞═════╪═══════════╪═══════════╡
        │ a   ┆ [1, 1]    ┆ [5, 3]    │
        │ b   ┆ [2, 3]    ┆ [4, 2]    │
        │ c   ┆ [3]       ┆ [1]       │
        └─────┴───────────┴───────────┘

        Compute the sum of a column for each group.

        >>> ldf.group_by("a").agg(
        ...     pl.col("b").sum()
        ... ).collect()  # doctest: +IGNORE_RESULT
        shape: (3, 2)
        ┌─────┬─────┐
        │ a   ┆ b   │
        │ --- ┆ --- │
        │ str ┆ i64 │
        ╞═════╪═════╡
        │ a   ┆ 2   │
        │ b   ┆ 5   │
        │ c   ┆ 3   │
        └─────┴─────┘

        Compute multiple aggregates at once by passing a list of expressions.

        >>> ldf.group_by("a").agg(
        ...     [pl.sum("b"), pl.mean("c")]
        ... ).collect()  # doctest: +IGNORE_RESULT
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ str ┆ i64 ┆ f64 │
        ╞═════╪═════╪═════╡
        │ c   ┆ 3   ┆ 1.0 │
        │ a   ┆ 2   ┆ 4.0 │
        │ b   ┆ 5   ┆ 3.0 │
        └─────┴─────┴─────┘

        Or use positional arguments to compute multiple aggregations in the same way.

        >>> ldf.group_by("a").agg(
        ...     pl.sum("b").name.suffix("_sum"),
        ...     (pl.col("c") ** 2).mean().name.suffix("_mean_squared"),
        ... ).collect()  # doctest: +IGNORE_RESULT
        shape: (3, 3)
        ┌─────┬───────┬────────────────┐
        │ a   ┆ b_sum ┆ c_mean_squared │
        │ --- ┆ ---   ┆ ---            │
        │ str ┆ i64   ┆ f64            │
        ╞═════╪═══════╪════════════════╡
        │ a   ┆ 2     ┆ 17.0           │
        │ c   ┆ 3     ┆ 1.0            │
        │ b   ┆ 5     ┆ 10.0           │
        └─────┴───────┴────────────────┘

        Use keyword arguments to easily name your expression inputs.

        >>> ldf.group_by("a").agg(
        ...     b_sum=pl.sum("b"),
        ...     c_mean_squared=(pl.col("c") ** 2).mean(),
        ... ).collect()  # doctest: +IGNORE_RESULT
        shape: (3, 3)
        ┌─────┬───────┬────────────────┐
        │ a   ┆ b_sum ┆ c_mean_squared │
        │ --- ┆ ---   ┆ ---            │
        │ str ┆ i64   ┆ f64            │
        ╞═════╪═══════╪════════════════╡
        │ a   ┆ 2     ┆ 17.0           │
        │ c   ┆ 3     ┆ 1.0            │
        │ b   ┆ 5     ┆ 10.0           │
        └─────┴───────┴────────────────┘
        """
        if aggs and isinstance(aggs[0], dict):
            msg = (
                "specifying aggregations as a dictionary is not supported"
                "\n\nTry unpacking the dictionary to take advantage of the keyword syntax"
                " of the `agg` method."
            )
            raise TypeError(msg)

        pyexprs = parse_into_list_of_expressions(*aggs, **named_aggs)
        return wrap_ldf(self.lgb.agg(pyexprs))

    def map_groups(
        self,
        function: Callable[[DataFrame], DataFrame],
        schema: SchemaDict | None,
    ) -> LazyFrame:
        """
        Apply a custom/user-defined function (UDF) over the groups as a new DataFrame.

        .. warning::
            This method is much slower than the native expressions API.
            Only use it if you cannot implement your logic otherwise.

        Using this is considered an anti-pattern as it will be very slow because:

        - it forces the engine to materialize the whole `DataFrames` for the groups.
        - it is not parallelized
        - it blocks optimizations as the passed python function is opaque to the
          optimizer

        The idiomatic way to apply custom functions over multiple columns is using:

        `pl.struct([my_columns]).map_elements(lambda struct_series: ..)`

        Parameters
        ----------
        function
            Function to apply over each group of the `LazyFrame`.
        schema
            Schema of the output function. This has to be known statically. If the
            given schema is incorrect, this is a bug in the caller's query and may
            lead to errors. If set to None, polars assumes the schema is unchanged.

        Examples
        --------
        For each color group sample two rows:

        >>> df = pl.DataFrame(
        ...     {
        ...         "id": [0, 1, 2, 3, 4],
        ...         "color": ["red", "green", "green", "red", "red"],
        ...         "shape": ["square", "triangle", "square", "triangle", "square"],
        ...     }
        ... )
        >>> (
        ...     df.lazy()
        ...     .group_by("color")
        ...     .map_groups(lambda group_df: group_df.sample(2), schema=None)
        ...     .collect()
        ... )  # doctest: +IGNORE_RESULT
        shape: (4, 3)
        ┌─────┬───────┬──────────┐
        │ id  ┆ color ┆ shape    │
        │ --- ┆ ---   ┆ ---      │
        │ i64 ┆ str   ┆ str      │
        ╞═════╪═══════╪══════════╡
        │ 1   ┆ green ┆ triangle │
        │ 2   ┆ green ┆ square   │
        │ 4   ┆ red   ┆ square   │
        │ 3   ┆ red   ┆ triangle │
        └─────┴───────┴──────────┘

        It is better to implement this with an expression:

        >>> df.lazy().filter(
        ...     pl.int_range(pl.len()).shuffle().over("color") < 2
        ... ).collect()  # doctest: +IGNORE_RESULT
        """
        return wrap_ldf(
            self.lgb.map_groups(lambda df: function(wrap_df(df))._df, schema)
        )

    def head(self, n: int = 5) -> LazyFrame:
        """
        Get the first `n` rows of each group.

        Parameters
        ----------
        n
            Number of rows to return.

        Examples
        --------
        >>> df = pl.DataFrame(
        ...     {
        ...         "letters": ["c", "c", "a", "c", "a", "b"],
        ...         "nrs": [1, 2, 3, 4, 5, 6],
        ...     }
        ... )
        >>> df
        shape: (6, 2)
        ┌─────────┬─────┐
        │ letters ┆ nrs │
        │ ---     ┆ --- │
        │ str     ┆ i64 │
        ╞═════════╪═════╡
        │ c       ┆ 1   │
        │ c       ┆ 2   │
        │ a       ┆ 3   │
        │ c       ┆ 4   │
        │ a       ┆ 5   │
        │ b       ┆ 6   │
        └─────────┴─────┘
        >>> df.group_by("letters").head(2).sort("letters")
        shape: (5, 2)
        ┌─────────┬─────┐
        │ letters ┆ nrs │
        │ ---     ┆ --- │
        │ str     ┆ i64 │
        ╞═════════╪═════╡
        │ a       ┆ 3   │
        │ a       ┆ 5   │
        │ b       ┆ 6   │
        │ c       ┆ 1   │
        │ c       ┆ 2   │
        └─────────┴─────┘
        """
        return wrap_ldf(self.lgb.head(n))

    def tail(self, n: int = 5) -> LazyFrame:
        """
        Get the last `n` rows of each group.

        Parameters
        ----------
        n
            Number of rows to return.

        Examples
        --------
        >>> df = pl.DataFrame(
        ...     {
        ...         "letters": ["c", "c", "a", "c", "a", "b"],
        ...         "nrs": [1, 2, 3, 4, 5, 6],
        ...     }
        ... )
        >>> df
        shape: (6, 2)
        ┌─────────┬─────┐
        │ letters ┆ nrs │
        │ ---     ┆ --- │
        │ str     ┆ i64 │
        ╞═════════╪═════╡
        │ c       ┆ 1   │
        │ c       ┆ 2   │
        │ a       ┆ 3   │
        │ c       ┆ 4   │
        │ a       ┆ 5   │
        │ b       ┆ 6   │
        └─────────┴─────┘
        >>> df.group_by("letters").tail(2).sort("letters")
         shape: (5, 2)
        ┌─────────┬─────┐
        │ letters ┆ nrs │
        │ ---     ┆ --- │
        │ str     ┆ i64 │
        ╞═════════╪═════╡
        │ a       ┆ 3   │
        │ a       ┆ 5   │
        │ b       ┆ 6   │
        │ c       ┆ 2   │
        │ c       ┆ 4   │
        └─────────┴─────┘
        """
        return wrap_ldf(self.lgb.tail(n))

    def all(self) -> LazyFrame:
        """
        Aggregate the groups into Series.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": ["one", "two", "one", "two"],
        ...         "b": [1, 2, 3, 4],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("a", maintain_order=True).all().collect()
        shape: (2, 2)
        ┌─────┬───────────┐
        │ a   ┆ b         │
        │ --- ┆ ---       │
        │ str ┆ list[i64] │
        ╞═════╪═══════════╡
        │ one ┆ [1, 3]    │
        │ two ┆ [2, 4]    │
        └─────┴───────────┘
        """
        return self.agg(F.all())

    def len(self, name: str | None = None) -> LazyFrame:
        """
        Return the number of rows in each group.

        Parameters
        ----------
        name
            Assign a name to the resulting column; if unset, defaults to "len".

        Examples
        --------
        >>> lf = pl.LazyFrame({"a": ["Apple", "Apple", "Orange"], "b": [1, None, 2]})
        >>> lf.group_by("a").len().collect()  # doctest: +IGNORE_RESULT
        shape: (2, 2)
        ┌────────┬─────┐
        │ a      ┆ len │
        │ ---    ┆ --- │
        │ str    ┆ u32 │
        ╞════════╪═════╡
        │ Apple  ┆ 2   │
        │ Orange ┆ 1   │
        └────────┴─────┘
        >>> lf.group_by("a").len(name="n").collect()  # doctest: +IGNORE_RESULT
        shape: (2, 2)
        ┌────────┬─────┐
        │ a      ┆ n   │
        │ ---    ┆ --- │
        │ str    ┆ u32 │
        ╞════════╪═════╡
        │ Apple  ┆ 2   │
        │ Orange ┆ 1   │
        └────────┴─────┘
        """
        len_expr = F.len()
        if name is not None:
            len_expr = len_expr.alias(name)
        return self.agg(len_expr)

    @deprecated("`count` was renamed; use `len` instead")
    def count(self) -> LazyFrame:
        """
        Return the number of rows in each group.

        .. deprecated:: 0.20.5
            This method has been renamed to :func:`LazyGroupBy.len`.

        Rows containing null values count towards the total.

        Examples
        --------
        >>> lf = pl.LazyFrame(
        ...     {
        ...         "a": ["Apple", "Apple", "Orange"],
        ...         "b": [1, None, 2],
        ...     }
        ... )
        >>> lf.group_by("a").count().collect()  # doctest: +SKIP
        shape: (2, 2)
        ┌────────┬───────┐
        │ a      ┆ count │
        │ ---    ┆ ---   │
        │ str    ┆ u32   │
        ╞════════╪═══════╡
        │ Apple  ┆ 2     │
        │ Orange ┆ 1     │
        └────────┴───────┘
        """
        return self.agg(F.len().alias("count"))

    def first(self, *, ignore_nulls: bool = False) -> LazyFrame:
        """
        Aggregate the first values in the group.

        Parameters
        ----------
        ignore_nulls
            Ignore null values (default `False`).
            If set to `True`, the first non-null value for each aggregation is returned,
            otherwise `None` is returned if no non-null value exists.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "c": [None, True, True, False, False, True],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).first().collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬───────┐
        │ d      ┆ a   ┆ b    ┆ c     │
        │ ---    ┆ --- ┆ ---  ┆ ---   │
        │ str    ┆ i64 ┆ f64  ┆ bool  │
        ╞════════╪═════╪══════╪═══════╡
        │ Apple  ┆ 1   ┆ 0.5  ┆ null  │
        │ Orange ┆ 2   ┆ 0.5  ┆ true  │
        │ Banana ┆ 4   ┆ 13.0 ┆ false │
        └────────┴─────┴──────┴───────┘
        >>> ldf.group_by("d", maintain_order=True).first(ignore_nulls=True).collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬───────┐
        │ d      ┆ a   ┆ b    ┆ c     │
        │ ---    ┆ --- ┆ ---  ┆ ---   │
        │ str    ┆ i64 ┆ f64  ┆ bool  │
        ╞════════╪═════╪══════╪═══════╡
        │ Apple  ┆ 1   ┆ 0.5  ┆ true  │
        │ Orange ┆ 2   ┆ 0.5  ┆ true  │
        │ Banana ┆ 4   ┆ 13.0 ┆ false │
        └────────┴─────┴──────┴───────┘
        """
        return self.agg(F.all().first(ignore_nulls=ignore_nulls))

    def last(self, *, ignore_nulls: bool = False) -> LazyFrame:
        """
        Aggregate the last values in the group.

        Parameters
        ----------
        ignore_nulls
            Ignore null values (default `False`).
            If set to `True`, the last non-null value for each aggregation is returned,
            otherwise `None` is returned if no non-null value exists.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 14, 13],
        ...         "c": [True, True, False, None, False, True],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).last().collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬──────┐
        │ d      ┆ a   ┆ b    ┆ c    │
        │ ---    ┆ --- ┆ ---  ┆ ---  │
        │ str    ┆ i64 ┆ f64  ┆ bool │
        ╞════════╪═════╪══════╪══════╡
        │ Apple  ┆ 3   ┆ 10.0 ┆ null │
        │ Orange ┆ 2   ┆ 0.5  ┆ true │
        │ Banana ┆ 5   ┆ 13.0 ┆ true │
        └────────┴─────┴──────┴──────┘
        >>> ldf.group_by("d", maintain_order=True).last(ignore_nulls=True).collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬───────┐
        │ d      ┆ a   ┆ b    ┆ c     │
        │ ---    ┆ --- ┆ ---  ┆ ---   │
        │ str    ┆ i64 ┆ f64  ┆ bool  │
        ╞════════╪═════╪══════╪═══════╡
        │ Apple  ┆ 3   ┆ 10.0 ┆ false │
        │ Orange ┆ 2   ┆ 0.5  ┆ true  │
        │ Banana ┆ 5   ┆ 13.0 ┆ true  │
        └────────┴─────┴──────┴───────┘
        """
        return self.agg(F.all().last(ignore_nulls=ignore_nulls))

    def max(self) -> LazyFrame:
        """
        Reduce the groups to the maximal value.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "c": [True, True, True, False, False, True],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).max().collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬──────┐
        │ d      ┆ a   ┆ b    ┆ c    │
        │ ---    ┆ --- ┆ ---  ┆ ---  │
        │ str    ┆ i64 ┆ f64  ┆ bool │
        ╞════════╪═════╪══════╪══════╡
        │ Apple  ┆ 3   ┆ 10.0 ┆ true │
        │ Orange ┆ 2   ┆ 0.5  ┆ true │
        │ Banana ┆ 5   ┆ 14.0 ┆ true │
        └────────┴─────┴──────┴──────┘
        """
        return self.agg(F.all().max())

    def mean(self) -> LazyFrame:
        """
        Reduce the groups to the mean values.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "c": [True, True, True, False, False, True],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).mean().collect()
        shape: (3, 4)
        ┌────────┬─────┬──────────┬──────────┐
        │ d      ┆ a   ┆ b        ┆ c        │
        │ ---    ┆ --- ┆ ---      ┆ ---      │
        │ str    ┆ f64 ┆ f64      ┆ f64      │
        ╞════════╪═════╪══════════╪══════════╡
        │ Apple  ┆ 2.0 ┆ 4.833333 ┆ 0.666667 │
        │ Orange ┆ 2.0 ┆ 0.5      ┆ 1.0      │
        │ Banana ┆ 4.5 ┆ 13.5     ┆ 0.5      │
        └────────┴─────┴──────────┴──────────┘
        """
        return self.agg(F.all().mean())

    def median(self) -> LazyFrame:
        """
        Return the median per group.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "d": ["Apple", "Banana", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).median().collect()
        shape: (2, 3)
        ┌────────┬─────┬──────┐
        │ d      ┆ a   ┆ b    │
        │ ---    ┆ --- ┆ ---  │
        │ str    ┆ f64 ┆ f64  │
        ╞════════╪═════╪══════╡
        │ Apple  ┆ 2.0 ┆ 4.0  │
        │ Banana ┆ 4.0 ┆ 13.0 │
        └────────┴─────┴──────┘
        """
        return self.agg(F.all().median())

    def min(self) -> LazyFrame:
        """
        Reduce the groups to the minimal value.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "c": [True, True, True, False, False, True],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).min().collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬───────┐
        │ d      ┆ a   ┆ b    ┆ c     │
        │ ---    ┆ --- ┆ ---  ┆ ---   │
        │ str    ┆ i64 ┆ f64  ┆ bool  │
        ╞════════╪═════╪══════╪═══════╡
        │ Apple  ┆ 1   ┆ 0.5  ┆ false │
        │ Orange ┆ 2   ┆ 0.5  ┆ true  │
        │ Banana ┆ 4   ┆ 13.0 ┆ false │
        └────────┴─────┴──────┴───────┘
        """
        return self.agg(F.all().min())

    def n_unique(self) -> LazyFrame:
        """
        Count the unique values per group.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 1, 3, 4, 5],
        ...         "b": [0.5, 0.5, 0.5, 10, 13, 14],
        ...         "d": ["Apple", "Banana", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).n_unique().collect()
        shape: (2, 3)
        ┌────────┬─────┬─────┐
        │ d      ┆ a   ┆ b   │
        │ ---    ┆ --- ┆ --- │
        │ str    ┆ u32 ┆ u32 │
        ╞════════╪═════╪═════╡
        │ Apple  ┆ 2   ┆ 2   │
        │ Banana ┆ 3   ┆ 3   │
        └────────┴─────┴─────┘
        """
        return self.agg(F.all().n_unique())

    def quantile(
        self, quantile: float, interpolation: QuantileMethod = "nearest"
    ) -> LazyFrame:
        """
        Compute the quantile per group.

        Parameters
        ----------
        quantile
            Quantile between 0.0 and 1.0.
        interpolation : {'nearest', 'higher', 'lower', 'midpoint', 'linear', 'equiprobable'}
            Interpolation method.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).quantile(1).collect()
        shape: (3, 3)
        ┌────────┬─────┬──────┐
        │ d      ┆ a   ┆ b    │
        │ ---    ┆ --- ┆ ---  │
        │ str    ┆ f64 ┆ f64  │
        ╞════════╪═════╪══════╡
        │ Apple  ┆ 3.0 ┆ 10.0 │
        │ Orange ┆ 2.0 ┆ 0.5  │
        │ Banana ┆ 5.0 ┆ 14.0 │
        └────────┴─────┴──────┘
        """  # noqa: W505
        return self.agg(F.all().quantile(quantile, interpolation=interpolation))

    def sum(self) -> LazyFrame:
        """
        Reduce the groups to the sum.

        Examples
        --------
        >>> ldf = pl.DataFrame(
        ...     {
        ...         "a": [1, 2, 2, 3, 4, 5],
        ...         "b": [0.5, 0.5, 4, 10, 13, 14],
        ...         "c": [True, True, True, False, False, True],
        ...         "d": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
        ...     }
        ... ).lazy()
        >>> ldf.group_by("d", maintain_order=True).sum().collect()
        shape: (3, 4)
        ┌────────┬─────┬──────┬─────┐
        │ d      ┆ a   ┆ b    ┆ c   │
        │ ---    ┆ --- ┆ ---  ┆ --- │
        │ str    ┆ i64 ┆ f64  ┆ u32 │
        ╞════════╪═════╪══════╪═════╡
        │ Apple  ┆ 6   ┆ 14.5 ┆ 2   │
        │ Orange ┆ 2   ┆ 0.5  ┆ 1   │
        │ Banana ┆ 9   ┆ 27.0 ┆ 1   │
        └────────┴─────┴──────┴─────┘
        """
        return self.agg(F.all().sum())
