aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pandas/py2/pandas/tests/extension/arrow/bool.py
blob: 025c4cacd8fa132da45733cf3a94250c3d3d4960 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Rudimentary Apache Arrow-backed ExtensionArray.

At the moment, just a boolean array / type is implemented.
Eventually, we'll want to parametrize the type and support
multiple dtypes. Not all methods are implemented yet, and the
current implementation is not efficient.
"""
import copy
import itertools

import numpy as np
import pyarrow as pa

import pandas as pd
from pandas.api.extensions import (
    ExtensionArray, ExtensionDtype, register_extension_dtype, take)


@register_extension_dtype
class ArrowBoolDtype(ExtensionDtype):

    type = np.bool_
    kind = 'b'
    name = 'arrow_bool'
    na_value = pa.NULL

    @classmethod
    def construct_from_string(cls, string):
        if string == cls.name:
            return cls()
        else:
            raise TypeError("Cannot construct a '{}' from "
                            "'{}'".format(cls, string))

    @classmethod
    def construct_array_type(cls):
        return ArrowBoolArray

    def _is_boolean(self):
        return True


class ArrowBoolArray(ExtensionArray):
    def __init__(self, values):
        if not isinstance(values, pa.ChunkedArray):
            raise ValueError

        assert values.type == pa.bool_()
        self._data = values
        self._dtype = ArrowBoolDtype()

    def __repr__(self):
        return "ArrowBoolArray({})".format(repr(self._data))

    @classmethod
    def from_scalars(cls, values):
        arr = pa.chunked_array([pa.array(np.asarray(values))])
        return cls(arr)

    @classmethod
    def from_array(cls, arr):
        assert isinstance(arr, pa.Array)
        return cls(pa.chunked_array([arr]))

    @classmethod
    def _from_sequence(cls, scalars, dtype=None, copy=False):
        return cls.from_scalars(scalars)

    def __getitem__(self, item):
        if pd.api.types.is_scalar(item):
            return self._data.to_pandas()[item]
        else:
            vals = self._data.to_pandas()[item]
            return type(self).from_scalars(vals)

    def __len__(self):
        return len(self._data)

    def astype(self, dtype, copy=True):
        # needed to fix this astype for the Series constructor.
        if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
            if copy:
                return self.copy()
            return self
        return super(ArrowBoolArray, self).astype(dtype, copy)

    @property
    def dtype(self):
        return self._dtype

    @property
    def nbytes(self):
        return sum(x.size for chunk in self._data.chunks
                   for x in chunk.buffers()
                   if x is not None)

    def isna(self):
        nas = pd.isna(self._data.to_pandas())
        return type(self).from_scalars(nas)

    def take(self, indices, allow_fill=False, fill_value=None):
        data = self._data.to_pandas()

        if allow_fill and fill_value is None:
            fill_value = self.dtype.na_value

        result = take(data, indices, fill_value=fill_value,
                      allow_fill=allow_fill)
        return self._from_sequence(result, dtype=self.dtype)

    def copy(self, deep=False):
        if deep:
            return type(self)(copy.deepcopy(self._data))
        else:
            return type(self)(copy.copy(self._data))

    def _concat_same_type(cls, to_concat):
        chunks = list(itertools.chain.from_iterable(x._data.chunks
                                                    for x in to_concat))
        arr = pa.chunked_array(chunks)
        return cls(arr)

    def __invert__(self):
        return type(self).from_scalars(
            ~self._data.to_pandas()
        )

    def _reduce(self, method, skipna=True, **kwargs):
        if skipna:
            arr = self[~self.isna()]
        else:
            arr = self

        try:
            op = getattr(arr, method)
        except AttributeError:
            raise TypeError
        return op(**kwargs)

    def any(self, axis=0, out=None):
        return self._data.to_pandas().any()

    def all(self, axis=0, out=None):
        return self._data.to_pandas().all()