Source code for easytexminer.utils.initializer

# coding=utf-8
# Copyright (c) 2020 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""EasyTexMiner initialization."""
import random
import os
import time
import numpy as np

import torch
from .global_vars import get_args, set_global_variables, get_tensorboard_writer
from .logger import init_logger
from .io_utils import TFOSSIO, parse_oss_buckets, OSSIO, io

[docs]def initialize_easytexminer(extra_args_provider=None, args_defaults={}, ignore_unknown_args=False): # Make sure cuda is available. assert torch.cuda.is_available(), 'EasyTexMiner requires CUDA.' set_global_variables(extra_args_provider=extra_args_provider, args_defaults=args_defaults, ignore_unknown_args=ignore_unknown_args) args = get_args() _initialize_distributed() # Random seeds for reproducibility. if args.rank == 0: print('> setting random seeds to {} ...'.format(args.seed)) _set_random_seed(args.seed) os.environ['EASYTEXMINER_MODELZOO_BASE_DIR'] = args.model_zoo_base_dir os.environ["EASYTEXMINER_IS_MASTER"] = str(args.is_master_node) os.environ["EASYTEXMINER_N_GPUS"] = str(args.n_gpu) init_logger(local_rank=args.rank) if args.buckets is not None: init_oss_io(args) if args.mode == "train" or not args.checkpoint_dir: from . import set_local_pretrained_model_dir args.pretrained_model_name_or_path = set_local_pretrained_model_dir(args.pretrained_model_name_or_path) else: args.pretrained_model_name_or_path = args.checkpoint_dir # Compile dependencies. #_compile_dependencies() # No continuation function return None
def _compile_dependencies(): args = get_args() # Always build on rank zero first. if torch.distributed.get_rank() == 0: start_time = time.time() print('> compiling and loading fused kernels ...', flush=True) fused_kernels.load(args) torch.distributed.barrier() else: torch.distributed.barrier() fused_kernels.load(args) # Simple barrier to make sure all ranks have passed the # compilation phase successfully before moving on to the # rest of the program. We think this might ensure that # the lock is released. torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>>> done with compiling and loading fused kernels. ' 'Compilation time: {:.3f} seconds'.format( time.time() - start_time), flush=True) def _initialize_distributed(): """Initialize torch.distributed and mpu.""" args = get_args() device_count = torch.cuda.device_count() if torch.distributed.is_initialized(): if args.rank == 0: print('torch distributed is already initialized, ' 'skipping initialization ...', flush=True) args.rank = torch.distributed.get_rank() args.world_size = torch.distributed.get_world_size() else: if args.rank == 0: print('> initializing torch distributed ...', flush=True) # Manually set the device ids. if device_count > 0: device = args.rank % device_count if args.local_rank is not None: assert args.local_rank == device, \ 'expected local-rank to be the same as rank % device-count.' else: args.local_rank = device torch.cuda.set_device(device) # Call the init process init_method = 'tcp://' master_ip = os.getenv('MASTER_ADDR', 'localhost') master_port = os.getenv('MASTER_PORT', '6000') init_method += master_ip + ':' + master_port torch.distributed.init_process_group( backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method) def _set_random_seed(seed): """Set random seed for reproducability.""" if seed is not None and seed > 0: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) else: raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
[docs]def write_args_to_tensorboard(): """Write arguments to tensorboard.""" args = get_args() writer = get_tensorboard_writer() if writer: for arg in vars(args): writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
[docs]def init_oss_io(cfg): if "role_arn" in cfg.buckets: new_io = TFOSSIO() else: access_key_id, access_key_secret, hosts, buckets = parse_oss_buckets(cfg.buckets) if cfg.model_zoo_base_dir and "oss://" in cfg.model_zoo_base_dir: _, _, mz_hosts, mz_buckets = parse_oss_buckets(cfg.model_zoo_base_dir) hosts += mz_hosts buckets += mz_buckets new_io = OSSIO(access_key_id=access_key_id, access_key_secret=access_key_secret, hosts=hosts, buckets=buckets) io.set_io(new_io)
[docs]def init_odps_io(odps_config): os.environ['ODPS_CONFIG_FILE_PATH'] = odps_config