Source code for ckipnlp.driver.base

#!/usr/bin/env python3
# -*- coding:utf-8 -*-

"""
This module provides base drivers.
"""

__author__ = 'Mu Yang <http://muyang.pro>'
__copyright__ = '2018-2023 CKIP Lab'
__license__ = 'GPL-3.0'

from abc import (
    ABCMeta as _ABCMeta,
    abstractmethod as _abstractmethod,
)

from ckipnlp.util.logger import (
    get_logger as _get_logger,
)

################################################################################################################################

[docs]class DriverRegister: """The driver registering utility.""" _DRIVERS = {} @staticmethod def get(driver_type, driver_family): # pylint: disable=missing-docstring if driver_family is None: return DummyDriver driver = DriverRegister._DRIVERS.get((driver_type, driver_family,)) if not driver: raise KeyError(f'{driver_type} is not implemented for type {driver_family}') if not driver.is_dummy: _get_logger().debug(f'Use {driver.__name__} ...') return driver
################################################################################################################################
[docs]class BaseDriver(metaclass=_ABCMeta): """The base CKIPNLP driver.""" is_dummy = False def __init__(self, *, lazy=False): self._core = None self._inited = False if not lazy: self.init() def init(self): # pylint: disable=missing-docstring if self._inited: return _get_logger().info(f'Initializing {self.__class__.__name__} ...') self._init() self._inited = True def __call__(self, *args, **kwargs): self.init() return self._call(*args, **kwargs) ######################################################################################################################## @_abstractmethod def driver_type(self): # pylint: disable=missing-docstring return NotImplemented @_abstractmethod def driver_family(self): # pylint: disable=missing-docstring return NotImplemented @_abstractmethod def driver_inputs(self): # pylint: disable=missing-docstring return NotImplemented @_abstractmethod def _init(self): return NotImplemented @_abstractmethod def _call(self): return NotImplemented def _call_from_pipeline(self, pipeline, doc): return self.__call__(**{ key: pipeline._get(key, doc) for key in self.driver_inputs # pylint: disable=protected-access }) ######################################################################################################################## def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) driver_type = cls.driver_type driver_family = cls.driver_family key = (driver_type, driver_family,) assert key not in DriverRegister._DRIVERS, f'{key} already registered!' # pylint: disable=protected-access DriverRegister._DRIVERS[key] = cls # pylint: disable=protected-access
################################################################################################################################
[docs]class DummyDriver(BaseDriver): """The dummy driver.""" driver_type = None driver_family = None driver_inputs = None is_dummy = True def __init__(self, *, lazy=False): super().__init__(lazy=lazy) self._inited = True def _init(self): pass def _call(self): raise NotImplementedError