跳到内容

IO 插件

除了表达式插件之外,我们还支持IO插件。这些插件允许您将不同的文件格式注册为Polars引擎的来源。由于来源可以通过Arrow FFI零拷贝移动数据,并且可以在返回前生成大量数据块,我们目前决定通过Python来接口IO插件,因为我们认为GIL所需的短暂时间不应导致任何争用。

例如,一个IO来源可以在Rust中读取其数据帧,并且只有在汇合点才能零拷贝地移动数据,而GIL所需的占用时间很短。

用例

如果您有一个Polars不支持的源文件,并且您希望从诸如投影下推、谓词下推、提前停止以及我们流式引擎的支持等优化中受益,那么您就需要IO插件。

示例

那么,让我们编写一个简单、非常糟糕的自定义CSV来源,并将其注册为IO插件。我想强调的是,这是一个非常糟糕的例子,仅用于学习目的。

首先,我们定义一些所需的导入

# Use python for csv parsing.
import csv
import polars as pl
# Used to register a new generator on every instantiation.
from polars.io.plugins import register_io_source
from typing import Iterator
import io

解析模式

Polars中的每个scan函数都必须能够提供其读取数据的模式。对于这个简单的CSV解析器,我们总是将数据读取为pl.String。唯一不同的是字段名称和字段数量。

def parse_schema(csv_str: str) -> pl.Schema:
    first_line = csv_str.split("\n")[0]

    return pl.Schema({k: pl.String for k in first_line.split(",")})

如果我们使用小的CSV文件"a,b,c\n1,2,3"运行此代码,我们将得到模式:Schema([('a', String), ('b', String), ('c', String)])

>>> print(parse_schema("a,b,c\n1,2,3"))
Schema([('a', String), ('b', String), ('c', String)])

编写来源

接下来是实际的来源。为此,我们创建一个外部函数和一个内部函数。外部函数my_scan_csv是面向用户的函数。此函数将接受文件名和读取来源所需的其他潜在参数。对于CSV文件,这些参数可以是“delimiter”、“quote_char”等。

这个外部函数调用register_io_source,它接受一个callable和一个schema。该模式是完整源文件的Polars模式(独立于投影下推)。

callable是一个将返回一个生成器(该生成器产生pl.DataFrame对象)的函数。

此函数的参数是预定义的,并且此函数必须接受

  • with_columns

被投影的列。如果应用,读取器必须投影这些列

  • predicate

Polars表达式。读取器必须相应地过滤其行。

  • n_rows

仅从来源具体化n行。当读取n_rows时,读取器可以停止。

  • batch_size

读取器的生成器必须产生的理想批处理大小的提示。

内部函数是IO来源的实际实现,也可以调用Rust/C++或IO插件编写的任何地方。如果您想查看用Rust实现的IO来源,请查看我们的插件仓库

def my_scan_csv(csv_str: str) -> pl.LazyFrame:
    schema = parse_schema(csv_str)

    def source_generator(
        with_columns: list[str] | None,
        predicate: pl.Expr | None,
        n_rows: int | None,
        batch_size: int | None,
    ) -> Iterator[pl.DataFrame]:
        """
        Generator function that creates the source.
        This function will be registered as IO source.
        """
        if batch_size is None:
            batch_size = 100

        # Initialize the reader.
        reader = csv.reader(io.StringIO(csv_str), delimiter=',')
        # Skip the header.
        _ = next(reader)

        # Ensure we don't read more rows than requested from the engine
        while n_rows is None or n_rows > 0:
            if n_rows is not None:
                batch_size = min(batch_size, n_rows)

            rows = []

            for _ in range(batch_size):
                try:
                    row = next(reader)
                except StopIteration:
                    n_rows = 0
                    break
                rows.append(row)

            df = pl.from_records(rows, schema=schema, orient="row")
            n_rows -= df.height

            # If we would make a performant reader, we would not read these
            # columns at all.
            if with_columns is not None:
                df = df.select(with_columns)

            # If the source supports predicate pushdown, the expression can be parsed
            # to skip rows/groups.
            if predicate is not None:
                df = df.filter(predicate)

            yield df

    return register_io_source(io_source=source_generator, schema=schema)

进行一次(非常慢的)尝试

最后我们可以测试我们的来源

csv_str1 = """a,b,c,d
1,2,3,4
9,10,11,2
1,2,3,4
1,122,3,4"""

print(my_scan_csv(csv_str1).collect())


csv_str2 = """a,b
1,2
9,10
1,2
1,122"""

print(my_scan_csv(csv_str2).head(2).collect())

运行上述脚本会将以下输出打印到控制台

shape: (4, 4)
┌─────┬─────┬─────┬─────┐
│ a   ┆ b   ┆ c   ┆ d   │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str │
╞═════╪═════╪═════╪═════╡
│ 1   ┆ 2   ┆ 3   ┆ 4   │
│ 9   ┆ 10  ┆ 11  ┆ 2   │
│ 1   ┆ 2   ┆ 3   ┆ 4   │
│ 1   ┆ 122 ┆ 3   ┆ 4   │
└─────┴─────┴─────┴─────┘
shape: (2, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ str ┆ str │
╞═════╪═════╡
│ 1   ┆ 2   │
│ 9   ┆ 10  │
└─────┴─────┘

进一步阅读