Source code for matorage.model.manager

# Copyright 2020-present Tae Hwan Jung
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import tables
import hashlib
import tempfile
from minio import Minio

from matorage.nas import NAS
from matorage.utils import check_nas, logger
from matorage.uploader import Uploader

_KB = 1024
"""The size of a Kilobyte in bytes"""

_MB = 1024 * _KB
"""The size of a Megabyte in bytes"""


[docs]class Manager(object): type = "model" def __init__(self, config, num_worker_threads=4, multipart_upload_size=5 * _MB): self.config = config self.num_worker_threads = num_worker_threads self.multipart_upload_size = multipart_upload_size self._client = ( Minio( endpoint=self.config.endpoint, access_key=self.config.access_key, secret_key=self.config.secret_key, secure=self.config.secure, region=self.config.region, ) if not check_nas(self.config.endpoint) else NAS(self.config.endpoint) ) self._uploader = Uploader( client=self._client, bucket=self.config.bucket_name, num_worker_threads=self.num_worker_threads, multipart_upload_size=self.multipart_upload_size, inmemory=True, ) def _uploader_closing(self): self._uploader.join_queue() _metadata_file = tempfile.mktemp("metadata.json") with open(_metadata_file, "w", encoding="utf-8") as writer: writer.write( json.dumps(self.config.metadata, indent=4, sort_keys=True) + "\n" ) self._client.fput_object( bucket_name=self.config.bucket_name, object_name="metadata.json", file_path=_metadata_file, ) os.remove(_metadata_file) def _save_with_clear(self, model_folder, model, overwrite=False): if overwrite: objects = self._client.list_objects( bucket_name=self.config.bucket_name, prefix=f"{model_folder}/" ) for obj in objects: self._client.remove_object( bucket_name=self.config.bucket_name, object_name=obj.object_name ) # saving model self._save_model(model_folder, model) self._uploader_closing() def _save_layer(self, model_folder, name, weight): _local_file = tempfile.mktemp(f"{name}.h5") _file = tables.open_file( _local_file, "w", driver="H5FD_CORE", driver_core_backing_store=False ) _file.create_carray( "/", self.type, obj=weight, filters=tables.Filters(**self.config.compressor) ) self._uploader.set_queue( local_file=_file.get_file_image(), remote_file=f"{model_folder}/{name}" ) _file.close() def save(self, model, **kwargs): if not self._client.bucket_exists(self.config.bucket_name): self._client.make_bucket( self.config.bucket_name, location=self.config.region ) if not isinstance(kwargs, dict): metadata = 0 else: metadata = kwargs model_folder = self._hashmap_transfer(metadata) if model_folder in self.config.metadata["model"]: logger.info( "{} {} is already exist, so model will be overwrited.".format( self.config.model_name, str(self.config.additional) ) ) self._save_with_clear(model_folder, model, overwrite=True) else: self.config.metadata["model"].update({model_folder: metadata}) self._save_with_clear(model_folder, model) logger.info("model with {} is saved".format(str(metadata))) def load(self, model, **kwargs): if not isinstance(kwargs, dict): metadata = 0 else: metadata = kwargs model_folder = self._hashmap_transfer(metadata) layers = self._client.list_objects( bucket_name=self.config.bucket_name, prefix=f"{model_folder}/", recursive=True, ) logger.info("model with {} is loaded".format(str(metadata))) return self._load_model(model_folder, layers, model) def _hashmap_transfer(self, metadata): """ Get unikey object folder with `metadata` of model. Returns: :obj: `str`: """ if isinstance(metadata, int): metadata = str(metadata) if not isinstance(metadata, str) and not isinstance(metadata, dict): raise ValueError( "metadata {} is empty or not str and dict type".format(metadata) ) key = json.dumps(metadata, indent=4, sort_keys=True) return hashlib.md5(key.encode("utf-8")).hexdigest() @property def get_metadata(self): """ Get all models according to metadata(ex. step, epoch) Returns: :obj:`dict`: model of metadata Examples:: >>> model_manager.save(model, step=100) >>> model_manager.save(model, step=200) >>> model_manager.get_metadata { 'additional': {'version': '1.0.1'}, 'compressor': {'complevel': 0, 'complib': 'zlib'}, 'endpoint': '127.0.0.1:9000', 'model': { 'ad44168f1343bc77b4d9ad6f1fef50b6': {'step': 100}, 'af0677ecf0d15d17d10204be9ff2f9f5': {'step': 200} }, 'model_name': 'testmodel' } """ return self.config.metadata