Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datacustomcode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from datacustomcode.credentials import AuthType, Credentials
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.print import PrintDataCloudWriter
from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider

__all__ = [
"AuthType",
"Client",
"Credentials",
"LocalProxyClientProvider",
"PrintDataCloudWriter",
"QueryAPIDataCloudReader",
]
49 changes: 47 additions & 2 deletions src/datacustomcode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
from datacustomcode.proxy.client.base import BaseProxyClient
from datacustomcode.spark.base import BaseSparkSessionProvider


Expand Down Expand Up @@ -106,17 +107,20 @@ class Client:
_reader: BaseDataCloudReader
_writer: BaseDataCloudWriter
_file: DefaultFindFilePath
_proxy: BaseProxyClient
_data_layer_history: dict[DataCloudObjectType, set[str]]

def __new__(
cls,
reader: Optional[BaseDataCloudReader] = None,
writer: Optional["BaseDataCloudWriter"] = None,
proxy: Optional[BaseProxyClient] = None,
spark_provider: Optional["BaseSparkSessionProvider"] = None,
) -> Client:
if cls._instance is None:
cls._instance = super().__new__(cls)

spark = None
# Initialize Readers and Writers from config
# and/or provided reader and writer
if reader is None or writer is None:
Expand All @@ -135,6 +139,22 @@ def __new__(
provider = DefaultSparkSessionProvider()

spark = provider.get_session(config.spark_config)
elif (
proxy is None
and config.proxy_config is not None
and config.spark_config is not None
):
# Both reader and writer provided; we still need spark for proxy init
provider = (
spark_provider
if spark_provider is not None
else (
config.spark_provider_config.to_object()
if config.spark_provider_config is not None
else DefaultSparkSessionProvider()
)
)
spark = provider.get_session(config.spark_config)

if config.reader_config is None and reader is None:
raise ValueError(
Expand All @@ -143,22 +163,44 @@ def __new__(
elif reader is None or (
config.reader_config is not None and config.reader_config.force
):
reader_init = config.reader_config.to_object(spark) # type: ignore
if config.proxy_config is None:
raise ValueError(
"Proxy config is required when reader is built from config"
)
assert (
spark is not None
) # set in "reader is None or writer is None" branch
assert config.reader_config is not None # ensured by branch condition
proxy_init = config.proxy_config.to_object(spark)

reader_init = config.reader_config.to_object(spark)
else:
reader_init = reader
if proxy is not None:
proxy_init = proxy
elif config.proxy_config is None:
raise ValueError("Proxy config is required when reader is provided")
else:
assert (
spark is not None
) # set in "both provided; proxy from config" branch
proxy_init = config.proxy_config.to_object(spark)
if config.writer_config is None and writer is None:
raise ValueError(
"Writer config is required when writer is not provided"
)
elif writer is None or (
config.writer_config is not None and config.writer_config.force
):
writer_init = config.writer_config.to_object(spark) # type: ignore
assert spark is not None # set when reader or writer from config
assert config.writer_config is not None # ensured by branch condition
writer_init = config.writer_config.to_object(spark)
else:
writer_init = writer
cls._instance._reader = reader_init
cls._instance._writer = writer_init
cls._instance._file = DefaultFindFilePath()
cls._instance._proxy = proxy_init
cls._instance._data_layer_history = {
DataCloudObjectType.DLO: set(),
DataCloudObjectType.DMO: set(),
Expand Down Expand Up @@ -217,6 +259,9 @@ def write_to_dmo(
self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO)
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)

def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens)

def find_file_path(self, file_name: str) -> Path:
"""Return a file path"""

Expand Down
3 changes: 3 additions & 0 deletions src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001
from datacustomcode.spark.base import BaseSparkSessionProvider

DEFAULT_CONFIG_NAME = "config.yaml"
Expand Down Expand Up @@ -109,6 +110,7 @@ def to_object(self) -> _P:
class ClientConfig(BaseModel):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None
spark_config: Union[SparkConfig, None] = None
spark_provider_config: Union[
SparkProviderConfig[BaseSparkSessionProvider], None
Expand Down Expand Up @@ -136,6 +138,7 @@ def merge(

self.reader_config = merge(self.reader_config, other.reader_config)
self.writer_config = merge(self.writer_config, other.writer_config)
self.proxy_config = merge(self.proxy_config, other.proxy_config)
self.spark_config = merge(self.spark_config, other.spark_config)
self.spark_provider_config = merge(
self.spark_provider_config, other.spark_provider_config
Expand Down
5 changes: 5 additions & 0 deletions src/datacustomcode/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ spark_config:
spark.submit.deployMode: client
spark.sql.execution.arrow.pyspark.enabled: 'true'
spark.driver.extraJavaOptions: -Djava.security.manager=allow

proxy_config:
type_config_name: LocalProxyClientProvider
options:
credentials_profile: default
14 changes: 14 additions & 0 deletions src/datacustomcode/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
24 changes: 24 additions & 0 deletions src/datacustomcode/proxy/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC

from datacustomcode.mixin import UserExtendableNamedConfigMixin


class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin):
def __init__(self):
pass
14 changes: 14 additions & 0 deletions src/datacustomcode/proxy/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
28 changes: 28 additions & 0 deletions src/datacustomcode/proxy/client/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import abstractmethod

from datacustomcode.io.base import BaseDataAccessLayer


class BaseProxyClient(BaseDataAccessLayer):
def __init__(self, spark=None, **kwargs):
if spark is not None:
super().__init__(spark)

@abstractmethod
def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...
26 changes: 26 additions & 0 deletions src/datacustomcode/proxy/client/local_proxy_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from datacustomcode.proxy.client.base import BaseProxyClient


class LocalProxyClientProvider(BaseProxyClient):
"""Default proxy client provider."""

CONFIG_NAME = "LocalProxyClientProvider"

def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str:
return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}"
35 changes: 20 additions & 15 deletions src/datacustomcode/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import List, Union

from datacustomcode.config import config
from datacustomcode.scan import get_package_type


def _set_config_option(config_obj, key: str, value: str) -> None:
Expand Down Expand Up @@ -60,6 +61,8 @@ def run_entrypoint(
f"config.json not found at {config_json_path}. config.json is required."
)

package_type = get_package_type(entrypoint_dir)

try:
with open(config_json_path, "r") as f:
config_json = json.load(f)
Expand All @@ -68,21 +71,23 @@ def run_entrypoint(
f"config.json at {config_json_path} is not valid JSON"
) from err

# Require dataspace to be present in config.json
dataspace = config_json.get("dataspace")
if not dataspace:
raise ValueError(
f"config.json at {config_json_path} is missing required field 'dataspace'. "
f"Please ensure config.json contains a 'dataspace' field."
)

# Load config file first
if config_file:
config.load(config_file)

# Add dataspace to reader and writer config options
_set_config_option(config.reader_config, "dataspace", dataspace)
_set_config_option(config.writer_config, "dataspace", dataspace)
if package_type == "script":
# Require dataspace to be present in config.json
dataspace = config_json.get("dataspace")
if not dataspace:
raise ValueError(
f"config.json at {config_json_path} is missing required "
f"field 'dataspace'. "
f"Please ensure config.json contains a 'dataspace' field."
)

# Load config file first
if config_file:
config.load(config_file)

# Add dataspace to reader and writer config options
_set_config_option(config.reader_config, "dataspace", dataspace)
_set_config_option(config.writer_config, "dataspace", dataspace)

if profile != "default":
_set_config_option(config.reader_config, "credentials_profile", profile)
Expand Down
Loading