Source code for matorage.optimizer.manager

# Copyright 2020-present Tae Hwan Jung
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import tables
import hashlib
import tempfile
from minio import Minio

from matorage.nas import NAS
from matorage.utils import check_nas, logger
from matorage.uploader import Uploader

_KB = 1024
"""The size of a Kilobyte in bytes"""

_MB = 1024 * _KB
"""The size of a Megabyte in bytes"""


[docs]class Manager(object): type = "optimizer" def __init__(self, config, num_worker_threads=4, multipart_upload_size=5 * _MB): self.config = config self.num_worker_threads = num_worker_threads self.multipart_upload_size = multipart_upload_size self._client = ( Minio( endpoint=self.config.endpoint, access_key=self.config.access_key, secret_key=self.config.secret_key, secure=self.config.secure, region=self.config.region, ) if not check_nas(self.config.endpoint) else NAS(self.config.endpoint) ) self._uploader = Uploader( client=self._client, bucket=self.config.bucket_name, num_worker_threads=self.num_worker_threads, multipart_upload_size=self.multipart_upload_size, inmemory=True, ) def _uploader_closing(self): self._uploader.join_queue() _metadata_file = tempfile.mktemp("metadata.json") with open(_metadata_file, "w", encoding="utf-8") as writer: writer.write(json.dumps(self.config.metadata, indent=4) + "\n") self._client.fput_object( bucket_name=self.config.bucket_name, object_name="metadata.json", file_path=_metadata_file, ) os.remove(_metadata_file) def _save_with_clear(self, step, optimizer, overwrite=False): if overwrite: objects = self._client.list_objects( bucket_name=self.config.bucket_name, prefix=f"{step}/" ) for obj in objects: self._client.remove_object( bucket_name=self.config.bucket_name, object_name=obj.object_name ) # saving optimizer self._save_optimizer(step, optimizer) self._uploader_closing() def _save_param(self, step, group, name, weight): group = str(group) _local_file = tempfile.mktemp(f"{name}.h5") _file = tables.open_file( _local_file, "w", driver="H5FD_CORE", driver_core_backing_store=False ) _file.create_carray( "/", self.type, obj=weight, filters=tables.Filters(**self.config.compressor) ) if group: self._uploader.set_queue( local_file=_file.get_file_image(), remote_file=f"{step}/{group}/{name}" ) else: self._uploader.set_queue( local_file=_file.get_file_image(), remote_file=f"{step}/{name}" ) _file.close() def save(self, optimizer, scheduler=None): if not self._client.bucket_exists(self.config.bucket_name): self._client.make_bucket( self.config.bucket_name, location=self.config.region ) step = self._get_step(optimizer) if not step: logger.error( "{} {} step({})is not exist".format( self.config.optimizer_name, self.config.additional, str(step) ) ) return if step in self.config.metadata["optimizer"]: logger.info( "{} {} is already exist, so optimizer will be overwrited.".format( self.config.optimizer_name, str(self.config.additional) ) ) self._save_with_clear(step, optimizer, overwrite=True) else: self._set_metadata( metadata=self.config.metadata, optimizer=optimizer, step=step ) self._save_with_clear(step, optimizer) if scheduler: self._set_scheduler( metadata=self.config.metadata, scheduler=scheduler, step=step ) logger.info("optimizer with {} is saved".format(str(step))) def load(self, optimizer, step): layers = self._client.list_objects( bucket_name=self.config.bucket_name, prefix=f"{step}/", recursive=True ) logger.info("optimizer with {} is loaded".format(str(step))) self._load_optimizer(step, layers, optimizer) @property def get_metadata(self): """ Get all optimizers according to metadata by step. Returns: :obj:`dict`: optimizer of metadata Examples:: >>> optimizer_manager = OptimizerManager(config=optimizer_config) >>> optimizer_manager.save(optimizer) >>> optimizer_manager.get_metadata {'938': { 'framework': 'pytorch', 'param_groups': [ { 'lr': 0.01, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [ 140516594711520, 140516594711760, 140517867028384, 140516594711680, 140516594693376, 140516594612336 ] } ] } } """ return self.config.metadata["optimizer"]