IO 插件
除了表达式插件之外,我们还支持IO插件。这些插件允许您将不同的文件格式注册为Polars引擎的来源。由于来源可以通过Arrow FFI零拷贝移动数据,并且可以在返回前生成大量数据块,我们目前决定通过Python来接口IO插件,因为我们认为GIL所需的短暂时间不应导致任何争用。
例如,一个IO来源可以在Rust中读取其数据帧,并且只有在汇合点才能零拷贝地移动数据,而GIL所需的占用时间很短。
用例
如果您有一个Polars不支持的源文件,并且您希望从诸如投影下推、谓词下推、提前停止以及我们流式引擎的支持等优化中受益,那么您就需要IO插件。
示例
那么,让我们编写一个简单、非常糟糕的自定义CSV来源,并将其注册为IO插件。我想强调的是,这是一个非常糟糕的例子,仅用于学习目的。
首先,我们定义一些所需的导入
# Use python for csv parsing.
import csv
import polars as pl
# Used to register a new generator on every instantiation.
from polars.io.plugins import register_io_source
from typing import Iterator
import io
解析模式
Polars中的每个scan
函数都必须能够提供其读取数据的模式。对于这个简单的CSV解析器,我们总是将数据读取为pl.String
。唯一不同的是字段名称和字段数量。
def parse_schema(csv_str: str) -> pl.Schema:
first_line = csv_str.split("\n")[0]
return pl.Schema({k: pl.String for k in first_line.split(",")})
如果我们使用小的CSV文件"a,b,c\n1,2,3"
运行此代码,我们将得到模式:Schema([('a', String), ('b', String), ('c', String)])
。
>>> print(parse_schema("a,b,c\n1,2,3"))
Schema([('a', String), ('b', String), ('c', String)])
编写来源
接下来是实际的来源。为此,我们创建一个外部函数和一个内部函数。外部函数my_scan_csv
是面向用户的函数。此函数将接受文件名和读取来源所需的其他潜在参数。对于CSV文件,这些参数可以是“delimiter”、“quote_char”等。
这个外部函数调用register_io_source
,它接受一个callable
和一个schema
。该模式是完整源文件的Polars模式(独立于投影下推)。
callable
是一个将返回一个生成器(该生成器产生pl.DataFrame
对象)的函数。
此函数的参数是预定义的,并且此函数必须接受
with_columns
被投影的列。如果应用,读取器必须投影这些列
predicate
Polars表达式。读取器必须相应地过滤其行。
n_rows
仅从来源具体化n行。当读取n_rows
时,读取器可以停止。
batch_size
读取器的生成器必须产生的理想批处理大小的提示。
内部函数是IO来源的实际实现,也可以调用Rust/C++或IO插件编写的任何地方。如果您想查看用Rust实现的IO来源,请查看我们的插件仓库。
def my_scan_csv(csv_str: str) -> pl.LazyFrame:
schema = parse_schema(csv_str)
def source_generator(
with_columns: list[str] | None,
predicate: pl.Expr | None,
n_rows: int | None,
batch_size: int | None,
) -> Iterator[pl.DataFrame]:
"""
Generator function that creates the source.
This function will be registered as IO source.
"""
if batch_size is None:
batch_size = 100
# Initialize the reader.
reader = csv.reader(io.StringIO(csv_str), delimiter=',')
# Skip the header.
_ = next(reader)
# Ensure we don't read more rows than requested from the engine
while n_rows is None or n_rows > 0:
if n_rows is not None:
batch_size = min(batch_size, n_rows)
rows = []
for _ in range(batch_size):
try:
row = next(reader)
except StopIteration:
n_rows = 0
break
rows.append(row)
df = pl.from_records(rows, schema=schema, orient="row")
n_rows -= df.height
# If we would make a performant reader, we would not read these
# columns at all.
if with_columns is not None:
df = df.select(with_columns)
# If the source supports predicate pushdown, the expression can be parsed
# to skip rows/groups.
if predicate is not None:
df = df.filter(predicate)
yield df
return register_io_source(io_source=source_generator, schema=schema)
进行一次(非常慢的)尝试
最后我们可以测试我们的来源
csv_str1 = """a,b,c,d
1,2,3,4
9,10,11,2
1,2,3,4
1,122,3,4"""
print(my_scan_csv(csv_str1).collect())
csv_str2 = """a,b
1,2
9,10
1,2
1,122"""
print(my_scan_csv(csv_str2).head(2).collect())
运行上述脚本会将以下输出打印到控制台
shape: (4, 4)
┌─────┬─────┬─────┬─────┐
│ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str │
╞═════╪═════╪═════╪═════╡
│ 1 ┆ 2 ┆ 3 ┆ 4 │
│ 9 ┆ 10 ┆ 11 ┆ 2 │
│ 1 ┆ 2 ┆ 3 ┆ 4 │
│ 1 ┆ 122 ┆ 3 ┆ 4 │
└─────┴─────┴─────┴─────┘
shape: (2, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ str ┆ str │
╞═════╪═════╡
│ 1 ┆ 2 │
│ 9 ┆ 10 │
└─────┴─────┘