%PDF- %PDF-
| Direktori : /backups/router/usr/local/lib/python3.11/site-packages/duckdb/experimental/spark/sql/ |
| Current File : //backups/router/usr/local/lib/python3.11/site-packages/duckdb/experimental/spark/sql/session.py |
from typing import Optional, List, Any, Union, Iterable, TYPE_CHECKING
import uuid
if TYPE_CHECKING:
from .catalog import Catalog
from pandas.core.frame import DataFrame as PandasDataFrame
from ..exception import ContributionsAcceptedError
from .types import StructType, AtomicType, DataType
from ..conf import SparkConf
from .dataframe import DataFrame
from .conf import RuntimeConfig
from .readwriter import DataFrameReader
from ..context import SparkContext
from .udf import UDFRegistration
from .streaming import DataStreamReader
import duckdb
from ..errors import (
PySparkTypeError,
PySparkValueError
)
from ..errors.error_classes import *
# In spark:
# SparkSession holds a SparkContext
# SparkContext gets created from SparkConf
# At this level the check is made to determine whether the instance already exists and just needs to be retrieved or it needs to be created
# For us this is done inside of `duckdb.connect`, based on the passed in path + configuration
# SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class
# data is a List of rows
# every value in each row needs to be turned into a Value
def _combine_data_and_schema(data: Iterable[Any], schema: StructType):
from duckdb import Value
new_data = []
for row in data:
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema])]
new_data.append(new_row)
return new_data
class SparkSession:
def __init__(self, context: SparkContext):
self.conn = context.connection
self._context = context
self._conf = RuntimeConfig(self.conn)
def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame:
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
unique_name = f'pyspark_pandas_df_{uuid.uuid1()}'
self.conn.register(unique_name, data)
return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self)
def verify_tuple_integrity(tuples):
if len(tuples) <= 1:
return
expected_length = len(tuples[0])
for i, item in enumerate(tuples[1:]):
actual_length = len(item)
if expected_length == actual_length:
continue
raise PySparkTypeError(
error_class="LENGTH_SHOULD_BE_THE_SAME",
message_parameters={
"arg1": f"data{i}",
"arg2": f"data{i+1}",
"arg1_length": str(expected_length),
"arg2_length": str(actual_length)
},
)
if not isinstance(data, list):
data = list(data)
verify_tuple_integrity(data)
def construct_query(tuples) -> str:
def construct_values_list(row, start_param_idx):
parameter_count = len(row)
parameters = [f'${x+start_param_idx}' for x in range(parameter_count)]
parameters = '(' + ', '.join(parameters) + ')'
return parameters
row_size = len(tuples[0])
values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)]
values_list = ', '.join(values_list)
query = f"""
select * from (values {values_list})
"""
return query
query = construct_query(data)
def construct_parameters(tuples):
parameters = []
for row in tuples:
parameters.extend(list(row))
return parameters
parameters = construct_parameters(data)
rel = self.conn.sql(query, params=parameters)
return DataFrame(rel, self)
def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> DataFrame:
df = self._create_dataframe(data)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def createDataFrame(
self,
data: Union["PandasDataFrame", Iterable[Any]],
schema: Optional[Union[StructType, List[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
) -> DataFrame:
if samplingRatio:
raise NotImplementedError
if not verifySchema:
raise NotImplementedError
types = None
names = None
if isinstance(data, DataFrame):
raise PySparkTypeError(
error_class="SHOULD_NOT_DATAFRAME",
message_parameters={"arg_name": "data"},
)
if schema:
if isinstance(schema, StructType):
types, names = schema.extract_types_and_names()
else:
names = schema
try:
import pandas
has_pandas = True
except ImportError:
has_pandas = False
# Falsey check on pandas dataframe is not defined, so first check if it's not a pandas dataframe
# Then check if 'data' is None or []
if has_pandas and isinstance(data, pandas.DataFrame):
return self._createDataFrameFromPandas(data, types, names)
# Finally check if a schema was provided
is_empty = False
if not data and names:
# Create NULLs for every type in our dataframe
is_empty = True
data = [tuple(None for _ in names)]
if schema and isinstance(schema, StructType):
# Transform the data into Values to combine the data+schema
data = _combine_data_and_schema(data, schema)
df = self._create_dataframe(data)
if is_empty:
rel = df.relation
# Add impossible where clause
rel = rel.filter('1=0')
df = DataFrame(rel, self)
# Cast to types
if types:
df = df._cast_types(*types)
# Alias to names
if names:
df = df.toDF(*names)
return df
def newSession(self) -> "SparkSession":
return SparkSession(self._context)
def range(
self,
start: int,
end: Optional[int] = None,
step: int = 1,
numPartitions: Optional[int] = None,
) -> "DataFrame":
if numPartitions:
raise ContributionsAcceptedError
if end is None:
end = start
start = 0
return DataFrame(self.conn.table_function("range", parameters=[start, end, step]),self)
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame:
if kwargs:
raise NotImplementedError
relation = self.conn.sql(sqlQuery)
return DataFrame(relation, self)
def stop(self) -> None:
self._context.stop()
def table(self, tableName: str) -> DataFrame:
relation = self.conn.table(tableName)
return DataFrame(relation, self)
def getActiveSession(self) -> "SparkSession":
return self
@property
def catalog(self) -> "Catalog":
if not hasattr(self, "_catalog"):
from duckdb.experimental.spark.sql.catalog import Catalog
self._catalog = Catalog(self)
return self._catalog
@property
def conf(self) -> RuntimeConfig:
return self._conf
@property
def read(self) -> DataFrameReader:
return DataFrameReader(self)
@property
def readStream(self) -> DataStreamReader:
return DataStreamReader(self)
@property
def sparkContext(self) -> SparkContext:
return self._context
@property
def streams(self) -> Any:
raise ContributionsAcceptedError
@property
def udf(self) -> UDFRegistration:
return UDFRegistration(self)
@property
def version(self) -> str:
return '1.0.0'
class Builder:
def __init__(self):
pass
def master(self, name: str) -> "SparkSession.Builder":
# no-op
return self
def appName(self, name: str) -> "SparkSession.Builder":
# no-op
return self
def remote(self, url: str) -> "SparkSession.Builder":
# no-op
return self
def getOrCreate(self) -> "SparkSession":
context = SparkContext("__ignored__")
return SparkSession(context)
def config(
self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None
) -> "SparkSession.Builder":
return self
def enableHiveSupport(self) -> "SparkSession.Builder":
# no-op
return self
builder = Builder()
__all__ = ["SparkSession"]