Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_flight.py.
import time
import pyarrow as pa
import pyarrow.flight as flight
class Server(flight.FlightServerBase):
def do_put(self, context, descriptor, reader, writer):
time.sleep(1)
raise flight.FlightCancelledError("")
if __name__ == "__main__":
server = Server("grpc://localhost:0")
client = flight.connect(f"grpc://localhost:{server.port}")
schema = pa.schema([])
writer, reader = client.do_put(
flight.FlightDescriptor.for_command(b""), schema)
writer.done_writing()

View File

@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_pandas.py.
from threading import Thread
import pandas as pd
from pyarrow.pandas_compat import _pandas_api
if __name__ == "__main__":
wait = True
num_threads = 10
df = pd.DataFrame()
results = []
def rc():
while wait:
pass
results.append(_pandas_api.is_data_frame(df))
threads = [Thread(target=rc) for _ in range(num_threads)]
for t in threads:
t.start()
wait = False
for t in threads:
t.join()
assert len(results) == num_threads
assert all(results), "`is_data_frame` returned False when given a DataFrame"

View File

@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_schema.py.
import pyarrow as pa
# the types where to_pandas_dtype returns a non-numpy dtype
cases = [
(pa.timestamp('ns', tz='UTC'), "datetime64[ns, UTC]"),
]
for arrow_type, pandas_type in cases:
assert str(arrow_type.to_pandas_dtype()) == pandas_type

View File

@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language=c++
# cython: language_level = 3
from pyarrow.lib cimport *
from pyarrow.lib import frombytes, tobytes
# basic test to roundtrip through a BoundFunction
ctypedef CStatus visit_string_cb(const c_string&)
cdef extern from * namespace "arrow::py" nogil:
"""
#include <functional>
#include <string>
#include <vector>
#include "arrow/status.h"
namespace arrow {
namespace py {
Status VisitStrings(const std::vector<std::string>& strs,
std::function<Status(const std::string&)> cb) {
for (const std::string& str : strs) {
RETURN_NOT_OK(cb(str));
}
return Status::OK();
}
} // namespace py
} // namespace arrow
"""
cdef CStatus CVisitStrings" arrow::py::VisitStrings"(
vector[c_string], function[visit_string_cb])
cdef void _visit_strings_impl(py_cb, const c_string& s) except *:
py_cb(frombytes(s))
def _visit_strings(strings, cb):
cdef:
function[visit_string_cb] c_cb
vector[c_string] c_strings
c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb)
for s in strings:
c_strings.push_back(tobytes(s))
check_status(CVisitStrings(c_strings, c_cb))

View File

