跳到内容

折叠

Polars 提供了许多表达式来对列进行计算,例如 sum_horizontalmean_horizontalmin_horizontal。然而,这些都只是称为“折叠(fold)”的通用算法的特例,当 Polars 的专用版本不足时,Polars 提供了一种通用机制来让您计算自定义折叠。

使用 fold 函数计算的折叠操作作用于整个列以获得最大速度。它们能非常有效地利用数据布局,并且通常具有向量化执行能力。

基本示例

作为第一个示例,我们将使用 fold 函数重新实现 sum_horizontal

fold

import operator
import polars as pl

df = pl.DataFrame(
    {
        "label": ["foo", "bar", "spam"],
        "a": [1, 2, 3],
        "b": [10, 20, 30],
    }
)

result = df.select(
    pl.fold(
        acc=pl.lit(0),
        function=operator.add,
        exprs=pl.col("a", "b"),
    ).alias("sum_fold"),
    pl.sum_horizontal(pl.col("a", "b")).alias("sum_horz"),
)

print(result)

fold_exprs

use polars::lazy::dsl::sum_horizontal;
use polars::prelude::*;

let df = df!(
    "label" => ["foo", "bar", "spam"],
    "a" => [1, 2, 3],
    "b" => [10, 20, 30],
)?;

let result = df
    .clone()
    .lazy()
    .select([
        fold_exprs(
            lit(0),
            |acc, val| (&acc + &val).map(Some),
            [col("a"), col("b")],
            false,
            None,
        )
        .alias("sum_fold"),
        sum_horizontal([col("a"), col("b")], true)?.alias("sum_horz"),
    ])
    .collect()?;

println!("{result:?}");

shape: (3, 2)
┌──────────┬──────────┐
│ sum_fold ┆ sum_horz │
│ ---      ┆ ---      │
│ i64      ┆ i64      │
╞══════════╪══════════╡
│ 11       ┆ 11       │
│ 22       ┆ 22       │
│ 33       ┆ 33       │
└──────────┴──────────┘

函数 fold 需要一个函数 f 作为参数 function,并且 f 应该接受两个参数。第一个参数是累积结果(我们将其初始化为零),第二个参数则接受参数 exprs 中列出的表达式的连续值。在本例中,它们是“a”和“b”两列。

下面的代码片段包含第三个显式表达式,它表示了 fold 函数在上面所做的事情。

fold

acc = pl.lit(0)
f = operator.add

result = df.select(
    f(f(acc, pl.col("a")), pl.col("b")),
    pl.fold(acc=acc, function=f, exprs=pl.col("a", "b")).alias("sum_fold"),
)

print(result)

fold_exprs

let acc = lit(0);
let f = |acc: Expr, val: Expr| acc + val;

let result = df
    .clone()
    .lazy()
    .select([
        f(f(acc, col("a")), col("b")),
        fold_exprs(
            lit(0),
            |acc, val| (&acc + &val).map(Some),
            [col("a"), col("b")],
            false,
            None,
        )
        .alias("sum_fold"),
    ])
    .collect()?;

println!("{result:?}");

shape: (3, 2)
┌─────────┬──────────┐
│ literal ┆ sum_fold │
│ ---     ┆ ---      │
│ i64     ┆ i64      │
╞═════════╪══════════╡
│ 11      ┆ 11       │
│ 22      ┆ 22       │
│ 33      ┆ 33       │
└─────────┴──────────┘
Python 中的 fold

大多数编程语言都包含一个高阶函数,用于实现 Polars 中 fold 函数所实现的算法。Polars 的 fold 与 Python 的 functools.reduce 非常相似。您可以在这篇文章中了解更多关于 functools.reduce 的强大功能

初始值 acc

为累加器 acc 选择的初始值通常是(但并非总是)您想要应用的操作的单位元。例如,如果我们想对列进行乘法运算,如果我们的累加器设置为零,我们将无法得到正确的结果。

fold

result = df.select(
    pl.fold(
        acc=pl.lit(0),
        function=operator.mul,
        exprs=pl.col("a", "b"),
    ).alias("prod"),
)

print(result)

fold_exprs

let result = df
    .clone()
    .lazy()
    .select([fold_exprs(
        lit(0),
        |acc, val| (&acc * &val).map(Some),
        [col("a"), col("b")],
        false,
        None,
    )
    .alias("prod")])
    .collect()?;

println!("{result:?}");

shape: (3, 1)
┌──────┐
│ prod │
│ ---  │
│ i64  │
╞══════╡
│ 0    │
│ 0    │
│ 0    │
└──────┘

为了解决这个问题,累加器 acc 应该设置为 1

fold

result = df.select(
    pl.fold(
        acc=pl.lit(1),
        function=operator.mul,
        exprs=pl.col("a", "b"),
    ).alias("prod"),
)

print(result)

fold_exprs

let result = df
    .clone()
    .lazy()
    .select([fold_exprs(
        lit(1),
        |acc, val| (&acc * &val).map(Some),
        [col("a"), col("b")],
        false,
        None,
    )
    .alias("prod")])
    .collect()?;

println!("{result:?}");

shape: (3, 1)
┌──────┐
│ prod │
│ ---  │
│ i64  │
╞══════╡
│ 10   │
│ 40   │
│ 90   │
└──────┘

条件

如果您想在数据框的所有列上应用条件/谓词,使用折叠(fold)是一种非常简洁的表达方式。

fold

df = pl.DataFrame(
    {
        "a": [1, 2, 3],
        "b": [0, 1, 2],
    }
)

result = df.filter(
    pl.fold(
        acc=pl.lit(True),
        function=lambda acc, x: acc & x,
        exprs=pl.all() > 1,
    )
)
print(result)

fold_exprs

let df = df!(
    "a" => [1, 2, 3],
    "b" => [0, 1, 2],
)?;

let result = df
    .clone()
    .lazy()
    .filter(fold_exprs(
        lit(true),
        |acc, val| (&acc & &val).map(Some),
        [col("*").gt(1)],
        false,
        None,
    ))
    .collect()?;

println!("{result:?}");

shape: (1, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 3   ┆ 2   │
└─────┴─────┘

上面的代码片段过滤了所有列都大于 1 的行。

折叠与字符串数据

折叠(Folds)可以用于连接字符串数据。但是,由于中间列的物化,此操作的复杂度将是平方级的。

因此,我们建议为此使用 concat_str 函数。

concat_str

df = pl.DataFrame(
    {
        "a": ["a", "b", "c"],
        "b": [1, 2, 3],
    }
)

result = df.select(pl.concat_str(["a", "b"]))
print(result)

concat_str · 在功能 concat_str 上可用

let df = df!(
    "a" => ["a", "b", "c"],
    "b" => [1, 2, 3],
)?;

let result = df
    .lazy()
    .select([concat_str([col("a"), col("b")], "", false)])
    .collect()?;
println!("{result:?}");

shape: (3, 1)
┌─────┐
│ a   │
│ --- │
│ str │
╞═════╡
│ a1  │
│ b2  │
│ c3  │
└─────┘