Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/datacustomcode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from datacustomcode.credentials import AuthType, Credentials
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.print import PrintDataCloudWriter
from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider
from datacustomcode.proxy.client.LocalProxyClientProvider import (
LocalProxyClientProvider,
)

__all__ = [
"AuthType",
Expand Down
64 changes: 21 additions & 43 deletions src/datacustomcode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,23 @@ class Client:
_reader: BaseDataCloudReader
_writer: BaseDataCloudWriter
_file: DefaultFindFilePath
_proxy: BaseProxyClient
_proxy: Optional[BaseProxyClient]
_data_layer_history: dict[DataCloudObjectType, set[str]]
_code_type: str

def __new__(
cls,
reader: Optional[BaseDataCloudReader] = None,
writer: Optional["BaseDataCloudWriter"] = None,
proxy: Optional[BaseProxyClient] = None,
spark_provider: Optional["BaseSparkSessionProvider"] = None,
code_type: str = "script",
) -> Client:
if "function" in code_type:
return cls._new_function_client()

if cls._instance is None:
cls._instance = super().__new__(cls)

spark = None
# Initialize Readers and Writers from config
# and/or provided reader and writer
if reader is None or writer is None:
Expand All @@ -139,22 +142,6 @@ def __new__(
provider = DefaultSparkSessionProvider()

spark = provider.get_session(config.spark_config)
elif (
proxy is None
and config.proxy_config is not None
and config.spark_config is not None
):
# Both reader and writer provided; we still need spark for proxy init
provider = (
spark_provider
if spark_provider is not None
else (
config.spark_provider_config.to_object()
if config.spark_provider_config is not None
else DefaultSparkSessionProvider()
)
)
spark = provider.get_session(config.spark_config)

if config.reader_config is None and reader is None:
raise ValueError(
Expand All @@ -163,44 +150,23 @@ def __new__(
elif reader is None or (
config.reader_config is not None and config.reader_config.force
):
if config.proxy_config is None:
raise ValueError(
"Proxy config is required when reader is built from config"
)
assert (
spark is not None
) # set in "reader is None or writer is None" branch
assert config.reader_config is not None # ensured by branch condition
proxy_init = config.proxy_config.to_object(spark)

reader_init = config.reader_config.to_object(spark)
reader_init = config.reader_config.to_object(spark) # type: ignore
else:
reader_init = reader
if proxy is not None:
proxy_init = proxy
elif config.proxy_config is None:
raise ValueError("Proxy config is required when reader is provided")
else:
assert (
spark is not None
) # set in "both provided; proxy from config" branch
proxy_init = config.proxy_config.to_object(spark)
if config.writer_config is None and writer is None:
raise ValueError(
"Writer config is required when writer is not provided"
)
elif writer is None or (
config.writer_config is not None and config.writer_config.force
):
assert spark is not None # set when reader or writer from config
assert config.writer_config is not None # ensured by branch condition
writer_init = config.writer_config.to_object(spark)
writer_init = config.writer_config.to_object(spark) # type: ignore
else:
writer_init = writer

cls._instance._reader = reader_init
cls._instance._writer = writer_init
cls._instance._file = DefaultFindFilePath()
cls._instance._proxy = proxy_init
cls._instance._data_layer_history = {
DataCloudObjectType.DLO: set(),
DataCloudObjectType.DMO: set(),
Expand All @@ -209,6 +175,16 @@ def __new__(
raise ValueError("Cannot set reader or writer after client is initialized")
return cls._instance

@classmethod
def _new_function_client(cls) -> Client:
cls._instance = super().__new__(cls)
cls._instance._proxy = (
config.proxy_config.to_object() # type: ignore
if config.proxy_config is not None
else None
)
return cls._instance

def read_dlo(self, name: str) -> PySparkDataFrame:
"""Read a DLO from Data Cloud.

Expand Down Expand Up @@ -260,6 +236,8 @@ def write_to_dmo(
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)

def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
if self._proxy is None:
raise ValueError("No proxy configured; set proxy or proxy_config")
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens)

def find_file_path(self, file_name: str) -> Path:
Expand Down
20 changes: 19 additions & 1 deletion src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
from datacustomcode.proxy.base import BaseProxyAccessLayer
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001
from datacustomcode.spark.base import BaseSparkSessionProvider

Expand Down Expand Up @@ -93,6 +94,23 @@ class SparkConfig(ForceableConfig):

_P = TypeVar("_P", bound=BaseSparkSessionProvider)

_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)


class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
"""Config for proxy clients that take no constructor args (e.g. no spark)."""

model_config = ConfigDict(validate_default=True, extra="forbid")
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
type_config_name: str = Field(
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
)
options: dict[str, Any] = Field(default_factory=dict)

def to_object(self) -> _PX:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_PX, type_(**self.options))


class SparkProviderConfig(ForceableConfig, Generic[_P]):
model_config = ConfigDict(validate_default=True, extra="forbid")
Expand All @@ -110,7 +128,7 @@ def to_object(self) -> _P:
class ClientConfig(BaseModel):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
spark_config: Union[SparkConfig, None] = None
spark_provider_config: Union[
SparkProviderConfig[BaseSparkSessionProvider], None
Expand Down
2 changes: 1 addition & 1 deletion src/datacustomcode/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from datacustomcode.mixin import UserExtendableNamedConfigMixin


class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin):
class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin):
def __init__(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ class LocalProxyClientProvider(BaseProxyClient):

CONFIG_NAME = "LocalProxyClientProvider"

def __init__(self, **kwargs: object) -> None:
pass

def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str:
return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}"
9 changes: 4 additions & 5 deletions src/datacustomcode/proxy/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

from abc import abstractmethod

from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.proxy.base import BaseProxyAccessLayer


class BaseProxyClient(BaseDataAccessLayer):
def __init__(self, spark=None, **kwargs):
if spark is not None:
super().__init__(spark)
class BaseProxyClient(BaseProxyAccessLayer):
def __init__(self):
pass

@abstractmethod
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...