aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pyarrow/pyarrow/jvm.py
diff options
context:
space:
mode:
authornkozlovskiy <nmk@ydb.tech>2023-09-29 12:24:06 +0300
committernkozlovskiy <nmk@ydb.tech>2023-09-29 12:41:34 +0300
commite0e3e1717e3d33762ce61950504f9637a6e669ed (patch)
treebca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/pyarrow/pyarrow/jvm.py
parent38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff)
downloadydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz
add ydb deps
Diffstat (limited to 'contrib/python/pyarrow/pyarrow/jvm.py')
-rw-r--r--contrib/python/pyarrow/pyarrow/jvm.py335
1 files changed, 335 insertions, 0 deletions
diff --git a/contrib/python/pyarrow/pyarrow/jvm.py b/contrib/python/pyarrow/pyarrow/jvm.py
new file mode 100644
index 0000000000..161c5ff4d6
--- /dev/null
+++ b/contrib/python/pyarrow/pyarrow/jvm.py
@@ -0,0 +1,335 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Functions to interact with Arrow memory allocated by Arrow Java.
+
+These functions convert the objects holding the metadata, the actual
+data is not copied at all.
+
+This will only work with a JVM running in the same process such as provided
+through jpype. Modules that talk to a remote JVM like py4j will not work as the
+memory addresses reported by them are not reachable in the python process.
+"""
+
+import pyarrow as pa
+
+
+class _JvmBufferNanny:
+ """
+ An object that keeps a org.apache.arrow.memory.ArrowBuf's underlying
+ memory alive.
+ """
+ ref_manager = None
+
+ def __init__(self, jvm_buf):
+ ref_manager = jvm_buf.getReferenceManager()
+ # Will raise a java.lang.IllegalArgumentException if the buffer
+ # is already freed. It seems that exception cannot easily be
+ # caught...
+ ref_manager.retain()
+ self.ref_manager = ref_manager
+
+ def __del__(self):
+ if self.ref_manager is not None:
+ self.ref_manager.release()
+
+
+def jvm_buffer(jvm_buf):
+ """
+ Construct an Arrow buffer from org.apache.arrow.memory.ArrowBuf
+
+ Parameters
+ ----------
+
+ jvm_buf: org.apache.arrow.memory.ArrowBuf
+ Arrow Buffer representation on the JVM.
+
+ Returns
+ -------
+ pyarrow.Buffer
+ Python Buffer that references the JVM memory.
+ """
+ nanny = _JvmBufferNanny(jvm_buf)
+ address = jvm_buf.memoryAddress()
+ size = jvm_buf.capacity()
+ return pa.foreign_buffer(address, size, base=nanny)
+
+
+def _from_jvm_int_type(jvm_type):
+ """
+ Convert a JVM int type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type : org.apache.arrow.vector.types.pojo.ArrowType$Int
+
+ Returns
+ -------
+ typ : pyarrow.DataType
+ """
+
+ bit_width = jvm_type.getBitWidth()
+ if jvm_type.getIsSigned():
+ if bit_width == 8:
+ return pa.int8()
+ elif bit_width == 16:
+ return pa.int16()
+ elif bit_width == 32:
+ return pa.int32()
+ elif bit_width == 64:
+ return pa.int64()
+ else:
+ if bit_width == 8:
+ return pa.uint8()
+ elif bit_width == 16:
+ return pa.uint16()
+ elif bit_width == 32:
+ return pa.uint32()
+ elif bit_width == 64:
+ return pa.uint64()
+
+
+def _from_jvm_float_type(jvm_type):
+ """
+ Convert a JVM float type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$FloatingPoint
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ precision = jvm_type.getPrecision().toString()
+ if precision == 'HALF':
+ return pa.float16()
+ elif precision == 'SINGLE':
+ return pa.float32()
+ elif precision == 'DOUBLE':
+ return pa.float64()
+
+
+def _from_jvm_time_type(jvm_type):
+ """
+ Convert a JVM time type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Time
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ time_unit = jvm_type.getUnit().toString()
+ if time_unit == 'SECOND':
+ assert jvm_type.getBitWidth() == 32
+ return pa.time32('s')
+ elif time_unit == 'MILLISECOND':
+ assert jvm_type.getBitWidth() == 32
+ return pa.time32('ms')
+ elif time_unit == 'MICROSECOND':
+ assert jvm_type.getBitWidth() == 64
+ return pa.time64('us')
+ elif time_unit == 'NANOSECOND':
+ assert jvm_type.getBitWidth() == 64
+ return pa.time64('ns')
+
+
+def _from_jvm_timestamp_type(jvm_type):
+ """
+ Convert a JVM timestamp type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Timestamp
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ time_unit = jvm_type.getUnit().toString()
+ timezone = jvm_type.getTimezone()
+ if timezone is not None:
+ timezone = str(timezone)
+ if time_unit == 'SECOND':
+ return pa.timestamp('s', tz=timezone)
+ elif time_unit == 'MILLISECOND':
+ return pa.timestamp('ms', tz=timezone)
+ elif time_unit == 'MICROSECOND':
+ return pa.timestamp('us', tz=timezone)
+ elif time_unit == 'NANOSECOND':
+ return pa.timestamp('ns', tz=timezone)
+
+
+def _from_jvm_date_type(jvm_type):
+ """
+ Convert a JVM date type to its Python equivalent
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Date
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ day_unit = jvm_type.getUnit().toString()
+ if day_unit == 'DAY':
+ return pa.date32()
+ elif day_unit == 'MILLISECOND':
+ return pa.date64()
+
+
+def field(jvm_field):
+ """
+ Construct a Field from a org.apache.arrow.vector.types.pojo.Field
+ instance.
+
+ Parameters
+ ----------
+ jvm_field: org.apache.arrow.vector.types.pojo.Field
+
+ Returns
+ -------
+ pyarrow.Field
+ """
+ name = str(jvm_field.getName())
+ jvm_type = jvm_field.getType()
+
+ typ = None
+ if not jvm_type.isComplex():
+ type_str = jvm_type.getTypeID().toString()
+ if type_str == 'Null':
+ typ = pa.null()
+ elif type_str == 'Int':
+ typ = _from_jvm_int_type(jvm_type)
+ elif type_str == 'FloatingPoint':
+ typ = _from_jvm_float_type(jvm_type)
+ elif type_str == 'Utf8':
+ typ = pa.string()
+ elif type_str == 'Binary':
+ typ = pa.binary()
+ elif type_str == 'FixedSizeBinary':
+ typ = pa.binary(jvm_type.getByteWidth())
+ elif type_str == 'Bool':
+ typ = pa.bool_()
+ elif type_str == 'Time':
+ typ = _from_jvm_time_type(jvm_type)
+ elif type_str == 'Timestamp':
+ typ = _from_jvm_timestamp_type(jvm_type)
+ elif type_str == 'Date':
+ typ = _from_jvm_date_type(jvm_type)
+ elif type_str == 'Decimal':
+ typ = pa.decimal128(jvm_type.getPrecision(), jvm_type.getScale())
+ else:
+ raise NotImplementedError(
+ "Unsupported JVM type: {}".format(type_str))
+ else:
+ # TODO: The following JVM types are not implemented:
+ # Struct, List, FixedSizeList, Union, Dictionary
+ raise NotImplementedError(
+ "JVM field conversion only implemented for primitive types.")
+
+ nullable = jvm_field.isNullable()
+ jvm_metadata = jvm_field.getMetadata()
+ if jvm_metadata.isEmpty():
+ metadata = None
+ else:
+ metadata = {str(entry.getKey()): str(entry.getValue())
+ for entry in jvm_metadata.entrySet()}
+ return pa.field(name, typ, nullable, metadata)
+
+
+def schema(jvm_schema):
+ """
+ Construct a Schema from a org.apache.arrow.vector.types.pojo.Schema
+ instance.
+
+ Parameters
+ ----------
+ jvm_schema: org.apache.arrow.vector.types.pojo.Schema
+
+ Returns
+ -------
+ pyarrow.Schema
+ """
+ fields = jvm_schema.getFields()
+ fields = [field(f) for f in fields]
+ jvm_metadata = jvm_schema.getCustomMetadata()
+ if jvm_metadata.isEmpty():
+ metadata = None
+ else:
+ metadata = {str(entry.getKey()): str(entry.getValue())
+ for entry in jvm_metadata.entrySet()}
+ return pa.schema(fields, metadata)
+
+
+def array(jvm_array):
+ """
+ Construct an (Python) Array from its JVM equivalent.
+
+ Parameters
+ ----------
+ jvm_array : org.apache.arrow.vector.ValueVector
+
+ Returns
+ -------
+ array : Array
+ """
+ if jvm_array.getField().getType().isComplex():
+ minor_type_str = jvm_array.getMinorType().toString()
+ raise NotImplementedError(
+ "Cannot convert JVM Arrow array of type {},"
+ " complex types not yet implemented.".format(minor_type_str))
+ dtype = field(jvm_array.getField()).type
+ buffers = [jvm_buffer(buf)
+ for buf in list(jvm_array.getBuffers(False))]
+
+ # If JVM has an empty Vector, buffer list will be empty so create manually
+ if len(buffers) == 0:
+ return pa.array([], type=dtype)
+
+ length = jvm_array.getValueCount()
+ null_count = jvm_array.getNullCount()
+ return pa.Array.from_buffers(dtype, length, buffers, null_count)
+
+
+def record_batch(jvm_vector_schema_root):
+ """
+ Construct a (Python) RecordBatch from a JVM VectorSchemaRoot
+
+ Parameters
+ ----------
+ jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot
+
+ Returns
+ -------
+ record_batch: pyarrow.RecordBatch
+ """
+ pa_schema = schema(jvm_vector_schema_root.getSchema())
+
+ arrays = []
+ for name in pa_schema.names:
+ arrays.append(array(jvm_vector_schema_root.getVector(name)))
+
+ return pa.RecordBatch.from_arrays(
+ arrays,
+ pa_schema.names,
+ metadata=pa_schema.metadata
+ )