Initial commit
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,37 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# This file is called from a test in test_flight.py.
|
||||
import time
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.flight as flight
|
||||
|
||||
|
||||
class Server(flight.FlightServerBase):
|
||||
def do_put(self, context, descriptor, reader, writer):
|
||||
time.sleep(1)
|
||||
raise flight.FlightCancelledError("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
server = Server("grpc://localhost:0")
|
||||
client = flight.connect(f"grpc://localhost:{server.port}")
|
||||
schema = pa.schema([])
|
||||
writer, reader = client.do_put(
|
||||
flight.FlightDescriptor.for_command(b""), schema)
|
||||
writer.done_writing()
|
||||
@@ -0,0 +1,47 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# This file is called from a test in test_pandas.py.
|
||||
|
||||
from threading import Thread
|
||||
|
||||
import pandas as pd
|
||||
from pyarrow.pandas_compat import _pandas_api
|
||||
|
||||
if __name__ == "__main__":
|
||||
wait = True
|
||||
num_threads = 10
|
||||
df = pd.DataFrame()
|
||||
results = []
|
||||
|
||||
def rc():
|
||||
while wait:
|
||||
pass
|
||||
results.append(_pandas_api.is_data_frame(df))
|
||||
|
||||
threads = [Thread(target=rc) for _ in range(num_threads)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
wait = False
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results) == num_threads
|
||||
assert all(results), "`is_data_frame` returned False when given a DataFrame"
|
||||
@@ -0,0 +1,30 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# This file is called from a test in test_schema.py.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
# the types where to_pandas_dtype returns a non-numpy dtype
|
||||
cases = [
|
||||
(pa.timestamp('ns', tz='UTC'), "datetime64[ns, UTC]"),
|
||||
]
|
||||
|
||||
|
||||
for arrow_type, pandas_type in cases:
|
||||
assert str(arrow_type.to_pandas_dtype()) == pandas_type
|
||||
@@ -0,0 +1,67 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# distutils: language=c++
|
||||
# cython: language_level = 3
|
||||
|
||||
from pyarrow.lib cimport *
|
||||
from pyarrow.lib import frombytes, tobytes
|
||||
|
||||
# basic test to roundtrip through a BoundFunction
|
||||
|
||||
ctypedef CStatus visit_string_cb(const c_string&)
|
||||
|
||||
cdef extern from * namespace "arrow::py" nogil:
|
||||
"""
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/status.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace py {
|
||||
|
||||
Status VisitStrings(const std::vector<std::string>& strs,
|
||||
std::function<Status(const std::string&)> cb) {
|
||||
for (const std::string& str : strs) {
|
||||
RETURN_NOT_OK(cb(str));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace py
|
||||
} // namespace arrow
|
||||
"""
|
||||
cdef CStatus CVisitStrings" arrow::py::VisitStrings"(
|
||||
vector[c_string], function[visit_string_cb])
|
||||
|
||||
|
||||
cdef void _visit_strings_impl(py_cb, const c_string& s) except *:
|
||||
py_cb(frombytes(s))
|
||||
|
||||
|
||||
def _visit_strings(strings, cb):
|
||||
cdef:
|
||||
function[visit_string_cb] c_cb
|
||||
vector[c_string] c_strings
|
||||
|
||||
c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb)
|
||||
for s in strings:
|
||||
c_strings.push_back(tobytes(s))
|
||||
|
||||
check_status(CVisitStrings(c_strings, c_cb))
|
||||
330
venv/lib/python3.10/site-packages/pyarrow/tests/conftest.py
Normal file
330
venv/lib/python3.10/site-packages/pyarrow/tests/conftest.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
import pytest
|
||||
import hypothesis as h
|
||||
|
||||
from ..conftest import groups, defaults
|
||||
|
||||
from pyarrow import set_timezone_db_path
|
||||
from pyarrow.util import find_free_port
|
||||
|
||||
|
||||
# setup hypothesis profiles
|
||||
h.settings.register_profile('ci', max_examples=1000)
|
||||
h.settings.register_profile('dev', max_examples=50)
|
||||
h.settings.register_profile('debug', max_examples=10,
|
||||
verbosity=h.Verbosity.verbose)
|
||||
|
||||
# load default hypothesis profile, either set HYPOTHESIS_PROFILE environment
|
||||
# variable or pass --hypothesis-profile option to pytest, to see the generated
|
||||
# examples try:
|
||||
# pytest pyarrow -sv --enable-hypothesis --hypothesis-profile=debug
|
||||
h.settings.load_profile(os.environ.get('HYPOTHESIS_PROFILE', 'dev'))
|
||||
|
||||
# Set this at the beginning before the AWS SDK was loaded to avoid reading in
|
||||
# user configuration values.
|
||||
os.environ['AWS_CONFIG_FILE'] = "/dev/null"
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
tzdata_set_path = os.environ.get('PYARROW_TZDATA_PATH', None)
|
||||
if tzdata_set_path:
|
||||
set_timezone_db_path(tzdata_set_path)
|
||||
|
||||
|
||||
# GH-45295: For ORC, try to populate TZDIR env var from tzdata package resource
|
||||
# path.
|
||||
#
|
||||
# Note this is a different kind of database than what we allow to be set by
|
||||
# `PYARROW_TZDATA_PATH` and passed to set_timezone_db_path.
|
||||
if sys.platform == 'win32':
|
||||
if os.environ.get('TZDIR', None) is None:
|
||||
from importlib import resources
|
||||
try:
|
||||
os.environ['TZDIR'] = os.path.join(resources.files('tzdata'), 'zoneinfo')
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
'Package "tzdata" not found. Not setting TZDIR environment variable.'
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
# Create options to selectively enable test groups
|
||||
def bool_env(name, default=None):
|
||||
value = os.environ.get(name.upper())
|
||||
if not value: # missing or empty
|
||||
return default
|
||||
value = value.lower()
|
||||
if value in {'1', 'true', 'on', 'yes', 'y'}:
|
||||
return True
|
||||
elif value in {'0', 'false', 'off', 'no', 'n'}:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f'{name.upper()}={value} is not parsable as boolean')
|
||||
|
||||
for group in groups:
|
||||
default = bool_env(f'PYARROW_TEST_{group}', defaults[group])
|
||||
parser.addoption(f'--enable-{group}',
|
||||
action='store_true', default=default,
|
||||
help=(f'Enable the {group} test group'))
|
||||
parser.addoption(f'--disable-{group}',
|
||||
action='store_true', default=False,
|
||||
help=(f'Disable the {group} test group'))
|
||||
|
||||
|
||||
class PyArrowConfig:
|
||||
def __init__(self):
|
||||
self.is_enabled = {}
|
||||
|
||||
def apply_mark(self, mark):
|
||||
group = mark.name
|
||||
if group in groups:
|
||||
self.requires(group)
|
||||
|
||||
def requires(self, group):
|
||||
if not self.is_enabled[group]:
|
||||
pytest.skip(f'{group} NOT enabled')
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# Apply command-line options to initialize PyArrow-specific config object
|
||||
config.pyarrow = PyArrowConfig()
|
||||
|
||||
for mark in groups:
|
||||
config.addinivalue_line(
|
||||
"markers", mark,
|
||||
)
|
||||
|
||||
enable_flag = f'--enable-{mark}'
|
||||
disable_flag = f'--disable-{mark}'
|
||||
|
||||
is_enabled = (config.getoption(enable_flag) and not
|
||||
config.getoption(disable_flag))
|
||||
config.pyarrow.is_enabled[mark] = is_enabled
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
# Apply test markers to skip tests selectively
|
||||
for mark in item.iter_markers():
|
||||
item.config.pyarrow.apply_mark(mark)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tempdir(tmpdir):
|
||||
# convert pytest's LocalPath to pathlib.Path
|
||||
return pathlib.Path(tmpdir.strpath)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def base_datadir():
|
||||
return pathlib.Path(__file__).parent / 'data'
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_aws_metadata(monkeypatch):
|
||||
"""Stop the AWS SDK from trying to contact the EC2 metadata server.
|
||||
|
||||
Otherwise, this causes a 5 second delay in tests that exercise the
|
||||
S3 filesystem.
|
||||
"""
|
||||
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
|
||||
|
||||
# TODO(kszucs): move the following fixtures to test_fs.py once the previous
|
||||
# parquet dataset implementation and hdfs implementation are removed.
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def hdfs_connection():
|
||||
host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default')
|
||||
port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
|
||||
user = os.environ.get('ARROW_HDFS_TEST_USER', 'hdfs')
|
||||
return host, port, user
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def s3_connection():
|
||||
host, port = '127.0.0.1', find_free_port()
|
||||
access_key, secret_key = 'arrow', 'apachearrow'
|
||||
return host, port, access_key, secret_key
|
||||
|
||||
|
||||
def retry(attempts=3, delay=1.0, max_delay=None, backoff=1):
|
||||
"""
|
||||
Retry decorator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attempts : int, default 3
|
||||
The number of attempts.
|
||||
delay : float, default 1
|
||||
Initial delay in seconds.
|
||||
max_delay : float, optional
|
||||
The max delay between attempts.
|
||||
backoff : float, default 1
|
||||
The multiplier to delay after each attempt.
|
||||
"""
|
||||
def decorate(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
remaining_attempts = attempts
|
||||
curr_delay = delay
|
||||
while remaining_attempts > 0:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
remaining_attempts -= 1
|
||||
last_exception = err
|
||||
curr_delay *= backoff
|
||||
if max_delay:
|
||||
curr_delay = min(curr_delay, max_delay)
|
||||
time.sleep(curr_delay)
|
||||
raise last_exception
|
||||
return wrapper
|
||||
return decorate
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def s3_server(s3_connection, tmpdir_factory):
|
||||
@retry(attempts=5, delay=1, backoff=2)
|
||||
def minio_server_health_check(address):
|
||||
resp = urllib.request.urlopen(f"http://{address}/minio/health/live")
|
||||
assert resp.getcode() == 200
|
||||
|
||||
tmpdir = tmpdir_factory.getbasetemp()
|
||||
host, port, access_key, secret_key = s3_connection
|
||||
|
||||
address = f'{host}:{port}'
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
'MINIO_ACCESS_KEY': access_key,
|
||||
'MINIO_SECRET_KEY': secret_key
|
||||
})
|
||||
|
||||
args = ['minio', '--compat', 'server', '--quiet', '--address',
|
||||
address, tmpdir]
|
||||
proc = None
|
||||
try:
|
||||
proc = subprocess.Popen(args, env=env)
|
||||
except OSError:
|
||||
pytest.skip('`minio` command cannot be located')
|
||||
else:
|
||||
# Wait for the server to startup before yielding
|
||||
minio_server_health_check(address)
|
||||
|
||||
yield {
|
||||
'connection': s3_connection,
|
||||
'process': proc,
|
||||
'tempdir': tmpdir
|
||||
}
|
||||
finally:
|
||||
if proc is not None:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def gcs_server():
|
||||
port = find_free_port()
|
||||
env = os.environ.copy()
|
||||
exe = 'storage-testbench'
|
||||
args = [exe, '--port', str(port)]
|
||||
proc = None
|
||||
try:
|
||||
# start server
|
||||
proc = subprocess.Popen(args, env=env)
|
||||
# Make sure the server is alive.
|
||||
if proc.poll() is not None:
|
||||
pytest.skip(f"Command {args} did not start server successfully!")
|
||||
except OSError as e:
|
||||
pytest.skip(f"Command {args} failed to execute: {e}")
|
||||
else:
|
||||
yield {
|
||||
'connection': ('localhost', port),
|
||||
'process': proc,
|
||||
}
|
||||
finally:
|
||||
if proc is not None:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def azure_server(tmpdir_factory):
|
||||
port = find_free_port()
|
||||
env = os.environ.copy()
|
||||
tmpdir = tmpdir_factory.getbasetemp()
|
||||
# We only need blob service emulator, not queue or table.
|
||||
args = ['azurite-blob', "--location", tmpdir, "--blobPort", str(port)]
|
||||
# For old Azurite. We can't install the latest Azurite with old
|
||||
# Node.js on old Ubuntu.
|
||||
args += ["--skipApiVersionCheck"]
|
||||
proc = None
|
||||
try:
|
||||
proc = subprocess.Popen(args, env=env)
|
||||
# Make sure the server is alive.
|
||||
if proc.poll() is not None:
|
||||
pytest.skip(f"Command {args} did not start server successfully!")
|
||||
except (ModuleNotFoundError, OSError) as e:
|
||||
pytest.skip(f"Command {args} failed to execute: {e}")
|
||||
else:
|
||||
yield {
|
||||
# Use the standard azurite account_name and account_key.
|
||||
# https://learn.microsoft.com/en-us/azure/storage/common/storage-use-emulator#authorize-with-shared-key-credentials
|
||||
'connection': ('127.0.0.1', port, 'devstoreaccount1',
|
||||
'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2'
|
||||
'UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=='),
|
||||
'process': proc,
|
||||
'tempdir': tmpdir,
|
||||
}
|
||||
finally:
|
||||
if proc is not None:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
'builtin_pickle',
|
||||
'cloudpickle'
|
||||
],
|
||||
scope='session'
|
||||
)
|
||||
def pickle_module(request):
|
||||
return request.getfixturevalue(request.param)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def builtin_pickle():
|
||||
import pickle
|
||||
return pickle
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def cloudpickle():
|
||||
cp = pytest.importorskip('cloudpickle')
|
||||
if 'HIGHEST_PROTOCOL' not in cp.__dict__:
|
||||
cp.HIGHEST_PROTOCOL = cp.DEFAULT_PROTOCOL
|
||||
return cp
|
||||
Binary file not shown.
@@ -0,0 +1,22 @@
|
||||
<!---
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
The ORC and JSON files come from the `examples` directory in the Apache ORC
|
||||
source tree:
|
||||
https://github.com/apache/orc/tree/main/examples
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,94 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# distutils: language=c++
|
||||
# cython: language_level = 3
|
||||
|
||||
from pyarrow.lib cimport *
|
||||
|
||||
cdef extern from * namespace "arrow::py" nogil:
|
||||
"""
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/extension_type.h"
|
||||
#include "arrow/json/from_string.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace py {
|
||||
|
||||
class UuidArray : public ExtensionArray {
|
||||
public:
|
||||
using ExtensionArray::ExtensionArray;
|
||||
};
|
||||
|
||||
class UuidType : public ExtensionType {
|
||||
public:
|
||||
UuidType() : ExtensionType(fixed_size_binary(16)) {}
|
||||
std::string extension_name() const override { return "example-uuid"; }
|
||||
|
||||
bool ExtensionEquals(const ExtensionType& other) const override {
|
||||
return other.extension_name() == this->extension_name();
|
||||
}
|
||||
|
||||
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
|
||||
return std::make_shared<ExtensionArray>(data);
|
||||
}
|
||||
|
||||
Result<std::shared_ptr<DataType>> Deserialize(
|
||||
std::shared_ptr<DataType> storage_type,
|
||||
const std::string& serialized) const override {
|
||||
return std::make_shared<UuidType>();
|
||||
}
|
||||
|
||||
std::string Serialize() const override { return ""; }
|
||||
};
|
||||
|
||||
|
||||
std::shared_ptr<DataType> MakeUuidType() {
|
||||
return std::make_shared<UuidType>();
|
||||
}
|
||||
|
||||
std::shared_ptr<Array> MakeUuidArray() {
|
||||
auto uuid_type = MakeUuidType();
|
||||
auto json = "[\\"abcdefghijklmno0\\", \\"0onmlkjihgfedcba\\"]";
|
||||
auto result = json::ArrayFromJSONString(fixed_size_binary(16), json);
|
||||
return ExtensionType::WrapArray(uuid_type, result.ValueOrDie());
|
||||
}
|
||||
|
||||
std::once_flag uuid_registered;
|
||||
|
||||
static bool RegisterUuidType() {
|
||||
std::call_once(uuid_registered, RegisterExtensionType,
|
||||
std::make_shared<UuidType>());
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto uuid_type_registered = RegisterUuidType();
|
||||
|
||||
} // namespace py
|
||||
} // namespace arrow
|
||||
"""
|
||||
|
||||
cdef shared_ptr[CDataType] CMakeUuidType" arrow::py::MakeUuidType"()
|
||||
cdef shared_ptr[CArray] CMakeUuidArray" arrow::py::MakeUuidArray"()
|
||||
|
||||
|
||||
def _make_uuid_type():
|
||||
return pyarrow_wrap_data_type(CMakeUuidType())
|
||||
|
||||
|
||||
def _make_uuid_array():
|
||||
return pyarrow_wrap_array(CMakeUuidArray())
|
||||
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,529 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime as dt
|
||||
import pyarrow as pa
|
||||
from pyarrow.vendored.version import Version
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
import pyarrow.interchange as pi
|
||||
from pyarrow.interchange.column import (
|
||||
_PyArrowColumn,
|
||||
ColumnNullType,
|
||||
DtypeKind,
|
||||
)
|
||||
from pyarrow.interchange.from_dataframe import _from_dataframe
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
# import pandas.testing as tm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
|
||||
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
|
||||
def test_datetime(unit, tz):
|
||||
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), None]
|
||||
table = pa.table({"A": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))})
|
||||
col = table.__dataframe__().get_column_by_name("A")
|
||||
|
||||
assert col.size() == 3
|
||||
assert col.offset == 0
|
||||
assert col.null_count == 1
|
||||
assert col.dtype[0] == DtypeKind.DATETIME
|
||||
assert col.describe_null == (ColumnNullType.USE_BITMASK, 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["test_data", "kind"],
|
||||
[
|
||||
(["foo", "bar"], 21),
|
||||
([1.5, 2.5, 3.5], 2),
|
||||
([1, 2, 3, 4], 0),
|
||||
],
|
||||
)
|
||||
def test_array_to_pyarrowcolumn(test_data, kind):
|
||||
arr = pa.array(test_data)
|
||||
arr_column = _PyArrowColumn(arr)
|
||||
|
||||
assert arr_column._col == arr
|
||||
assert arr_column.size() == len(test_data)
|
||||
assert arr_column.dtype[0] == kind
|
||||
assert arr_column.num_chunks() == 1
|
||||
assert arr_column.null_count == 0
|
||||
assert arr_column.get_buffers()["validity"] is None
|
||||
assert len(list(arr_column.get_chunks())) == 1
|
||||
|
||||
for chunk in arr_column.get_chunks():
|
||||
assert chunk == arr_column
|
||||
|
||||
|
||||
def test_offset_of_sliced_array():
|
||||
arr = pa.array([1, 2, 3, 4])
|
||||
arr_sliced = arr.slice(2, 2)
|
||||
|
||||
table = pa.table([arr], names=["arr"])
|
||||
table_sliced = pa.table([arr_sliced], names=["arr_sliced"])
|
||||
|
||||
col = table_sliced.__dataframe__().get_column(0)
|
||||
assert col.offset == 2
|
||||
|
||||
result = _from_dataframe(table_sliced.__dataframe__())
|
||||
assert table_sliced.equals(result)
|
||||
assert not table.equals(result)
|
||||
|
||||
# pandas hardcodes offset to 0:
|
||||
# https://github.com/pandas-dev/pandas/blob/5c66e65d7b9fef47ccb585ce2fd0b3ea18dc82ea/pandas/core/interchange/from_dataframe.py#L247
|
||||
# so conversion to pandas can't be tested currently
|
||||
|
||||
# df = pandas_from_dataframe(table)
|
||||
# df_sliced = pandas_from_dataframe(table_sliced)
|
||||
|
||||
# tm.assert_series_equal(df["arr"][2:4], df_sliced["arr_sliced"],
|
||||
# check_index=False, check_names=False)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
@pytest.mark.parametrize(
|
||||
"uint", [pa.uint8(), pa.uint16(), pa.uint32()]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"float, np_float_str", [
|
||||
# (pa.float16(), np.float16), #not supported by pandas
|
||||
(pa.float32(), "float32"),
|
||||
(pa.float64(), "float64")
|
||||
]
|
||||
)
|
||||
def test_pandas_roundtrip(uint, int, float, np_float_str):
|
||||
if Version(pd.__version__) < Version("1.5.0"):
|
||||
pytest.skip("__dataframe__ added to pandas in 1.5.0")
|
||||
|
||||
arr = [1, 2, 3]
|
||||
table = pa.table(
|
||||
{
|
||||
"a": pa.array(arr, type=uint),
|
||||
"b": pa.array(arr, type=int),
|
||||
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), type=float),
|
||||
"d": [True, False, True],
|
||||
}
|
||||
)
|
||||
from pandas.api.interchange import (
|
||||
from_dataframe as pandas_from_dataframe
|
||||
)
|
||||
pandas_df = pandas_from_dataframe(table)
|
||||
result = pi.from_dataframe(pandas_df)
|
||||
assert table.equals(result)
|
||||
|
||||
table_protocol = table.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert table_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert table_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert table_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert table_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_pandas_roundtrip_string():
|
||||
# See https://github.com/pandas-dev/pandas/issues/50554
|
||||
if Version(pd.__version__) < Version("1.6"):
|
||||
pytest.skip("Column.size() bug in pandas")
|
||||
|
||||
arr = ["a", "", "c"]
|
||||
table = pa.table({"a": pa.array(arr)})
|
||||
|
||||
from pandas.api.interchange import (
|
||||
from_dataframe as pandas_from_dataframe
|
||||
)
|
||||
|
||||
pandas_df = pandas_from_dataframe(table)
|
||||
result = pi.from_dataframe(pandas_df)
|
||||
|
||||
assert result["a"].to_pylist() == table["a"].to_pylist()
|
||||
assert pa.types.is_string(table["a"].type)
|
||||
assert pa.types.is_large_string(result["a"].type)
|
||||
|
||||
table_protocol = table.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert table_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert table_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert table_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert table_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_pandas_roundtrip_large_string():
|
||||
# See https://github.com/pandas-dev/pandas/issues/50554
|
||||
if Version(pd.__version__) < Version("1.6"):
|
||||
pytest.skip("Column.size() bug in pandas")
|
||||
|
||||
arr = ["a", "", "c"]
|
||||
table = pa.table({"a_large": pa.array(arr, type=pa.large_string())})
|
||||
|
||||
from pandas.api.interchange import (
|
||||
from_dataframe as pandas_from_dataframe
|
||||
)
|
||||
|
||||
if Version(pd.__version__) >= Version("2.0.1"):
|
||||
pandas_df = pandas_from_dataframe(table)
|
||||
result = pi.from_dataframe(pandas_df)
|
||||
|
||||
assert result["a_large"].to_pylist() == table["a_large"].to_pylist()
|
||||
assert pa.types.is_large_string(table["a_large"].type)
|
||||
assert pa.types.is_large_string(result["a_large"].type)
|
||||
|
||||
table_protocol = table.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert table_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert table_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert table_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert table_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
else:
|
||||
# large string not supported by pandas implementation for
|
||||
# older versions of pandas
|
||||
# https://github.com/pandas-dev/pandas/issues/52795
|
||||
with pytest.raises(AssertionError):
|
||||
pandas_from_dataframe(table)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_pandas_roundtrip_string_with_missing():
|
||||
# See https://github.com/pandas-dev/pandas/issues/50554
|
||||
if Version(pd.__version__) < Version("1.6"):
|
||||
pytest.skip("Column.size() bug in pandas")
|
||||
|
||||
arr = ["a", "", "c", None]
|
||||
table = pa.table({"a": pa.array(arr),
|
||||
"a_large": pa.array(arr, type=pa.large_string())})
|
||||
|
||||
from pandas.api.interchange import (
|
||||
from_dataframe as pandas_from_dataframe
|
||||
)
|
||||
|
||||
if Version(pd.__version__) >= Version("2.0.2"):
|
||||
pandas_df = pandas_from_dataframe(table)
|
||||
result = pi.from_dataframe(pandas_df)
|
||||
|
||||
assert result["a"].to_pylist() == table["a"].to_pylist()
|
||||
assert pa.types.is_string(table["a"].type)
|
||||
assert pa.types.is_large_string(result["a"].type)
|
||||
|
||||
assert result["a_large"].to_pylist() == table["a_large"].to_pylist()
|
||||
assert pa.types.is_large_string(table["a_large"].type)
|
||||
assert pa.types.is_large_string(result["a_large"].type)
|
||||
else:
|
||||
# older versions of pandas do not have bitmask support
|
||||
# https://github.com/pandas-dev/pandas/issues/49888
|
||||
with pytest.raises(NotImplementedError):
|
||||
pandas_from_dataframe(table)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_pandas_roundtrip_categorical():
|
||||
if Version(pd.__version__) < Version("2.0.2"):
|
||||
pytest.skip("Bitmasks not supported in pandas interchange implementation")
|
||||
|
||||
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
|
||||
table = pa.table(
|
||||
{"weekday": pa.array(arr).dictionary_encode()}
|
||||
)
|
||||
|
||||
from pandas.api.interchange import (
|
||||
from_dataframe as pandas_from_dataframe
|
||||
)
|
||||
pandas_df = pandas_from_dataframe(table)
|
||||
result = pi.from_dataframe(pandas_df)
|
||||
|
||||
assert result["weekday"].to_pylist() == table["weekday"].to_pylist()
|
||||
assert pa.types.is_dictionary(table["weekday"].type)
|
||||
assert pa.types.is_dictionary(result["weekday"].type)
|
||||
assert pa.types.is_string(table["weekday"].chunk(0).dictionary.type)
|
||||
assert pa.types.is_large_string(result["weekday"].chunk(0).dictionary.type)
|
||||
assert pa.types.is_int32(table["weekday"].chunk(0).indices.type)
|
||||
assert pa.types.is_int8(result["weekday"].chunk(0).indices.type)
|
||||
|
||||
table_protocol = table.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert table_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert table_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert table_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert table_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
col_table = table_protocol.get_column(0)
|
||||
col_result = result_protocol.get_column(0)
|
||||
|
||||
assert col_result.dtype[0] == DtypeKind.CATEGORICAL
|
||||
assert col_result.dtype[0] == col_table.dtype[0]
|
||||
assert col_result.size() == col_table.size()
|
||||
assert col_result.offset == col_table.offset
|
||||
|
||||
desc_cat_table = col_result.describe_categorical
|
||||
desc_cat_result = col_result.describe_categorical
|
||||
|
||||
assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
|
||||
assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"]
|
||||
assert isinstance(desc_cat_result["categories"]._col, pa.Array)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
|
||||
def test_pandas_roundtrip_datetime(unit):
|
||||
if Version(pd.__version__) < Version("1.5.0"):
|
||||
pytest.skip("__dataframe__ added to pandas in 1.5.0")
|
||||
from datetime import datetime as dt
|
||||
|
||||
# timezones not included as they are not yet supported in
|
||||
# the pandas implementation
|
||||
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
|
||||
table = pa.table({"a": pa.array(dt_arr, type=pa.timestamp(unit))})
|
||||
|
||||
if Version(pd.__version__) < Version("1.6"):
|
||||
# pandas < 2.0 always creates datetime64 in "ns"
|
||||
# resolution
|
||||
expected = pa.table({"a": pa.array(dt_arr, type=pa.timestamp('ns'))})
|
||||
else:
|
||||
expected = table
|
||||
|
||||
from pandas.api.interchange import (
|
||||
from_dataframe as pandas_from_dataframe
|
||||
)
|
||||
pandas_df = pandas_from_dataframe(table)
|
||||
result = pi.from_dataframe(pandas_df)
|
||||
|
||||
assert expected.equals(result)
|
||||
|
||||
expected_protocol = expected.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert expected_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert expected_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert expected_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert expected_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
@pytest.mark.parametrize(
|
||||
"np_float_str", ["float32", "float64"]
|
||||
)
|
||||
def test_pandas_to_pyarrow_with_missing(np_float_str):
|
||||
if Version(pd.__version__) < Version("1.5.0"):
|
||||
pytest.skip("__dataframe__ added to pandas in 1.5.0")
|
||||
|
||||
np_array = np.array([0, np.nan, 2], dtype=np.dtype(np_float_str))
|
||||
datetime_array = [None, dt(2007, 7, 14), dt(2007, 7, 15)]
|
||||
df = pd.DataFrame({
|
||||
# float, ColumnNullType.USE_NAN
|
||||
"a": np_array,
|
||||
# ColumnNullType.USE_SENTINEL
|
||||
"dt": np.array(datetime_array, dtype="datetime64[ns]")
|
||||
})
|
||||
expected = pa.table({
|
||||
"a": pa.array(np_array, from_pandas=True),
|
||||
"dt": pa.array(datetime_array, type=pa.timestamp("ns"))
|
||||
})
|
||||
result = pi.from_dataframe(df)
|
||||
|
||||
assert result.equals(expected)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_pandas_to_pyarrow_float16_with_missing():
|
||||
if Version(pd.__version__) < Version("1.5.0"):
|
||||
pytest.skip("__dataframe__ added to pandas in 1.5.0")
|
||||
|
||||
# np.float16 errors if ps.is_nan is used
|
||||
# pyarrow.lib.ArrowNotImplementedError: Function 'is_nan' has no kernel
|
||||
# matching input types (halffloat)
|
||||
np_array = np.array([0, np.nan, 2], dtype=np.float16)
|
||||
df = pd.DataFrame({"a": np_array})
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
pi.from_dataframe(df)
|
||||
|
||||
|
||||
@pytest.mark.numpy
|
||||
@pytest.mark.parametrize(
|
||||
"uint", [pa.uint8(), pa.uint16(), pa.uint32()]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"float, np_float_str", [
|
||||
(pa.float16(), "float16"),
|
||||
(pa.float32(), "float32"),
|
||||
(pa.float64(), "float64")
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
|
||||
@pytest.mark.parametrize("tz", ['America/New_York', '+07:30', '-04:30'])
|
||||
@pytest.mark.parametrize("offset, length", [(0, 3), (0, 2), (1, 2), (2, 1)])
|
||||
def test_pyarrow_roundtrip(uint, int, float, np_float_str,
|
||||
unit, tz, offset, length):
|
||||
|
||||
from datetime import datetime as dt
|
||||
arr = [1, 2, None]
|
||||
dt_arr = [dt(2007, 7, 13), None, dt(2007, 7, 15)]
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"a": pa.array(arr, type=uint),
|
||||
"b": pa.array(arr, type=int),
|
||||
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)),
|
||||
type=float, from_pandas=True),
|
||||
"d": [True, False, True],
|
||||
"e": [True, False, None],
|
||||
"f": ["a", None, "c"],
|
||||
"g": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))
|
||||
}
|
||||
)
|
||||
table = table.slice(offset, length)
|
||||
result = _from_dataframe(table.__dataframe__())
|
||||
|
||||
assert table.equals(result)
|
||||
|
||||
table_protocol = table.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert table_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert table_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert table_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert table_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("offset, length", [(0, 10), (0, 2), (7, 3), (2, 1)])
|
||||
def test_pyarrow_roundtrip_categorical(offset, length):
|
||||
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", None, "Sun"]
|
||||
table = pa.table(
|
||||
{"weekday": pa.array(arr).dictionary_encode()}
|
||||
)
|
||||
table = table.slice(offset, length)
|
||||
result = _from_dataframe(table.__dataframe__())
|
||||
|
||||
assert table.equals(result)
|
||||
|
||||
table_protocol = table.__dataframe__()
|
||||
result_protocol = result.__dataframe__()
|
||||
|
||||
assert table_protocol.num_columns() == result_protocol.num_columns()
|
||||
assert table_protocol.num_rows() == result_protocol.num_rows()
|
||||
assert table_protocol.num_chunks() == result_protocol.num_chunks()
|
||||
assert table_protocol.column_names() == result_protocol.column_names()
|
||||
|
||||
col_table = table_protocol.get_column(0)
|
||||
col_result = result_protocol.get_column(0)
|
||||
|
||||
assert col_result.dtype[0] == DtypeKind.CATEGORICAL
|
||||
assert col_result.dtype[0] == col_table.dtype[0]
|
||||
assert col_result.size() == col_table.size()
|
||||
assert col_result.offset == col_table.offset
|
||||
|
||||
desc_cat_table = col_table.describe_categorical
|
||||
desc_cat_result = col_result.describe_categorical
|
||||
|
||||
assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
|
||||
assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"]
|
||||
assert isinstance(desc_cat_result["categories"]._col, pa.Array)
|
||||
|
||||
|
||||
@pytest.mark.large_memory
|
||||
def test_pyarrow_roundtrip_large_string():
|
||||
|
||||
data = np.array([b'x'*1024]*(3*1024**2), dtype='object') # 3GB bytes data
|
||||
arr = pa.array(data, type=pa.large_string())
|
||||
table = pa.table([arr], names=["large_string"])
|
||||
|
||||
result = _from_dataframe(table.__dataframe__())
|
||||
col = result.__dataframe__().get_column(0)
|
||||
|
||||
assert col.size() == 3*1024**2
|
||||
assert pa.types.is_large_string(table[0].type)
|
||||
assert pa.types.is_large_string(result[0].type)
|
||||
|
||||
assert table.equals(result)
|
||||
|
||||
|
||||
def test_nan_as_null():
|
||||
table = pa.table({"a": [1, 2, 3, 4]})
|
||||
with pytest.raises(RuntimeError):
|
||||
table.__dataframe__(nan_as_null=True)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_allow_copy_false():
|
||||
if Version(pd.__version__) < Version("1.5.0"):
|
||||
pytest.skip("__dataframe__ added to pandas in 1.5.0")
|
||||
|
||||
# Test that an error is raised when a copy is needed
|
||||
# to create a bitmask
|
||||
|
||||
df = pd.DataFrame({"a": [0, 1.0, 2.0]})
|
||||
with pytest.raises(RuntimeError):
|
||||
pi.from_dataframe(df, allow_copy=False)
|
||||
|
||||
df = pd.DataFrame({
|
||||
"dt": [None, dt(2007, 7, 14), dt(2007, 7, 15)]
|
||||
})
|
||||
with pytest.raises(RuntimeError):
|
||||
pi.from_dataframe(df, allow_copy=False)
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
def test_allow_copy_false_bool_categorical():
|
||||
if Version(pd.__version__) < Version("1.5.0"):
|
||||
pytest.skip("__dataframe__ added to pandas in 1.5.0")
|
||||
|
||||
# Test that an error is raised for boolean
|
||||
# and categorical dtype (copy is always made)
|
||||
|
||||
df = pd.DataFrame({"a": [None, False, True]})
|
||||
with pytest.raises(RuntimeError):
|
||||
pi.from_dataframe(df, allow_copy=False)
|
||||
|
||||
df = pd.DataFrame({"a": [True, False, True]})
|
||||
with pytest.raises(RuntimeError):
|
||||
pi.from_dataframe(df, allow_copy=False)
|
||||
|
||||
df = pd.DataFrame({"weekday": ["a", "b", None]})
|
||||
df = df.astype("category")
|
||||
with pytest.raises(RuntimeError):
|
||||
pi.from_dataframe(df, allow_copy=False)
|
||||
|
||||
df = pd.DataFrame({"weekday": ["a", "b", "c"]})
|
||||
df = df.astype("category")
|
||||
with pytest.raises(RuntimeError):
|
||||
pi.from_dataframe(df, allow_copy=False)
|
||||
|
||||
|
||||
def test_empty_dataframe():
|
||||
schema = pa.schema([('col1', pa.int8())])
|
||||
df = pa.table([[]], schema=schema)
|
||||
dfi = df.__dataframe__()
|
||||
assert pi.from_dataframe(dfi) == df
|
||||
@@ -0,0 +1,294 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import ctypes
|
||||
import hypothesis as h
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import pytest
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
import pyarrow as pa
|
||||
import pyarrow.tests.strategies as past
|
||||
|
||||
|
||||
all_types = st.deferred(
|
||||
lambda: (
|
||||
past.signed_integer_types |
|
||||
past.unsigned_integer_types |
|
||||
past.floating_types |
|
||||
past.bool_type |
|
||||
past.string_type |
|
||||
past.large_string_type
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# datetime is tested in test_extra.py
|
||||
# dictionary is tested in test_categorical()
|
||||
@pytest.mark.numpy
|
||||
@h.settings(suppress_health_check=(h.HealthCheck.too_slow,))
|
||||
@h.given(past.arrays(all_types, size=3))
|
||||
def test_dtypes(arr):
|
||||
table = pa.table([arr], names=["a"])
|
||||
df = table.__dataframe__()
|
||||
|
||||
null_count = df.get_column(0).null_count
|
||||
assert null_count == arr.null_count
|
||||
assert isinstance(null_count, int)
|
||||
assert df.get_column(0).size() == 3
|
||||
assert df.get_column(0).offset == 0
|
||||
|
||||
|
||||
@pytest.mark.numpy
|
||||
@pytest.mark.parametrize(
|
||||
"uint, uint_bw",
|
||||
[
|
||||
(pa.uint8(), 8),
|
||||
(pa.uint16(), 16),
|
||||
(pa.uint32(), 32)
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"int, int_bw", [
|
||||
(pa.int8(), 8),
|
||||
(pa.int16(), 16),
|
||||
(pa.int32(), 32),
|
||||
(pa.int64(), 64)
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"float, float_bw, np_float_str", [
|
||||
(pa.float16(), 16, "float16"),
|
||||
(pa.float32(), 32, "float32"),
|
||||
(pa.float64(), 64, "float64")
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
|
||||
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
def test_mixed_dtypes(uint, uint_bw, int, int_bw,
|
||||
float, float_bw, np_float_str, unit, tz,
|
||||
use_batch):
|
||||
from datetime import datetime as dt
|
||||
arr = [1, 2, 3]
|
||||
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
|
||||
table = pa.table(
|
||||
{
|
||||
"a": pa.array(arr, type=uint),
|
||||
"b": pa.array(arr, type=int),
|
||||
"c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), type=float),
|
||||
"d": [True, False, True],
|
||||
"e": ["a", "", "c"],
|
||||
"f": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))
|
||||
}
|
||||
)
|
||||
if use_batch:
|
||||
table = table.to_batches()[0]
|
||||
df = table.__dataframe__()
|
||||
# 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT,
|
||||
# 20 = DtypeKind.BOOL, 21 = DtypeKind.STRING, 22 = DtypeKind.DATETIME
|
||||
# see DtypeKind class in column.py
|
||||
columns = {"a": 1, "b": 0, "c": 2, "d": 20, "e": 21, "f": 22}
|
||||
|
||||
for column, kind in columns.items():
|
||||
col = df.get_column_by_name(column)
|
||||
|
||||
assert col.null_count == 0
|
||||
assert col.size() == 3
|
||||
assert col.offset == 0
|
||||
assert col.dtype[0] == kind
|
||||
|
||||
assert df.get_column_by_name("a").dtype[1] == uint_bw
|
||||
assert df.get_column_by_name("b").dtype[1] == int_bw
|
||||
assert df.get_column_by_name("c").dtype[1] == float_bw
|
||||
|
||||
|
||||
def test_na_float():
|
||||
table = pa.table({"a": [1.0, None, 2.0]})
|
||||
df = table.__dataframe__()
|
||||
col = df.get_column_by_name("a")
|
||||
assert col.null_count == 1
|
||||
assert isinstance(col.null_count, int)
|
||||
|
||||
|
||||
def test_noncategorical():
|
||||
table = pa.table({"a": [1, 2, 3]})
|
||||
df = table.__dataframe__()
|
||||
col = df.get_column_by_name("a")
|
||||
with pytest.raises(TypeError, match=".*categorical.*"):
|
||||
col.describe_categorical
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
def test_categorical(use_batch):
|
||||
import pyarrow as pa
|
||||
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
|
||||
table = pa.table(
|
||||
{"weekday": pa.array(arr).dictionary_encode()}
|
||||
)
|
||||
if use_batch:
|
||||
table = table.to_batches()[0]
|
||||
|
||||
col = table.__dataframe__().get_column_by_name("weekday")
|
||||
categorical = col.describe_categorical
|
||||
assert isinstance(categorical["is_ordered"], bool)
|
||||
assert isinstance(categorical["is_dictionary"], bool)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
def test_dataframe(use_batch):
|
||||
n = pa.chunked_array([[2, 2, 4], [4, 5, 100]])
|
||||
a = pa.chunked_array([["Flamingo", "Parrot", "Cow"],
|
||||
["Horse", "Brittle stars", "Centipede"]])
|
||||
table = pa.table([n, a], names=['n_legs', 'animals'])
|
||||
if use_batch:
|
||||
table = table.combine_chunks().to_batches()[0]
|
||||
df = table.__dataframe__()
|
||||
|
||||
assert df.num_columns() == 2
|
||||
assert df.num_rows() == 6
|
||||
if use_batch:
|
||||
assert df.num_chunks() == 1
|
||||
else:
|
||||
assert df.num_chunks() == 2
|
||||
assert list(df.column_names()) == ['n_legs', 'animals']
|
||||
assert list(df.select_columns((1,)).column_names()) == list(
|
||||
df.select_columns_by_name(("animals",)).column_names()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
|
||||
def test_df_get_chunks(use_batch, size, n_chunks):
|
||||
table = pa.table({"x": list(range(size))})
|
||||
if use_batch:
|
||||
table = table.to_batches()[0]
|
||||
df = table.__dataframe__()
|
||||
chunks = list(df.get_chunks(n_chunks))
|
||||
assert len(chunks) == n_chunks
|
||||
assert sum(chunk.num_rows() for chunk in chunks) == size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
|
||||
def test_column_get_chunks(use_batch, size, n_chunks):
|
||||
table = pa.table({"x": list(range(size))})
|
||||
if use_batch:
|
||||
table = table.to_batches()[0]
|
||||
df = table.__dataframe__()
|
||||
chunks = list(df.get_column(0).get_chunks(n_chunks))
|
||||
assert len(chunks) == n_chunks
|
||||
assert sum(chunk.size() for chunk in chunks) == size
|
||||
|
||||
|
||||
@pytest.mark.pandas
|
||||
@pytest.mark.parametrize(
|
||||
"uint", [pa.uint8(), pa.uint16(), pa.uint32()]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"float, np_float_str", [
|
||||
(pa.float16(), "float16"),
|
||||
(pa.float32(), "float32"),
|
||||
(pa.float64(), "float64")
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
def test_get_columns(uint, int, float, np_float_str, use_batch):
|
||||
arr = [[1, 2, 3], [4, 5]]
|
||||
arr_float = np.array([1, 2, 3, 4, 5], dtype=np.dtype(np_float_str))
|
||||
table = pa.table(
|
||||
{
|
||||
"a": pa.chunked_array(arr, type=uint),
|
||||
"b": pa.chunked_array(arr, type=int),
|
||||
"c": pa.array(arr_float, type=float)
|
||||
}
|
||||
)
|
||||
if use_batch:
|
||||
table = table.combine_chunks().to_batches()[0]
|
||||
df = table.__dataframe__()
|
||||
for col in df.get_columns():
|
||||
assert col.size() == 5
|
||||
assert col.num_chunks() == 1
|
||||
|
||||
# 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT,
|
||||
# see DtypeKind class in column.py
|
||||
assert df.get_column(0).dtype[0] == 1 # UINT
|
||||
assert df.get_column(1).dtype[0] == 0 # INT
|
||||
assert df.get_column(2).dtype[0] == 2 # FLOAT
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
|
||||
)
|
||||
@pytest.mark.parametrize("use_batch", [False, True])
|
||||
def test_buffer(int, use_batch):
|
||||
arr = [0, 1, -1]
|
||||
table = pa.table({"a": pa.array(arr, type=int)})
|
||||
if use_batch:
|
||||
table = table.to_batches()[0]
|
||||
df = table.__dataframe__()
|
||||
col = df.get_column(0)
|
||||
buf = col.get_buffers()
|
||||
|
||||
dataBuf, dataDtype = buf["data"]
|
||||
|
||||
assert dataBuf.bufsize > 0
|
||||
assert dataBuf.ptr != 0
|
||||
device, _ = dataBuf.__dlpack_device__()
|
||||
|
||||
# 0 = DtypeKind.INT
|
||||
# see DtypeKind class in column.py
|
||||
assert dataDtype[0] == 0
|
||||
|
||||
if device == 1: # CPU-only as we're going to directly read memory here
|
||||
bitwidth = dataDtype[1]
|
||||
ctype = {
|
||||
8: ctypes.c_int8,
|
||||
16: ctypes.c_int16,
|
||||
32: ctypes.c_int32,
|
||||
64: ctypes.c_int64,
|
||||
}[bitwidth]
|
||||
|
||||
for idx, truth in enumerate(arr):
|
||||
val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value
|
||||
assert val == truth, f"Buffer at index {idx} mismatch"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"indices_type, bitwidth, f_string", [
|
||||
(pa.int8(), 8, "c"),
|
||||
(pa.int16(), 16, "s"),
|
||||
(pa.int32(), 32, "i"),
|
||||
(pa.int64(), 64, "l")
|
||||
]
|
||||
)
|
||||
def test_categorical_dtype(indices_type, bitwidth, f_string):
|
||||
type = pa.dictionary(indices_type, pa.string())
|
||||
arr = pa.array(["a", "b", None, "d"], type)
|
||||
table = pa.table({'a': arr})
|
||||
|
||||
df = table.__dataframe__()
|
||||
col = df.get_column(0)
|
||||
assert col.dtype[0] == 23 # <DtypeKind.CATEGORICAL: 23>
|
||||
assert col.dtype[1] == bitwidth
|
||||
assert col.dtype[2] == f_string
|
||||
@@ -0,0 +1,172 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
from datetime import date, time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def dataframe_with_arrays(include_index=False):
|
||||
"""
|
||||
Dataframe with numpy arrays columns of every possible primitive type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df: pandas.DataFrame
|
||||
schema: pyarrow.Schema
|
||||
Arrow schema definition that is in line with the constructed df.
|
||||
"""
|
||||
dtypes = [('i1', pa.int8()), ('i2', pa.int16()),
|
||||
('i4', pa.int32()), ('i8', pa.int64()),
|
||||
('u1', pa.uint8()), ('u2', pa.uint16()),
|
||||
('u4', pa.uint32()), ('u8', pa.uint64()),
|
||||
('f4', pa.float32()), ('f8', pa.float64())]
|
||||
|
||||
arrays = OrderedDict()
|
||||
fields = []
|
||||
for dtype, arrow_dtype in dtypes:
|
||||
fields.append(pa.field(dtype, pa.list_(arrow_dtype)))
|
||||
arrays[dtype] = [
|
||||
np.arange(10, dtype=dtype),
|
||||
np.arange(5, dtype=dtype),
|
||||
None,
|
||||
np.arange(1, dtype=dtype)
|
||||
]
|
||||
|
||||
fields.append(pa.field('str', pa.list_(pa.string())))
|
||||
arrays['str'] = [
|
||||
np.array(["1", "ä"], dtype="object"),
|
||||
None,
|
||||
np.array(["1"], dtype="object"),
|
||||
np.array(["1", "2", "3"], dtype="object")
|
||||
]
|
||||
|
||||
fields.append(pa.field('datetime64', pa.list_(pa.timestamp('ms'))))
|
||||
arrays['datetime64'] = [
|
||||
np.array(['2007-07-13T01:23:34.123456789',
|
||||
None,
|
||||
'2010-08-13T05:46:57.437699912'],
|
||||
dtype='datetime64[ms]'),
|
||||
None,
|
||||
None,
|
||||
np.array(['2007-07-13T02',
|
||||
None,
|
||||
'2010-08-13T05:46:57.437699912'],
|
||||
dtype='datetime64[ms]'),
|
||||
]
|
||||
|
||||
if include_index:
|
||||
fields.append(pa.field('__index_level_0__', pa.int64()))
|
||||
df = pd.DataFrame(arrays)
|
||||
schema = pa.schema(fields)
|
||||
|
||||
return df, schema
|
||||
|
||||
|
||||
def dataframe_with_lists(include_index=False, parquet_compatible=False):
|
||||
"""
|
||||
Dataframe with list columns of every possible primitive type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df: pandas.DataFrame
|
||||
schema: pyarrow.Schema
|
||||
Arrow schema definition that is in line with the constructed df.
|
||||
parquet_compatible: bool
|
||||
Exclude types not supported by parquet
|
||||
"""
|
||||
arrays = OrderedDict()
|
||||
fields = []
|
||||
|
||||
fields.append(pa.field('int64', pa.list_(pa.int64())))
|
||||
arrays['int64'] = [
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
[0, 1, 2, 3, 4],
|
||||
None,
|
||||
[],
|
||||
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 2,
|
||||
dtype=np.int64)[::2]
|
||||
]
|
||||
fields.append(pa.field('double', pa.list_(pa.float64())))
|
||||
arrays['double'] = [
|
||||
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
|
||||
[0., 1., 2., 3., 4.],
|
||||
None,
|
||||
[],
|
||||
np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.] * 2)[::2],
|
||||
]
|
||||
fields.append(pa.field('bytes_list', pa.list_(pa.binary())))
|
||||
arrays['bytes_list'] = [
|
||||
[b"1", b"f"],
|
||||
None,
|
||||
[b"1"],
|
||||
[b"1", b"2", b"3"],
|
||||
[],
|
||||
]
|
||||
fields.append(pa.field('str_list', pa.list_(pa.string())))
|
||||
arrays['str_list'] = [
|
||||
["1", "ä"],
|
||||
None,
|
||||
["1"],
|
||||
["1", "2", "3"],
|
||||
[],
|
||||
]
|
||||
|
||||
date_data = [
|
||||
[],
|
||||
[date(2018, 1, 1), date(2032, 12, 30)],
|
||||
[date(2000, 6, 7)],
|
||||
None,
|
||||
[date(1969, 6, 9), date(1972, 7, 3)]
|
||||
]
|
||||
time_data = [
|
||||
[time(23, 11, 11), time(1, 2, 3), time(23, 59, 59)],
|
||||
[],
|
||||
[time(22, 5, 59)],
|
||||
None,
|
||||
[time(0, 0, 0), time(18, 0, 2), time(12, 7, 3)]
|
||||
]
|
||||
|
||||
temporal_pairs = [
|
||||
(pa.date32(), date_data),
|
||||
(pa.date64(), date_data),
|
||||
(pa.time32('s'), time_data),
|
||||
(pa.time32('ms'), time_data),
|
||||
(pa.time64('us'), time_data)
|
||||
]
|
||||
if not parquet_compatible:
|
||||
temporal_pairs += [
|
||||
(pa.time64('ns'), time_data),
|
||||
]
|
||||
|
||||
for value_type, data in temporal_pairs:
|
||||
field_name = f'{value_type}_list'
|
||||
field_type = pa.list_(value_type)
|
||||
field = pa.field(field_name, field_type)
|
||||
fields.append(field)
|
||||
arrays[field_name] = data
|
||||
|
||||
if include_index:
|
||||
fields.append(pa.field('__index_level_0__', pa.int64()))
|
||||
|
||||
df = pd.DataFrame(arrays)
|
||||
schema = pa.schema(fields)
|
||||
|
||||
return df, schema
|
||||
@@ -0,0 +1,44 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# This file is called from a test in test_pandas.py.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import faulthandler
|
||||
import sys
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
num_threads = 60
|
||||
timeout = 10 # seconds
|
||||
|
||||
|
||||
def thread_func(i):
|
||||
pa.array([i]).to_pandas()
|
||||
|
||||
|
||||
def main():
|
||||
# In case of import deadlock, crash after a finite timeout
|
||||
faulthandler.dump_traceback_later(timeout, exit=True)
|
||||
with ThreadPoolExecutor(num_threads) as pool:
|
||||
assert "pandas" not in sys.modules # pandas is imported lazily
|
||||
list(pool.map(thread_func, range(num_threads)))
|
||||
assert "pandas" in sys.modules
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
# Marks all of the tests in this module
|
||||
# Ignore these with pytest ... -m 'not parquet'
|
||||
pytestmark = [
|
||||
pytest.mark.parquet,
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,179 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import io
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow.tests import util
|
||||
|
||||
|
||||
def _write_table(table, path, **kwargs):
|
||||
# So we see the ImportError somewhere
|
||||
import pyarrow.parquet as pq
|
||||
from pyarrow.pandas_compat import _pandas_api
|
||||
|
||||
if _pandas_api.is_data_frame(table):
|
||||
table = pa.Table.from_pandas(table)
|
||||
|
||||
pq.write_table(table, path, **kwargs)
|
||||
return table
|
||||
|
||||
|
||||
def _read_table(*args, **kwargs):
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
table = pq.read_table(*args, **kwargs)
|
||||
table.validate(full=True)
|
||||
return table
|
||||
|
||||
|
||||
def _roundtrip_table(table, read_table_kwargs=None,
|
||||
write_table_kwargs=None):
|
||||
read_table_kwargs = read_table_kwargs or {}
|
||||
write_table_kwargs = write_table_kwargs or {}
|
||||
|
||||
writer = pa.BufferOutputStream()
|
||||
_write_table(table, writer, **write_table_kwargs)
|
||||
reader = pa.BufferReader(writer.getvalue())
|
||||
return _read_table(reader, **read_table_kwargs)
|
||||
|
||||
|
||||
def _check_roundtrip(table, expected=None, read_table_kwargs=None,
|
||||
**write_table_kwargs):
|
||||
if expected is None:
|
||||
expected = table
|
||||
|
||||
read_table_kwargs = read_table_kwargs or {}
|
||||
|
||||
# intentionally check twice
|
||||
result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs,
|
||||
write_table_kwargs=write_table_kwargs)
|
||||
assert result.schema == expected.schema
|
||||
assert result.equals(expected)
|
||||
result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs,
|
||||
write_table_kwargs=write_table_kwargs)
|
||||
assert result.schema == expected.schema
|
||||
assert result.equals(expected)
|
||||
|
||||
|
||||
def _roundtrip_pandas_dataframe(df, write_kwargs):
|
||||
table = pa.Table.from_pandas(df)
|
||||
result = _roundtrip_table(
|
||||
table, write_table_kwargs=write_kwargs)
|
||||
return result.to_pandas()
|
||||
|
||||
|
||||
def _random_integers(size, dtype):
|
||||
# We do not generate integers outside the int64 range
|
||||
platform_int_info = np.iinfo('int_')
|
||||
iinfo = np.iinfo(dtype)
|
||||
return np.random.randint(max(iinfo.min, platform_int_info.min),
|
||||
min(iinfo.max, platform_int_info.max),
|
||||
size=size, dtype=dtype)
|
||||
|
||||
|
||||
def _range_integers(size, dtype):
|
||||
return pa.array(np.arange(size, dtype=dtype))
|
||||
|
||||
|
||||
def _test_dict(size=10000, seed=0):
|
||||
np.random.seed(seed)
|
||||
return {
|
||||
'uint8': _random_integers(size, np.uint8),
|
||||
'uint16': _random_integers(size, np.uint16),
|
||||
'uint32': _random_integers(size, np.uint32),
|
||||
'uint64': _random_integers(size, np.uint64),
|
||||
'int8': _random_integers(size, np.int8),
|
||||
'int16': _random_integers(size, np.int16),
|
||||
'int32': _random_integers(size, np.int32),
|
||||
'int64': _random_integers(size, np.int64),
|
||||
'float32': np.random.randn(size).astype(np.float32),
|
||||
'float64': np.arange(size, dtype=np.float64),
|
||||
'bool': np.random.randn(size) > 0,
|
||||
'strings': [util.rands(10) for i in range(size)],
|
||||
'all_none': [None] * size,
|
||||
'all_none_category': [None] * size
|
||||
}
|
||||
|
||||
|
||||
def _test_dataframe(size=10000, seed=0):
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(_test_dict(size, seed))
|
||||
|
||||
# TODO(PARQUET-1015)
|
||||
# df['all_none_category'] = df['all_none_category'].astype('category')
|
||||
return df
|
||||
|
||||
|
||||
def _test_table(size=10000, seed=0):
|
||||
return pa.Table.from_pydict(_test_dict(size, seed))
|
||||
|
||||
|
||||
def make_sample_file(table_or_df):
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
if isinstance(table_or_df, pa.Table):
|
||||
a_table = table_or_df
|
||||
else:
|
||||
a_table = pa.Table.from_pandas(table_or_df)
|
||||
|
||||
buf = io.BytesIO()
|
||||
_write_table(a_table, buf, compression='SNAPPY', version='2.6')
|
||||
|
||||
buf.seek(0)
|
||||
return pq.ParquetFile(buf)
|
||||
|
||||
|
||||
def alltypes_sample(size=10000, seed=0, categorical=False):
|
||||
import pandas as pd
|
||||
|
||||
np.random.seed(seed)
|
||||
arrays = {
|
||||
'uint8': np.arange(size, dtype=np.uint8),
|
||||
'uint16': np.arange(size, dtype=np.uint16),
|
||||
'uint32': np.arange(size, dtype=np.uint32),
|
||||
'uint64': np.arange(size, dtype=np.uint64),
|
||||
'int8': np.arange(size, dtype=np.int16),
|
||||
'int16': np.arange(size, dtype=np.int16),
|
||||
'int32': np.arange(size, dtype=np.int32),
|
||||
'int64': np.arange(size, dtype=np.int64),
|
||||
'float16': np.arange(size, dtype=np.float16),
|
||||
'float32': np.arange(size, dtype=np.float32),
|
||||
'float64': np.arange(size, dtype=np.float64),
|
||||
'bool': np.random.randn(size) > 0,
|
||||
'datetime_ms': np.arange("2016-01-01T00:00:00.001", size,
|
||||
dtype='datetime64[ms]'),
|
||||
'datetime_us': np.arange("2016-01-01T00:00:00.000001", size,
|
||||
dtype='datetime64[us]'),
|
||||
'datetime_ns': np.arange("2016-01-01T00:00:00.000000001", size,
|
||||
dtype='datetime64[ns]'),
|
||||
'timedelta': np.arange(0, size, dtype="timedelta64[s]"),
|
||||
'str': pd.Series([str(x) for x in range(size)]),
|
||||
'empty_str': [''] * size,
|
||||
'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
|
||||
'null': [None] * size,
|
||||
'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)],
|
||||
}
|
||||
if categorical:
|
||||
arrays['str_category'] = arrays['str'].astype('category')
|
||||
return pd.DataFrame(arrays)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user