@@ -0,0 +1,330 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import os
import pathlib
import subprocess
import sys
import time
import urllib.request
import pytest
import hypothesis as h
from ..conftest import groups, defaults
from pyarrow import set_timezone_db_path
from pyarrow.util import find_free_port
# setup hypothesis profiles
h.settings.register_profile('ci', max_examples=1000)
h.settings.register_profile('dev', max_examples=50)
h.settings.register_profile('debug', max_examples=10,
verbosity=h.Verbosity.verbose)
# load default hypothesis profile, either set HYPOTHESIS_PROFILE environment
# variable or pass --hypothesis-profile option to pytest, to see the generated
# examples try:
# pytest pyarrow -sv --enable-hypothesis --hypothesis-profile=debug
h.settings.load_profile(os.environ.get('HYPOTHESIS_PROFILE', 'dev'))
# Set this at the beginning before the AWS SDK was loaded to avoid reading in
# user configuration values.
os.environ['AWS_CONFIG_FILE'] = "/dev/null"
if sys.platform == 'win32':
tzdata_set_path = os.environ.get('PYARROW_TZDATA_PATH', None)
if tzdata_set_path:
set_timezone_db_path(tzdata_set_path)
# GH-45295: For ORC, try to populate TZDIR env var from tzdata package resource
# path.
#
# Note this is a different kind of database than what we allow to be set by
# `PYARROW_TZDATA_PATH` and passed to set_timezone_db_path.
if sys.platform == 'win32':
if os.environ.get('TZDIR', None) is None:
from importlib import resources
try:
os.environ['TZDIR'] = os.path.join(resources.files('tzdata'), 'zoneinfo')
except ModuleNotFoundError:
print(
'Package "tzdata" not found. Not setting TZDIR environment variable.'
)
def pytest_addoption(parser):
# Create options to selectively enable test groups
def bool_env(name, default=None):
value = os.environ.get(name.upper())
if not value: # missing or empty
return default
value = value.lower()
if value in {'1', 'true', 'on', 'yes', 'y'}:
return True
elif value in {'0', 'false', 'off', 'no', 'n'}:
return False
else:
raise ValueError(f'{name.upper()}={value} is not parsable as boolean')
for group in groups:
default = bool_env(f'PYARROW_TEST_{group}', defaults[group])
parser.addoption(f'--enable-{group}',
action='store_true', default=default,
help=(f'Enable the {group} test group'))
parser.addoption(f'--disable-{group}',
action='store_true', default=False,
help=(f'Disable the {group} test group'))
class PyArrowConfig:
def __init__(self):
self.is_enabled = {}
def apply_mark(self, mark):
group = mark.name
if group in groups:
self.requires(group)
def requires(self, group):
if not self.is_enabled[group]:
pytest.skip(f'{group} NOT enabled')
def pytest_configure(config):
# Apply command-line options to initialize PyArrow-specific config object
config.pyarrow = PyArrowConfig()
for mark in groups:
config.addinivalue_line(
"markers", mark,
)
enable_flag = f'--enable-{mark}'
disable_flag = f'--disable-{mark}'
is_enabled = (config.getoption(enable_flag) and not
config.getoption(disable_flag))
config.pyarrow.is_enabled[mark] = is_enabled
def pytest_runtest_setup(item):
# Apply test markers to skip tests selectively
for mark in item.iter_markers():
item.config.pyarrow.apply_mark(mark)
@pytest.fixture
def tempdir(tmpdir):
# convert pytest's LocalPath to pathlib.Path
return pathlib.Path(tmpdir.strpath)
@pytest.fixture(scope='session')
def base_datadir():
return pathlib.Path(__file__).parent / 'data'
@pytest.fixture(autouse=True)
def disable_aws_metadata(monkeypatch):
"""Stop the AWS SDK from trying to contact the EC2 metadata server.
Otherwise, this causes a 5 second delay in tests that exercise the
S3 filesystem.
"""
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
# TODO(kszucs): move the following fixtures to test_fs.py once the previous
# parquet dataset implementation and hdfs implementation are removed.
@pytest.fixture(scope='session')
def hdfs_connection():
host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default')
port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
user = os.environ.get('ARROW_HDFS_TEST_USER', 'hdfs')
return host, port, user
@pytest.fixture(scope='session')
def s3_connection():
host, port = '127.0.0.1', find_free_port()
access_key, secret_key = 'arrow', 'apachearrow'
return host, port, access_key, secret_key
def retry(attempts=3, delay=1.0, max_delay=None, backoff=1):
"""
Retry decorator
Parameters
----------
attempts : int, default 3
The number of attempts.
delay : float, default 1
Initial delay in seconds.
max_delay : float, optional
The max delay between attempts.
backoff : float, default 1
The multiplier to delay after each attempt.
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remaining_attempts = attempts
curr_delay = delay
while remaining_attempts > 0:
try:
return func(*args, **kwargs)
except Exception as err:
remaining_attempts -= 1
last_exception = err
curr_delay *= backoff
if max_delay:
curr_delay = min(curr_delay, max_delay)
time.sleep(curr_delay)
raise last_exception
return wrapper
return decorate
@pytest.fixture(scope='session')
def s3_server(s3_connection, tmpdir_factory):
@retry(attempts=5, delay=1, backoff=2)
def minio_server_health_check(address):
resp = urllib.request.urlopen(f"http://{address}/minio/health/live")
assert resp.getcode() == 200
tmpdir = tmpdir_factory.getbasetemp()
host, port, access_key, secret_key = s3_connection
address = f'{host}:{port}'
env = os.environ.copy()
env.update({
'MINIO_ACCESS_KEY': access_key,
'MINIO_SECRET_KEY': secret_key
})
args = ['minio', '--compat', 'server', '--quiet', '--address',
address, tmpdir]
proc = None
try:
proc = subprocess.Popen(args, env=env)
except OSError:
pytest.skip('`minio` command cannot be located')
else:
# Wait for the server to startup before yielding
minio_server_health_check(address)
yield {
'connection': s3_connection,
'process': proc,
'tempdir': tmpdir
}
finally:
if proc is not None:
proc.kill()
proc.wait()
@pytest.fixture(scope='session')
def gcs_server():
port = find_free_port()
env = os.environ.copy()
exe = 'storage-testbench'
args = [exe, '--port', str(port)]
proc = None
try:
# start server
proc = subprocess.Popen(args, env=env)
# Make sure the server is alive.
if proc.poll() is not None:
pytest.skip(f"Command {args} did not start server successfully!")
except OSError as e:
pytest.skip(f"Command {args} failed to execute: {e}")
else:
yield {
'connection': ('localhost', port),
'process': proc,
}
finally:
if proc is not None:
proc.kill()
proc.wait()
@pytest.fixture(scope='session')
def azure_server(tmpdir_factory):
port = find_free_port()
env = os.environ.copy()
tmpdir = tmpdir_factory.getbasetemp()
# We only need blob service emulator, not queue or table.
args = ['azurite-blob', "--location", tmpdir, "--blobPort", str(port)]
# For old Azurite. We can't install the latest Azurite with old
# Node.js on old Ubuntu.
args += ["--skipApiVersionCheck"]
proc = None
try:
proc = subprocess.Popen(args, env=env)
# Make sure the server is alive.
if proc.poll() is not None:
pytest.skip(f"Command {args} did not start server successfully!")
except (ModuleNotFoundError, OSError) as e:
pytest.skip(f"Command {args} failed to execute: {e}")
else:
yield {
# Use the standard azurite account_name and account_key.
# https://learn.microsoft.com/en-us/azure/storage/common/storage-use-emulator#authorize-with-shared-key-credentials
'connection': ('127.0.0.1', port, 'devstoreaccount1',
'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2'
'UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=='),
'process': proc,
'tempdir': tmpdir,
}
finally:
if proc is not None:
proc.kill()
proc.wait()
@pytest.fixture(
params=[
'builtin_pickle',
'cloudpickle'
],
scope='session'
)
def pickle_module(request):
return request.getfixturevalue(request.param)
@pytest.fixture(scope='session')
def builtin_pickle():
import pickle
return pickle
@pytest.fixture(scope='session')
def cloudpickle():
cp = pytest.importorskip('cloudpickle')
if 'HIGHEST_PROTOCOL' not in cp.__dict__:
cp.HIGHEST_PROTOCOL = cp.DEFAULT_PROTOCOL
return cp

View File

@@ -0,0 +1,22 @@
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
The ORC and JSON files come from the `examples` directory in the Apache ORC
source tree:
https://github.com/apache/orc/tree/main/examples

View File

@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# distutils: language=c++
# cython: language_level = 3
from pyarrow.lib cimport *
cdef extern from * namespace "arrow::py" nogil:
"""
#include "arrow/status.h"
#include "arrow/extension_type.h"
#include "arrow/json/from_string.h"
namespace arrow {
namespace py {
class UuidArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};
class UuidType : public ExtensionType {
public:
UuidType() : ExtensionType(fixed_size_binary(16)) {}
std::string extension_name() const override { return "example-uuid"; }
bool ExtensionEquals(const ExtensionType& other) const override {
return other.extension_name() == this->extension_name();
}
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
return std::make_shared<ExtensionArray>(data);
}
Result<std::shared_ptr<DataType>> Deserialize(
std::shared_ptr<DataType> storage_type,
const std::string& serialized) const override {
return std::make_shared<UuidType>();
}
std::string Serialize() const override { return ""; }
};
std::shared_ptr<DataType> MakeUuidType() {
return std::make_shared<UuidType>();
}
std::shared_ptr<Array> MakeUuidArray() {
auto uuid_type = MakeUuidType();
auto json = "[\\"abcdefghijklmno0\\", \\"0onmlkjihgfedcba\\"]";
auto result = json::ArrayFromJSONString(fixed_size_binary(16), json);
return ExtensionType::WrapArray(uuid_type, result.ValueOrDie());
}
std::once_flag uuid_registered;
static bool RegisterUuidType() {
std::call_once(uuid_registered, RegisterExtensionType,
std::make_shared<UuidType>());
return true;
}
static auto uuid_type_registered = RegisterUuidType();
} // namespace py
} // namespace arrow
"""
cdef shared_ptr[CDataType] CMakeUuidType" arrow::py::MakeUuidType"()
cdef shared_ptr[CArray] CMakeUuidArray" arrow::py::MakeUuidArray"()
def _make_uuid_type():
return pyarrow_wrap_data_type(CMakeUuidType())
def _make_uuid_array():
return pyarrow_wrap_array(CMakeUuidArray())

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,529 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime as dt
import pyarrow as pa
from pyarrow.vendored.version import Version
import pytest
try:
import numpy as np
except ImportError:
np = None
import pyarrow.interchange as pi
from pyarrow.interchange.column import (
_PyArrowColumn,
ColumnNullType,
DtypeKind,
)
from pyarrow.interchange.from_dataframe import _from_dataframe
try:
import pandas as pd
# import pandas.testing as tm
except ImportError:
pass
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
def test_datetime(unit, tz):
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), None]
table = pa.table({"A": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))})
col = table.__dataframe__().get_column_by_name("A")
assert col.size() == 3
assert col.offset == 0
assert col.null_count == 1
assert col.dtype[0] == DtypeKind.DATETIME
assert col.describe_null == (ColumnNullType.USE_BITMASK, 0)
@pytest.mark.parametrize(
["test_data", "kind"],
[
(["foo", "bar"], 21),
([1.5, 2.5, 3.5], 2),
([1, 2, 3, 4], 0),
],
)
def test_array_to_pyarrowcolumn(test_data, kind):
arr = pa.array(test_data)
arr_column = _PyArrowColumn(arr)
assert arr_column._col == arr
assert arr_column.size() == len(test_data)
assert arr_column.dtype[0] == kind
assert arr_column.num_chunks() == 1
assert arr_column.null_count == 0
assert arr_column.get_buffers()["validity"] is None
assert len(list(arr_column.get_chunks())) == 1
for chunk in arr_column.get_chunks():
assert chunk == arr_column
def test_offset_of_sliced_array():
arr = pa.array([1, 2, 3, 4])
arr_sliced = arr.slice(2, 2)
table = pa.table([arr], names=["arr"])
table_sliced = pa.table([arr_sliced], names=["arr_sliced"])
col = table_sliced.__dataframe__().get_column(0)
assert col.offset == 2
result = _from_dataframe(table_sliced.__dataframe__())
assert table_sliced.equals(result)
assert not table.equals(result)
# pandas hardcodes offset to 0:
# https://github.com/pandas-dev/pandas/blob/5c66e65d7b9fef47ccb585ce2fd0b3ea18dc82ea/pandas/core/interchange/from_dataframe.py#L247
# so conversion to pandas can't be tested currently
# df = pandas_from_dataframe(table)
# df_sliced = pandas_from_dataframe(table_sliced)
# tm.assert_series_equal(df["arr"][2:4], df_sliced["arr_sliced"],
# check_index=False, check_names=False)
@pytest.mark.pandas
@pytest.mark.parametrize(
"uint", [pa.uint8(), pa.uint16(), pa.uint32()]
)
@pytest.mark.parametrize(
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
)
@pytest.mark.parametrize(
"float, np_float_str", [
# (pa.float16(), np.float16), #not supported by pandas
(pa.float32(), "float32"),
(pa.float64(), "float64")
]
)
def test_pandas_roundtrip(uint, int, float, np_float_str):
if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")
arr = [1, 2, 3]
table = pa.table(
{
"a": pa.array(arr, type=uint),
"b": pa.array(arr, type=int),
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), type=float),
"d": [True, False, True],
}
)
from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)
assert table.equals(result)
table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()
assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()
@pytest.mark.pandas
def test_pandas_roundtrip_string():
# See https://github.com/pandas-dev/pandas/issues/50554
if Version(pd.__version__) < Version("1.6"):
pytest.skip("Column.size() bug in pandas")
arr = ["a", "", "c"]
table = pa.table({"a": pa.array(arr)})
from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)
assert result["a"].to_pylist() == table["a"].to_pylist()
assert pa.types.is_string(table["a"].type)
assert pa.types.is_large_string(result["a"].type)
table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()
assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()
@pytest.mark.pandas
def test_pandas_roundtrip_large_string():
# See https://github.com/pandas-dev/pandas/issues/50554
if Version(pd.__version__) < Version("1.6"):
pytest.skip("Column.size() bug in pandas")
arr = ["a", "", "c"]
table = pa.table({"a_large": pa.array(arr, type=pa.large_string())})
from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
if Version(pd.__version__) >= Version("2.0.1"):
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)
assert result["a_large"].to_pylist() == table["a_large"].to_pylist()
assert pa.types.is_large_string(table["a_large"].type)
assert pa.types.is_large_string(result["a_large"].type)
table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()
assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()
else:
# large string not supported by pandas implementation for
# older versions of pandas
# https://github.com/pandas-dev/pandas/issues/52795
with pytest.raises(AssertionError):
pandas_from_dataframe(table)
@pytest.mark.pandas
def test_pandas_roundtrip_string_with_missing():
# See https://github.com/pandas-dev/pandas/issues/50554
if Version(pd.__version__) < Version("1.6"):
pytest.skip("Column.size() bug in pandas")
arr = ["a", "", "c", None]
table = pa.table({"a": pa.array(arr),
"a_large": pa.array(arr, type=pa.large_string())})
from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
if Version(pd.__version__) >= Version("2.0.2"):
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)
assert result["a"].to_pylist() == table["a"].to_pylist()
assert pa.types.is_string(table["a"].type)
assert pa.types.is_large_string(result["a"].type)
assert result["a_large"].to_pylist() == table["a_large"].to_pylist()
assert pa.types.is_large_string(table["a_large"].type)
assert pa.types.is_large_string(result["a_large"].type)
else:
# older versions of pandas do not have bitmask support
# https://github.com/pandas-dev/pandas/issues/49888
with pytest.raises(NotImplementedError):
pandas_from_dataframe(table)
@pytest.mark.pandas
def test_pandas_roundtrip_categorical():
if Version(pd.__version__) < Version("2.0.2"):
pytest.skip("Bitmasks not supported in pandas interchange implementation")
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
table = pa.table(
{"weekday": pa.array(arr).dictionary_encode()}
)
from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)
assert result["weekday"].to_pylist() == table["weekday"].to_pylist()
assert pa.types.is_dictionary(table["weekday"].type)
assert pa.types.is_dictionary(result["weekday"].type)
assert pa.types.is_string(table["weekday"].chunk(0).dictionary.type)
assert pa.types.is_large_string(result["weekday"].chunk(0).dictionary.type)
assert pa.types.is_int32(table["weekday"].chunk(0).indices.type)
assert pa.types.is_int8(result["weekday"].chunk(0).indices.type)
table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()
assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()
col_table = table_protocol.get_column(0)
col_result = result_protocol.get_column(0)
assert col_result.dtype[0] == DtypeKind.CATEGORICAL
assert col_result.dtype[0] == col_table.dtype[0]
assert col_result.size() == col_table.size()
assert col_result.offset == col_table.offset
desc_cat_table = col_result.describe_categorical
desc_cat_result = col_result.describe_categorical
assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"]
assert isinstance(desc_cat_result["categories"]._col, pa.Array)
@pytest.mark.pandas
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
def test_pandas_roundtrip_datetime(unit):
if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")
from datetime import datetime as dt
# timezones not included as they are not yet supported in
# the pandas implementation
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
table = pa.table({"a": pa.array(dt_arr, type=pa.timestamp(unit))})
if Version(pd.__version__) < Version("1.6"):
# pandas < 2.0 always creates datetime64 in "ns"
# resolution
expected = pa.table({"a": pa.array(dt_arr, type=pa.timestamp('ns'))})
else:
expected = table
from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)
assert expected.equals(result)
expected_protocol = expected.__dataframe__()
result_protocol = result.__dataframe__()
assert expected_protocol.num_columns() == result_protocol.num_columns()
assert expected_protocol.num_rows() == result_protocol.num_rows()
assert expected_protocol.num_chunks() == result_protocol.num_chunks()
assert expected_protocol.column_names() == result_protocol.column_names()
@pytest.mark.pandas
@pytest.mark.parametrize(
"np_float_str", ["float32", "float64"]
)
def test_pandas_to_pyarrow_with_missing(np_float_str):
if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")
np_array = np.array([0, np.nan, 2], dtype=np.dtype(np_float_str))
datetime_array = [None, dt(2007, 7, 14), dt(2007, 7, 15)]
df = pd.DataFrame({
# float, ColumnNullType.USE_NAN
"a": np_array,
# ColumnNullType.USE_SENTINEL
"dt": np.array(datetime_array, dtype="datetime64[ns]")
})
expected = pa.table({
"a": pa.array(np_array, from_pandas=True),
"dt": pa.array(datetime_array, type=pa.timestamp("ns"))
})
result = pi.from_dataframe(df)
assert result.equals(expected)
@pytest.mark.pandas
def test_pandas_to_pyarrow_float16_with_missing():
if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")
# np.float16 errors if ps.is_nan is used
# pyarrow.lib.ArrowNotImplementedError: Function 'is_nan' has no kernel
# matching input types (halffloat)
np_array = np.array([0, np.nan, 2], dtype=np.float16)
df = pd.DataFrame({"a": np_array})
with pytest.raises(NotImplementedError):
pi.from_dataframe(df)
@pytest.mark.numpy
@pytest.mark.parametrize(
"uint", [pa.uint8(), pa.uint16(), pa.uint32()]
)
@pytest.mark.parametrize(
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
)
@pytest.mark.parametrize(
"float, np_float_str", [
(pa.float16(), "float16"),
(pa.float32(), "float32"),
(pa.float64(), "float64")
]
)
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
@pytest.mark.parametrize("tz", ['America/New_York', '+07:30', '-04:30'])
@pytest.mark.parametrize("offset, length", [(0, 3), (0, 2), (1, 2), (2, 1)])
def test_pyarrow_roundtrip(uint, int, float, np_float_str,
unit, tz, offset, length):
from datetime import datetime as dt
arr = [1, 2, None]
dt_arr = [dt(2007, 7, 13), None, dt(2007, 7, 15)]
table = pa.table(
{
"a": pa.array(arr, type=uint),
"b": pa.array(arr, type=int),
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)),
type=float, from_pandas=True),
"d": [True, False, True],
"e": [True, False, None],
"f": ["a", None, "c"],
"g": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))
}
)
table = table.slice(offset, length)
result = _from_dataframe(table.__dataframe__())
assert table.equals(result)
table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()
assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()
@pytest.mark.parametrize("offset, length", [(0, 10), (0, 2), (7, 3), (2, 1)])
def test_pyarrow_roundtrip_categorical(offset, length):
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", None, "Sun"]
table = pa.table(
{"weekday": pa.array(arr).dictionary_encode()}
)
table = table.slice(offset, length)
result = _from_dataframe(table.__dataframe__())
assert table.equals(result)
table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()
assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()
col_table = table_protocol.get_column(0)
col_result = result_protocol.get_column(0)
assert col_result.dtype[0] == DtypeKind.CATEGORICAL
assert col_result.dtype[0] == col_table.dtype[0]
assert col_result.size() == col_table.size()
assert col_result.offset == col_table.offset
desc_cat_table = col_table.describe_categorical
desc_cat_result = col_result.describe_categorical
assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"]
assert isinstance(desc_cat_result["categories"]._col, pa.Array)
@pytest.mark.large_memory
def test_pyarrow_roundtrip_large_string():
data = np.array([b'x'*1024]*(3*1024**2), dtype='object') # 3GB bytes data
arr = pa.array(data, type=pa.large_string())
table = pa.table([arr], names=["large_string"])
result = _from_dataframe(table.__dataframe__())
col = result.__dataframe__().get_column(0)
assert col.size() == 3*1024**2
assert pa.types.is_large_string(table[0].type)
assert pa.types.is_large_string(result[0].type)
assert table.equals(result)
def test_nan_as_null():
table = pa.table({"a": [1, 2, 3, 4]})
with pytest.raises(RuntimeError):
table.__dataframe__(nan_as_null=True)
@pytest.mark.pandas
def test_allow_copy_false():
if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")
# Test that an error is raised when a copy is needed
# to create a bitmask
df = pd.DataFrame({"a": [0, 1.0, 2.0]})
with pytest.raises(RuntimeError):
pi.from_dataframe(df, allow_copy=False)
df = pd.DataFrame({
"dt": [None, dt(2007, 7, 14), dt(2007, 7, 15)]
})
with pytest.raises(RuntimeError):
pi.from_dataframe(df, allow_copy=False)
@pytest.mark.pandas
def test_allow_copy_false_bool_categorical():
if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")
# Test that an error is raised for boolean
# and categorical dtype (copy is always made)
df = pd.DataFrame({"a": [None, False, True]})
with pytest.raises(RuntimeError):
pi.from_dataframe(df, allow_copy=False)
df = pd.DataFrame({"a": [True, False, True]})
with pytest.raises(RuntimeError):
pi.from_dataframe(df, allow_copy=False)
df = pd.DataFrame({"weekday": ["a", "b", None]})
df = df.astype("category")
with pytest.raises(RuntimeError):
pi.from_dataframe(df, allow_copy=False)
df = pd.DataFrame({"weekday": ["a", "b", "c"]})
df = df.astype("category")
with pytest.raises(RuntimeError):
pi.from_dataframe(df, allow_copy=False)
def test_empty_dataframe():
schema = pa.schema([('col1', pa.int8())])
df = pa.table([[]], schema=schema)
dfi = df.__dataframe__()
assert pi.from_dataframe(dfi) == df

View File

@@ -0,0 +1,294 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import ctypes
import hypothesis as h
import hypothesis.strategies as st
import pytest
try:
import numpy as np
except ImportError:
np = None
import pyarrow as pa
import pyarrow.tests.strategies as past
all_types = st.deferred(
lambda: (
past.signed_integer_types |
past.unsigned_integer_types |
past.floating_types |
past.bool_type |
past.string_type |
past.large_string_type
)
)
# datetime is tested in test_extra.py
# dictionary is tested in test_categorical()
@pytest.mark.numpy
@h.settings(suppress_health_check=(h.HealthCheck.too_slow,))
@h.given(past.arrays(all_types, size=3))
def test_dtypes(arr):
table = pa.table([arr], names=["a"])
df = table.__dataframe__()
null_count = df.get_column(0).null_count
assert null_count == arr.null_count
assert isinstance(null_count, int)
assert df.get_column(0).size() == 3
assert df.get_column(0).offset == 0
@pytest.mark.numpy
@pytest.mark.parametrize(
"uint, uint_bw",
[
(pa.uint8(), 8),
(pa.uint16(), 16),
(pa.uint32(), 32)
]
)
@pytest.mark.parametrize(
"int, int_bw", [
(pa.int8(), 8),
(pa.int16(), 16),
(pa.int32(), 32),
(pa.int64(), 64)
]
)
@pytest.mark.parametrize(
"float, float_bw, np_float_str", [
(pa.float16(), 16, "float16"),
(pa.float32(), 32, "float32"),
(pa.float64(), 64, "float64")
]
)
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
@pytest.mark.parametrize("use_batch", [False, True])
def test_mixed_dtypes(uint, uint_bw, int, int_bw,
float, float_bw, np_float_str, unit, tz,
use_batch):
from datetime import datetime as dt
arr = [1, 2, 3]
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
table = pa.table(
{
"a": pa.array(arr, type=uint),
"b": pa.array(arr, type=int),
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), type=float),
"d": [True, False, True],
"e": ["a", "", "c"],
"f": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))
}
)
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
# 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT,
# 20 = DtypeKind.BOOL, 21 = DtypeKind.STRING, 22 = DtypeKind.DATETIME
# see DtypeKind class in column.py
columns = {"a": 1, "b": 0, "c": 2, "d": 20, "e": 21, "f": 22}
for column, kind in columns.items():
col = df.get_column_by_name(column)
assert col.null_count == 0
assert col.size() == 3
assert col.offset == 0
assert col.dtype[0] == kind
assert df.get_column_by_name("a").dtype[1] == uint_bw
assert df.get_column_by_name("b").dtype[1] == int_bw
assert df.get_column_by_name("c").dtype[1] == float_bw
def test_na_float():
table = pa.table({"a": [1.0, None, 2.0]})
df = table.__dataframe__()
col = df.get_column_by_name("a")
assert col.null_count == 1
assert isinstance(col.null_count, int)
def test_noncategorical():
table = pa.table({"a": [1, 2, 3]})
df = table.__dataframe__()
col = df.get_column_by_name("a")
with pytest.raises(TypeError, match=".*categorical.*"):
col.describe_categorical
@pytest.mark.parametrize("use_batch", [False, True])
def test_categorical(use_batch):
import pyarrow as pa
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
table = pa.table(
{"weekday": pa.array(arr).dictionary_encode()}
)
if use_batch:
table = table.to_batches()[0]
col = table.__dataframe__().get_column_by_name("weekday")
categorical = col.describe_categorical
assert isinstance(categorical["is_ordered"], bool)
assert isinstance(categorical["is_dictionary"], bool)
@pytest.mark.parametrize("use_batch", [False, True])
def test_dataframe(use_batch):
n = pa.chunked_array([[2, 2, 4], [4, 5, 100]])
a = pa.chunked_array([["Flamingo", "Parrot", "Cow"],
["Horse", "Brittle stars", "Centipede"]])
table = pa.table([n, a], names=['n_legs', 'animals'])
if use_batch:
table = table.combine_chunks().to_batches()[0]
df = table.__dataframe__()
assert df.num_columns() == 2
assert df.num_rows() == 6
if use_batch:
assert df.num_chunks() == 1
else:
assert df.num_chunks() == 2
assert list(df.column_names()) == ['n_legs', 'animals']
assert list(df.select_columns((1,)).column_names()) == list(
df.select_columns_by_name(("animals",)).column_names()
)
@pytest.mark.parametrize("use_batch", [False, True])
@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
def test_df_get_chunks(use_batch, size, n_chunks):
table = pa.table({"x": list(range(size))})
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
chunks = list(df.get_chunks(n_chunks))
assert len(chunks) == n_chunks
assert sum(chunk.num_rows() for chunk in chunks) == size
@pytest.mark.parametrize("use_batch", [False, True])
@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
def test_column_get_chunks(use_batch, size, n_chunks):
table = pa.table({"x": list(range(size))})
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
chunks = list(df.get_column(0).get_chunks(n_chunks))
assert len(chunks) == n_chunks
assert sum(chunk.size() for chunk in chunks) == size
@pytest.mark.pandas
@pytest.mark.parametrize(
"uint", [pa.uint8(), pa.uint16(), pa.uint32()]
)
@pytest.mark.parametrize(
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
)
@pytest.mark.parametrize(
"float, np_float_str", [
(pa.float16(), "float16"),
(pa.float32(), "float32"),
(pa.float64(), "float64")
]
)
@pytest.mark.parametrize("use_batch", [False, True])
def test_get_columns(uint, int, float, np_float_str, use_batch):
arr = [[1, 2, 3], [4, 5]]
arr_float = np.array([1, 2, 3, 4, 5], dtype=np.dtype(np_float_str))
table = pa.table(
{
"a": pa.chunked_array(arr, type=uint),
"b": pa.chunked_array(arr, type=int),
"c": pa.array(arr_float, type=float)
}
)
if use_batch:
table = table.combine_chunks().to_batches()[0]
df = table.__dataframe__()
for col in df.get_columns():
assert col.size() == 5
assert col.num_chunks() == 1
# 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT,
# see DtypeKind class in column.py
assert df.get_column(0).dtype[0] == 1 # UINT
assert df.get_column(1).dtype[0] == 0 # INT
assert df.get_column(2).dtype[0] == 2 # FLOAT
@pytest.mark.parametrize(
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
)
@pytest.mark.parametrize("use_batch", [False, True])
def test_buffer(int, use_batch):
arr = [0, 1, -1]
table = pa.table({"a": pa.array(arr, type=int)})
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
col = df.get_column(0)
buf = col.get_buffers()
dataBuf, dataDtype = buf["data"]
assert dataBuf.bufsize > 0
assert dataBuf.ptr != 0
device, _ = dataBuf.__dlpack_device__()
# 0 = DtypeKind.INT
# see DtypeKind class in column.py
assert dataDtype[0] == 0
if device == 1: # CPU-only as we're going to directly read memory here
bitwidth = dataDtype[1]
ctype = {
8: ctypes.c_int8,
16: ctypes.c_int16,
32: ctypes.c_int32,
64: ctypes.c_int64,
}[bitwidth]
for idx, truth in enumerate(arr):
val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value
assert val == truth, f"Buffer at index {idx} mismatch"
@pytest.mark.parametrize(
"indices_type, bitwidth, f_string", [
(pa.int8(), 8, "c"),
(pa.int16(), 16, "s"),
(pa.int32(), 32, "i"),
(pa.int64(), 64, "l")
]
)
def test_categorical_dtype(indices_type, bitwidth, f_string):
type = pa.dictionary(indices_type, pa.string())
arr = pa.array(["a", "b", None, "d"], type)
table = pa.table({'a': arr})
df = table.__dataframe__()
col = df.get_column(0)
assert col.dtype[0] == 23 # <DtypeKind.CATEGORICAL: 23>
assert col.dtype[1] == bitwidth
assert col.dtype[2] == f_string

View File

@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
from datetime import date, time
import numpy as np
import pandas as pd
import pyarrow as pa
def dataframe_with_arrays(include_index=False):
"""
Dataframe with numpy arrays columns of every possible primitive type.
Returns
-------
df: pandas.DataFrame
schema: pyarrow.Schema
Arrow schema definition that is in line with the constructed df.
"""
dtypes = [('i1', pa.int8()), ('i2', pa.int16()),
('i4', pa.int32()), ('i8', pa.int64()),
('u1', pa.uint8()), ('u2', pa.uint16()),
('u4', pa.uint32()), ('u8', pa.uint64()),
('f4', pa.float32()), ('f8', pa.float64())]
arrays = OrderedDict()
fields = []
for dtype, arrow_dtype in dtypes:
fields.append(pa.field(dtype, pa.list_(arrow_dtype)))
arrays[dtype] = [
np.arange(10, dtype=dtype),
np.arange(5, dtype=dtype),
None,
np.arange(1, dtype=dtype)
]
fields.append(pa.field('str', pa.list_(pa.string())))
arrays['str'] = [
np.array(["1", "ä"], dtype="object"),
None,
np.array(["1"], dtype="object"),
np.array(["1", "2", "3"], dtype="object")
]
fields.append(pa.field('datetime64', pa.list_(pa.timestamp('ms'))))
arrays['datetime64'] = [
np.array(['2007-07-13T01:23:34.123456789',
None,
'2010-08-13T05:46:57.437699912'],
dtype='datetime64[ms]'),
None,
None,
np.array(['2007-07-13T02',
None,
'2010-08-13T05:46:57.437699912'],
dtype='datetime64[ms]'),
]
if include_index:
fields.append(pa.field('__index_level_0__', pa.int64()))
df = pd.DataFrame(arrays)
schema = pa.schema(fields)
return df, schema
def dataframe_with_lists(include_index=False, parquet_compatible=False):
"""
Dataframe with list columns of every possible primitive type.
Returns
-------
df: pandas.DataFrame
schema: pyarrow.Schema
Arrow schema definition that is in line with the constructed df.
parquet_compatible: bool
Exclude types not supported by parquet
"""
arrays = OrderedDict()
fields = []
fields.append(pa.field('int64', pa.list_(pa.int64())))
arrays['int64'] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4],
None,
[],
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 2,
dtype=np.int64)[::2]
]
fields.append(pa.field('double', pa.list_(pa.float64())))
arrays['double'] = [
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.],
None,
[],
np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.] * 2)[::2],
]
fields.append(pa.field('bytes_list', pa.list_(pa.binary())))
arrays['bytes_list'] = [
[b"1", b"f"],
None,
[b"1"],
[b"1", b"2", b"3"],
[],
]
fields.append(pa.field('str_list', pa.list_(pa.string())))
arrays['str_list'] = [
["1", "ä"],
None,
["1"],
["1", "2", "3"],
[],
]
date_data = [
[],
[date(2018, 1, 1), date(2032, 12, 30)],
[date(2000, 6, 7)],
None,
[date(1969, 6, 9), date(1972, 7, 3)]
]
time_data = [
[time(23, 11, 11), time(1, 2, 3), time(23, 59, 59)],
[],
[time(22, 5, 59)],
None,
[time(0, 0, 0), time(18, 0, 2), time(12, 7, 3)]
]
temporal_pairs = [
(pa.date32(), date_data),
(pa.date64(), date_data),
(pa.time32('s'), time_data),
(pa.time32('ms'), time_data),
(pa.time64('us'), time_data)
]
if not parquet_compatible:
temporal_pairs += [
(pa.time64('ns'), time_data),
]
for value_type, data in temporal_pairs:
field_name = f'{value_type}_list'
field_type = pa.list_(value_type)
field = pa.field(field_name, field_type)
fields.append(field)
arrays[field_name] = data
if include_index:
fields.append(pa.field('__index_level_0__', pa.int64()))
df = pd.DataFrame(arrays)
schema = pa.schema(fields)
return df, schema

View File

@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is called from a test in test_pandas.py.
from concurrent.futures import ThreadPoolExecutor
import faulthandler
import sys
import pyarrow as pa
num_threads = 60
timeout = 10 # seconds
def thread_func(i):
pa.array([i]).to_pandas()
def main():
# In case of import deadlock, crash after a finite timeout
faulthandler.dump_traceback_later(timeout, exit=True)
with ThreadPoolExecutor(num_threads) as pool:
assert "pandas" not in sys.modules # pandas is imported lazily
list(pool.map(thread_func, range(num_threads)))
assert "pandas" in sys.modules
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
# Marks all of the tests in this module
# Ignore these with pytest ... -m 'not parquet'
pytestmark = [
pytest.mark.parquet,
]

View File

@@ -0,0 +1,179 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
try:
import numpy as np
except ImportError:
np = None
import pyarrow as pa
from pyarrow.tests import util
def _write_table(table, path, **kwargs):
# So we see the ImportError somewhere
import pyarrow.parquet as pq
from pyarrow.pandas_compat import _pandas_api
if _pandas_api.is_data_frame(table):
table = pa.Table.from_pandas(table)
pq.write_table(table, path, **kwargs)
return table
def _read_table(*args, **kwargs):
import pyarrow.parquet as pq
table = pq.read_table(*args, **kwargs)
table.validate(full=True)
return table
def _roundtrip_table(table, read_table_kwargs=None,
write_table_kwargs=None):
read_table_kwargs = read_table_kwargs or {}
write_table_kwargs = write_table_kwargs or {}
writer = pa.BufferOutputStream()
_write_table(table, writer, **write_table_kwargs)
reader = pa.BufferReader(writer.getvalue())
return _read_table(reader, **read_table_kwargs)
def _check_roundtrip(table, expected=None, read_table_kwargs=None,
**write_table_kwargs):
if expected is None:
expected = table
read_table_kwargs = read_table_kwargs or {}
# intentionally check twice
result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs,
write_table_kwargs=write_table_kwargs)
assert result.schema == expected.schema
assert result.equals(expected)
result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs,
write_table_kwargs=write_table_kwargs)
assert result.schema == expected.schema
assert result.equals(expected)
def _roundtrip_pandas_dataframe(df, write_kwargs):
table = pa.Table.from_pandas(df)
result = _roundtrip_table(
table, write_table_kwargs=write_kwargs)
return result.to_pandas()
def _random_integers(size, dtype):
# We do not generate integers outside the int64 range
platform_int_info = np.iinfo('int_')
iinfo = np.iinfo(dtype)
return np.random.randint(max(iinfo.min, platform_int_info.min),
min(iinfo.max, platform_int_info.max),
size=size, dtype=dtype)
def _range_integers(size, dtype):
return pa.array(np.arange(size, dtype=dtype))
def _test_dict(size=10000, seed=0):
np.random.seed(seed)
return {
'uint8': _random_integers(size, np.uint8),
'uint16': _random_integers(size, np.uint16),
'uint32': _random_integers(size, np.uint32),
'uint64': _random_integers(size, np.uint64),
'int8': _random_integers(size, np.int8),
'int16': _random_integers(size, np.int16),
'int32': _random_integers(size, np.int32),
'int64': _random_integers(size, np.int64),
'float32': np.random.randn(size).astype(np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0,
'strings': [util.rands(10) for i in range(size)],
'all_none': [None] * size,
'all_none_category': [None] * size
}
def _test_dataframe(size=10000, seed=0):
import pandas as pd
df = pd.DataFrame(_test_dict(size, seed))
# TODO(PARQUET-1015)
# df['all_none_category'] = df['all_none_category'].astype('category')
return df
def _test_table(size=10000, seed=0):
return pa.Table.from_pydict(_test_dict(size, seed))
def make_sample_file(table_or_df):
import pyarrow.parquet as pq
if isinstance(table_or_df, pa.Table):
a_table = table_or_df
else:
a_table = pa.Table.from_pandas(table_or_df)
buf = io.BytesIO()
_write_table(a_table, buf, compression='SNAPPY', version='2.6')
buf.seek(0)
return pq.ParquetFile(buf)
def alltypes_sample(size=10000, seed=0, categorical=False):
import pandas as pd
np.random.seed(seed)
arrays = {
'uint8': np.arange(size, dtype=np.uint8),
'uint16': np.arange(size, dtype=np.uint16),
'uint32': np.arange(size, dtype=np.uint32),
'uint64': np.arange(size, dtype=np.uint64),
'int8': np.arange(size, dtype=np.int16),
'int16': np.arange(size, dtype=np.int16),
'int32': np.arange(size, dtype=np.int32),
'int64': np.arange(size, dtype=np.int64),
'float16': np.arange(size, dtype=np.float16),
'float32': np.arange(size, dtype=np.float32),
'float64': np.arange(size, dtype=np.float64),
'bool': np.random.randn(size) > 0,
'datetime_ms': np.arange("2016-01-01T00:00:00.001", size,
dtype='datetime64[ms]'),
'datetime_us': np.arange("2016-01-01T00:00:00.000001", size,
dtype='datetime64[us]'),
'datetime_ns': np.arange("2016-01-01T00:00:00.000000001", size,
dtype='datetime64[ns]'),
'timedelta': np.arange(0, size, dtype="timedelta64[s]"),
'str': pd.Series([str(x) for x in range(size)]),
'empty_str': [''] * size,
'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
'null': [None] * size,
'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)],
}
if categorical:
arrays['str_category'] = arrays['str'].astype('category')
return pd.DataFrame(arrays)

Some files were not shown because too many files have changed in this diff Show More