# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import pyarrow as pa from pyarrow import compute as pc # UDFs are all tested with a dataset scan pytestmark = pytest.mark.dataset try: import pyarrow.dataset as ds except ImportError: ds = None def mock_udf_context(batch_length=10): from pyarrow._compute import _get_scalar_udf_context return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) class MyError(RuntimeError): pass @pytest.fixture(scope="session") def unary_func_fixture(): """ Register a unary scalar function. """ def unary_function(ctx, x): return pc.call_function("add", [x, 1], memory_pool=ctx.memory_pool) func_name = "y=x+1" unary_doc = {"summary": "add function", "description": "test add function"} pc.register_scalar_function(unary_function, func_name, unary_doc, {"array": pa.int64()}, pa.int64()) return unary_function, func_name @pytest.fixture(scope="session") def binary_func_fixture(): """ Register a binary scalar function. """ def binary_function(ctx, m, x): return pc.call_function("multiply", [m, x], memory_pool=ctx.memory_pool) func_name = "y=mx" binary_doc = {"summary": "y=mx", "description": "find y from y = mx"} pc.register_scalar_function(binary_function, func_name, binary_doc, {"m": pa.int64(), "x": pa.int64(), }, pa.int64()) return binary_function, func_name @pytest.fixture(scope="session") def ternary_func_fixture(): """ Register a ternary scalar function. """ def ternary_function(ctx, m, x, c): mx = pc.call_function("multiply", [m, x], memory_pool=ctx.memory_pool) return pc.call_function("add", [mx, c], memory_pool=ctx.memory_pool) ternary_doc = {"summary": "y=mx+c", "description": "find y from y = mx + c"} func_name = "y=mx+c" pc.register_scalar_function(ternary_function, func_name, ternary_doc, { "array1": pa.int64(), "array2": pa.int64(), "array3": pa.int64(), }, pa.int64()) return ternary_function, func_name @pytest.fixture(scope="session") def varargs_func_fixture(): """ Register a varargs scalar function with at least two arguments. """ def varargs_function(ctx, first, *values): acc = first for val in values: acc = pc.call_function("add", [acc, val], memory_pool=ctx.memory_pool) return acc func_name = "z=ax+by+c" varargs_doc = {"summary": "z=ax+by+c", "description": "find z from z = ax + by + c" } pc.register_scalar_function(varargs_function, func_name, varargs_doc, { "array1": pa.int64(), "array2": pa.int64(), }, pa.int64()) return varargs_function, func_name @pytest.fixture(scope="session") def nullary_func_fixture(): """ Register a nullary scalar function. """ def nullary_func(context): return pa.array([42] * context.batch_length, type=pa.int64(), memory_pool=context.memory_pool) func_doc = { "summary": "random function", "description": "generates a random value" } func_name = "test_nullary_func" pc.register_scalar_function(nullary_func, func_name, func_doc, {}, pa.int64()) return nullary_func, func_name @pytest.fixture(scope="session") def wrong_output_type_func_fixture(): """ Register a scalar function which returns something that is neither a Arrow scalar or array. """ def wrong_output_type(ctx): return 42 func_name = "test_wrong_output_type" in_types = {} out_type = pa.int64() doc = { "summary": "return wrong output type", "description": "" } pc.register_scalar_function(wrong_output_type, func_name, doc, in_types, out_type) return wrong_output_type, func_name @pytest.fixture(scope="session") def wrong_output_datatype_func_fixture(): """ Register a scalar function whose actual output DataType doesn't match the declared output DataType. """ def wrong_output_datatype(ctx, array): return pc.call_function("add", [array, 1]) func_name = "test_wrong_output_datatype" in_types = {"array": pa.int64()} # The actual output DataType will be int64. out_type = pa.int16() doc = { "summary": "return wrong output datatype", "description": "" } pc.register_scalar_function(wrong_output_datatype, func_name, doc, in_types, out_type) return wrong_output_datatype, func_name @pytest.fixture(scope="session") def wrong_signature_func_fixture(): """ Register a scalar function with the wrong signature. """ # Missing the context argument def wrong_signature(): return pa.scalar(1, type=pa.int64()) func_name = "test_wrong_signature" in_types = {} out_type = pa.int64() doc = { "summary": "UDF with wrong signature", "description": "" } pc.register_scalar_function(wrong_signature, func_name, doc, in_types, out_type) return wrong_signature, func_name @pytest.fixture(scope="session") def raising_func_fixture(): """ Register a scalar function which raises a custom exception. """ def raising_func(ctx): raise MyError("error raised by scalar UDF") func_name = "test_raise" doc = { "summary": "raising function", "description": "" } pc.register_scalar_function(raising_func, func_name, doc, {}, pa.int64()) return raising_func, func_name def check_scalar_function(func_fixture, inputs, *, run_in_dataset=True, batch_length=None): function, name = func_fixture if batch_length is None: all_scalar = True for arg in inputs: if isinstance(arg, pa.Array): all_scalar = False batch_length = len(arg) if all_scalar: batch_length = 1 expected_output = function(mock_udf_context(batch_length), *inputs) func = pc.get_function(name) assert func.name == name result = pc.call_function(name, inputs, length=batch_length) assert result == expected_output # At the moment there is an issue when handling nullary functions. # See: ARROW-15286 and ARROW-16290. if run_in_dataset: field_names = [f'field{index}' for index, in_arr in inputs] table = pa.Table.from_arrays(inputs, field_names) dataset = ds.dataset(table) func_args = [ds.field(field_name) for field_name in field_names] result_table = dataset.to_table( columns={'result': ds.field('')._call(name, func_args)}) assert result_table.column(0).chunks[0] == expected_output def test_scalar_udf_array_unary(unary_func_fixture): check_scalar_function(unary_func_fixture, [ pa.array([10, 20], pa.int64()) ] ) def test_scalar_udf_array_binary(binary_func_fixture): check_scalar_function(binary_func_fixture, [ pa.array([10, 20], pa.int64()), pa.array([2, 4], pa.int64()) ] ) def test_scalar_udf_array_ternary(ternary_func_fixture): check_scalar_function(ternary_func_fixture, [ pa.array([10, 20], pa.int64()), pa.array([2, 4], pa.int64()), pa.array([5, 10], pa.int64()) ] ) def test_scalar_udf_array_varargs(varargs_func_fixture): check_scalar_function(varargs_func_fixture, [ pa.array([2, 3], pa.int64()), pa.array([10, 20], pa.int64()), pa.array([3, 7], pa.int64()), pa.array([20, 30], pa.int64()), pa.array([5, 10], pa.int64()) ] ) def test_registration_errors(): # validate function name doc = { "summary": "test udf input", "description": "parameters are validated" } in_types = {"scalar": pa.int64()} out_type = pa.int64() def test_reg_function(context): return pa.array([10]) with pytest.raises(TypeError): pc.register_scalar_function(test_reg_function, None, doc, in_types, out_type) # validate function with pytest.raises(TypeError, match="func must be a callable"): pc.register_scalar_function(None, "test_none_function", doc, in_types, out_type) # validate output type expected_expr = "DataType expected, got " with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_output_function", doc, in_types, None) # validate input type expected_expr = "in_types must be a dictionary of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_input_function", doc, None, out_type) # register an already registered function # first registration pc.register_scalar_function(test_reg_function, "test_reg_function", doc, {}, out_type) # second registration expected_expr = "Already have a function registered with name:" \ + " test_reg_function" with pytest.raises(KeyError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_reg_function", doc, {}, out_type) def test_varargs_function_validation(varargs_func_fixture): _, func_name = varargs_func_fixture error_msg = r"VarArgs function 'z=ax\+by\+c' needs at least 2 arguments" with pytest.raises(ValueError, match=error_msg): pc.call_function(func_name, [42]) def test_function_doc_validation(): # validate arity in_types = {"scalar": pa.int64()} out_type = pa.int64() # doc with no summary func_doc = { "description": "desc" } def add_const(ctx, scalar): return pc.call_function("add", [scalar, 1]) with pytest.raises(ValueError, match="Function doc must contain a summary"): pc.register_scalar_function(add_const, "test_no_summary", func_doc, in_types, out_type) # doc with no decription func_doc = { "summary": "test summary" } with pytest.raises(ValueError, match="Function doc must contain a description"): pc.register_scalar_function(add_const, "test_no_desc", func_doc, in_types, out_type) def test_nullary_function(nullary_func_fixture): # XXX the Python compute layer API doesn't let us override batch_length, # so only test with the default value of 1. check_scalar_function(nullary_func_fixture, [], run_in_dataset=False, batch_length=1) def test_wrong_output_type(wrong_output_type_func_fixture): _, func_name = wrong_output_type_func_fixture with pytest.raises(TypeError, match="Unexpected output type: int"): pc.call_function(func_name, [], length=1) def test_wrong_output_datatype(wrong_output_datatype_func_fixture): _, func_name = wrong_output_datatype_func_fixture expected_expr = ("Expected output datatype int16, " "but function returned datatype int64") with pytest.raises(TypeError, match=expected_expr): pc.call_function(func_name, [pa.array([20, 30])]) def test_wrong_signature(wrong_signature_func_fixture): _, func_name = wrong_signature_func_fixture expected_expr = (r"wrong_signature\(\) takes 0 positional arguments " "but 1 was given") with pytest.raises(TypeError, match=expected_expr): pc.call_function(func_name, [], length=1) def test_wrong_datatype_declaration(): def identity(ctx, val): return val func_name = "test_wrong_datatype_declaration" in_types = {"array": pa.int64()} out_type = {} doc = { "summary": "test output value", "description": "test output" } with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, doc, in_types, out_type) def test_wrong_input_type_declaration(): def identity(ctx, val): return val func_name = "test_wrong_input_type_declaration" in_types = {"array": None} out_type = pa.int64() doc = { "summary": "test invalid input type", "description": "invalid input function" } with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, doc, in_types, out_type) def test_udf_context(unary_func_fixture): # Check the memory_pool argument is properly propagated proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool()) _, func_name = unary_func_fixture res = pc.call_function(func_name, [pa.array([1] * 1000, type=pa.int64())], memory_pool=proxy_pool) assert res == pa.array([2] * 1000, type=pa.int64()) assert proxy_pool.bytes_allocated() == 1000 * 8 # Destroying Python array should destroy underlying C++ memory res = None assert proxy_pool.bytes_allocated() == 0 def test_raising_func(raising_func_fixture): _, func_name = raising_func_fixture with pytest.raises(MyError, match="error raised by scalar UDF"): pc.call_function(func_name, [], length=1) def test_scalar_input(unary_func_fixture): function, func_name = unary_func_fixture res = pc.call_function(func_name, [pa.scalar(10)]) assert res == pa.scalar(11) def test_input_lifetime(unary_func_fixture): function, func_name = unary_func_fixture proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool()) assert proxy_pool.bytes_allocated() == 0 v = pa.array([1] * 1000, type=pa.int64(), memory_pool=proxy_pool) assert proxy_pool.bytes_allocated() == 1000 * 8 pc.call_function(func_name, [v]) assert proxy_pool.bytes_allocated() == 1000 * 8 # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0