从 Apache Spark 迁移
基于列的 API 对比 基于行的 API
与 Spark
的 DataFrame
类似于行的集合不同,Polars 的 DataFrame
更接近于列的集合。这意味着你可以在 Polars 中以 Spark
不可能的方式组合列,因为 Spark
保留了每行数据之间的关系。
考虑这个示例数据集
import polars as pl
df = pl.DataFrame({
"foo": ["a", "b", "c", "d", "d"],
"bar": [1, 2, 3, 4, 5],
})
dfs = spark.createDataFrame(
[
("a", 1),
("b", 2),
("c", 3),
("d", 4),
("d", 5),
],
schema=["foo", "bar"],
)
示例 1: 组合 head
和 sum
在 Polars 中,你可以这样写
df.select(
pl.col("foo").sort().head(2),
pl.col("bar").filter(pl.col("foo") == "d").sum()
)
输出
shape: (2, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a ┆ 9 │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ b ┆ 9 │
└─────┴─────┘
列 foo
和 bar
上的表达式是完全独立的。由于 bar
上的表达式返回一个单一值,该值会为 foo
上的表达式输出的每个值重复。但是 a
和 b
与产生和为 9
的数据没有关系。
要在 Spark
中做类似的事情,你需要单独计算总和并将其作为字面量提供。
from pyspark.sql.functions import col, sum, lit
bar_sum = (
dfs
.where(col("foo") == "d")
.groupBy()
.agg(sum(col("bar")))
.take(1)[0][0]
)
(
dfs
.orderBy("foo")
.limit(2)
.withColumn("bar", lit(bar_sum))
.show()
)
输出
+---+---+
|foo|bar|
+---+---+
| a| 9|
| b| 9|
+---+---+
示例 2: 组合两个 head
在 Polars 中,你可以在同一个 DataFrame 上组合两个不同的 head
表达式,前提是它们返回相同数量的值。
df.select(
pl.col("foo").sort().head(2),
pl.col("bar").sort(descending=True).head(2),
)
输出
shape: (3, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a ┆ 5 │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ b ┆ 4 │
└─────┴─────┘
同样,这里的两个 head
表达式是完全独立的,而 a
与 5
、b
与 4
的配对纯粹是表达式输出的两列并置的结果。
要在 Spark
中实现类似的功能,你需要生成一个人工键,以便能够以这种方式连接这些值。
from pyspark.sql import Window
from pyspark.sql.functions import row_number
foo_dfs = (
dfs
.withColumn(
"rownum",
row_number().over(Window.orderBy("foo"))
)
)
bar_dfs = (
dfs
.withColumn(
"rownum",
row_number().over(Window.orderBy(col("bar").desc()))
)
)
(
foo_dfs.alias("foo")
.join(bar_dfs.alias("bar"), on="rownum")
.select("foo.foo", "bar.bar")
.limit(2)
.show()
)
输出
+---+---+
|foo|bar|
+---+---+
| a| 5|
| b| 4|
+---+---+
示例 3: 组合表达式
Polars 允许你非常自由地组合表达式。例如,如果你想找到滞后变量的滚动平均值,你可以组合 shift
和 rolling_mean
,并在一个 over
表达式中评估它们。
df.with_columns(
feature=pl.col('price').shift(7).rolling_mean(7).over('store', order_by='date')
)
然而,在 PySpark 中,这是不允许的。它们允许组合像 F.mean(F.abs("price")).over(window)
这样的表达式,因为 F.abs
是一个逐元素函数,但不能组合 F.mean(F.lag("price", 1)).over(window)
,因为 F.lag
是一个窗口函数。要产生相同的结果,F.lag
和 F.mean
都需要自己的窗口。
from pyspark.sql import Window
from pyspark.sql import functions as F
window = Window().partitionBy("store").orderBy("date")
rolling_window = window.rowsBetween(-6, 0)
(
df.withColumn("lagged_price", F.lag("price", 7).over(window)).withColumn(
"feature",
F.when(
F.count("lagged_price").over(rolling_window) >= 7,
F.mean("lagged_price").over(rolling_window),
),
)
)