Skip to content
12 changes: 6 additions & 6 deletions sagemaker-core/src/sagemaker/core/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def wrapper(*args, **kwargs):
s3_uri=s3_path_join(
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
),
hmac_key=job.hmac_key,

)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down Expand Up @@ -403,7 +403,7 @@ def wrapper(*args, **kwargs):
return serialization.deserialize_obj_from_s3(
sagemaker_session=job_settings.sagemaker_session,
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
hmac_key=job.hmac_key,

)

if job.describe()["TrainingJobStatus"] == "Stopped":
Expand Down Expand Up @@ -983,7 +983,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
job_return = serialization.deserialize_obj_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
hmac_key=job.hmac_key,

)
except DeserializationError as e:
client_exception = e
Expand All @@ -995,7 +995,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
job_exception = serialization.deserialize_exception_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
hmac_key=job.hmac_key,

)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down Expand Up @@ -1085,7 +1085,7 @@ def result(self, timeout: float = None) -> Any:
self._return = serialization.deserialize_obj_from_s3(
sagemaker_session=self._job.sagemaker_session,
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
hmac_key=self._job.hmac_key,

)
self._state = _FINISHED
return self._return
Expand All @@ -1094,7 +1094,7 @@ def result(self, timeout: float = None) -> Any:
self._exception = serialization.deserialize_exception_from_s3(
sagemaker_session=self._job.sagemaker_session,
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
hmac_key=self._job.hmac_key,

)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ class _DelayedReturnResolver:
def __init__(
self,
delayed_returns: List[_DelayedReturn],
hmac_key: str,
properties_resolver: _PropertiesResolver,
parameter_resolver: _ParameterResolver,
execution_variable_resolver: _ExecutionVariableResolver,
Expand All @@ -175,7 +174,6 @@ def __init__(

Args:
delayed_returns: list of delayed returns to resolve.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
properties_resolver: resolver used to resolve step properties.
parameter_resolver: resolver used to pipeline parameters.
execution_variable_resolver: resolver used to resolve execution variables.
Expand All @@ -197,7 +195,6 @@ def deserialization_task(uri):
return uri, deserialize_obj_from_s3(
sagemaker_session=settings["sagemaker_session"],
s3_uri=uri,
hmac_key=hmac_key,
)

with ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -247,7 +244,6 @@ def resolve_pipeline_variables(
context: Context,
func_args: Tuple,
func_kwargs: Dict,
hmac_key: str,
s3_base_uri: str,
**settings,
):
Expand All @@ -257,7 +253,6 @@ def resolve_pipeline_variables(
context: context for the execution.
func_args: function args.
func_kwargs: function kwargs.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
**settings: settings to pass to the deserialization function.
Expand All @@ -280,7 +275,6 @@ def resolve_pipeline_variables(
properties_resolver = _PropertiesResolver(context)
delayed_return_resolver = _DelayedReturnResolver(
delayed_returns=delayed_returns,
hmac_key=hmac_key,
properties_resolver=properties_resolver,
parameter_resolver=parameter_resolver,
execution_variable_resolver=execution_variable_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io

import sys
import hmac
import hashlib
import pickle

Expand Down Expand Up @@ -156,15 +155,14 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:

# TODO: use dask serializer in case dask distributed is installed in users' environment.
def serialize_func_to_s3(
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
):
"""Serializes function and uploads it to S3.

Args:
sagemaker_session (sagemaker.core.helper.session.Session):
The underlying Boto3 session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
func: function to be serialized and persisted
Raises:
Expand All @@ -173,14 +171,13 @@ def serialize_func_to_s3(

_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(func),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
"""Downloads from S3 and then deserializes data objects.

This method downloads the serialized training job outputs to a temporary directory and
Expand All @@ -190,7 +187,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
sagemaker_session (sagemaker.core.helper.session.Session):
The underlying sagemaker session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
Returns :
The deserialized function.
Raises:
Expand All @@ -203,14 +199,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def serialize_obj_to_s3(
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
):
"""Serializes data object and uploads it to S3.

Expand All @@ -219,15 +215,13 @@ def serialize_obj_to_s3(
The underlying Boto3 session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
obj: object to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
"""

_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(obj),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
Expand Down Expand Up @@ -274,14 +268,13 @@ def json_serialize_obj_to_s3(
)


def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
"""Downloads from S3 and then deserializes data objects.

Args:
sagemaker_session (sagemaker.core.helper.session.Session):
The underlying sagemaker session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
Returns :
Deserialized python objects.
Raises:
Expand All @@ -295,14 +288,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def serialize_exception_to_s3(
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
):
"""Serializes exception with traceback and uploads it to S3.

Expand All @@ -311,7 +304,6 @@ def serialize_exception_to_s3(
The underlying Boto3 session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
exc: Exception to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
Expand All @@ -320,7 +312,6 @@ def serialize_exception_to_s3(

_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(exc),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
Expand All @@ -329,7 +320,6 @@ def serialize_exception_to_s3(

def _upload_payload_and_metadata_to_s3(
bytes_to_upload: Union[bytes, io.BytesIO],
hmac_key: str,
s3_uri: str,
sagemaker_session: Session,
s3_kms_key,
Expand All @@ -338,15 +328,14 @@ def _upload_payload_and_metadata_to_s3(

Args:
bytes_to_upload (bytes): Serialized bytes to upload.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
sagemaker_session (sagemaker.core.helper.session.Session):
The underlying Boto3 session which AWS service calls are delegated to.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
"""
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
sha256_hash = _compute_hash(bytes_to_upload)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
Expand All @@ -356,14 +345,13 @@ def _upload_payload_and_metadata_to_s3(
)


def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
"""Downloads from S3 and then deserializes exception.

Args:
sagemaker_session (sagemaker.core.helper.session.Session):
The underlying sagemaker session which AWS service calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
Returns :
Deserialized exception with traceback.
Raises:
Expand All @@ -377,7 +365,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
Expand All @@ -403,19 +391,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
) from e


def _compute_hash(buffer: bytes, secret_key: str) -> str:
"""Compute the hmac-sha256 hash"""
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
def _compute_hash(buffer: bytes) -> str:
"""Compute the sha256 hash"""
return hashlib.sha256(buffer).hexdigest()


def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
"""Performs integrity checks for serialized code/arguments uploaded to s3.

Verifies whether the hash read from s3 matches the hash calculated
during remote function execution.
"""
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
actual_hash_value = _compute_hash(buffer=buffer)
if expected_hash_value != actual_hash_value:
raise DeserializationError(
"Integrity check for the serialized function or data failed. "
"Please restrict access to your S3 bucket"
Expand Down
Loading
Loading