litecow_models package¶
Submodules¶
litecow_models.cli module¶
litecow_models.model module¶
- class litecow_models.model.ModelLoader¶
Bases:
object
- static export_model(model_bucket: str, model_name: str) → None¶
Export model from model registry
- Parameters
model_bucket (str) – Model registry bucket name
model_name (str) – Name of model
- static import_model(source: str, model_bucket: str, model_name: str, model_version: str) → None¶
Import model into model registry
- Parameters
source (str) – Source URL of model
model_bucket (str) – Model registry bucket name
model_name (str) – Name of model
model_version (str) – Version tag for the uploaded object in s3.
- litecow_models.model.convert_github_raw(scheme: str, netloc: str, path: str) → str¶
Convert github blob URL to raw content URL
- Parameters
scheme (str) – URL scheme
netloc (str) – Network host address
path (str) – Path to file
- litecow_models.model.create_s3_client() → <module ‘botocore.client’ from ‘/usr/local/lib/python3.7/site-packages/botocore/client.py’>¶
Create boto3 S3 client
- litecow_models.model.download_file(url: str) → str¶
Convert github blob URL to raw content URL
- Parameters
url (str) – Source URL of download
- Returns
local_filename – Local filename of downloaded file
- Return type
str
- litecow_models.model.initialize_s3(bucket_name: str) → None¶
Connects to S3, insures that a bucket exists and it has versioning enabled.
- Parameters
bucket_name (str) – Name of the s3 bucket to init.
- litecow_models.model.pytorch_to_onnx_file(net: torch.nn.modules.module.Module, out_file: Union[str, pathlib.Path], model_input_height: int, model_input_width: int, dynamic_shape: Optional[bool] = True, output_names: Optional[List[str]] = None, dummy_forward_input: Optional[torch.Tensor] = None) → None¶
Serialize a pytorch network with some given dummy input and dynamic batch sizes. The serialized model is done through tracing and/or scripting as shown in https://pytorch.org/docs/stable/onnx.html#tracing-vs-scripting. If the network’s forward function uses control or loop structures, only some of these can be captured with tracing/scripting.
- Parameters
net (torch.nn.Module) – Network to serialize.
out_file (Union[str, Path]) – File path to export the onnx file to.
model_input_height (int) – Model input dimension height.
model_input_width (int) – Model input dimension width.
dynamic_shape (Optional[bool]) – Whether or not the spatial dimensions of the input are dynamic. Defaults to False, meaning the shape is static.
output_names (Optional[List[str]]) – List of output headnames for mutli-head classification networks. Defaults to None which will be assigned to [“output”].
dummy_forward_input (Optional[torch.Tensor]) – Dummy forward pass input that will be run through the model for tracing export purposes. Defaults to None meaning a dummy input will be generated.
- litecow_models.model.serialize_model_to_file(net: torch.nn.modules.module.Module, out_path: Union[str, pathlib.Path], dummy_forward_input: torch.Tensor, output_names: List[str], dynamic_axes: Dict[str, Dict[int, str]]) → None¶
Serializes a given pytorch model to a file path. Main use for this is verifying tempfile IOs work with serialization and deserialization of onnx models.
- Parameters
net (torch.nn.Module) – Pytorch model to serialize.
out_path (Union[str, Path]) – Path where the serialized model should be written to.
dummy_forward_input (torch.Tensor) – Forward input for the model to be used for tracing the model and creating the onnx execution graph for serialization.
output_names (List[str]) – Name of the outputs for the model.
dynamic_axes (Dict[str, Dict[int, str]]) – Dynamic axes configuration dictionary for onnx export.
- litecow_models.model.verify_model_version(model_bucket: str, model_name: str, model_version: str, s3_client: mypy_boto3_s3.client.S3Client) → None¶
Verify that the given model and version do not conflict with pre-existing S3 objects.
- Parameters
model_bucket (str) – Bucket that models are uploaded to.
model_name (str) – Name of the model and bucket object.
model_version (str) – New version for a model.
s3_client (S3Client) – S3 client for connecting to s3.
- Raises
ValueError – If the given model and version already exist in S3.