|
import importlib |
|
import inspect |
|
import os |
|
|
|
from .artifact import Artifact, Artifactories |
|
from .catalog import LocalCatalog, GithubCatalog, PATHS_SEP |
|
from .utils import Singleton |
|
|
|
|
|
UNITXT_ARTIFACTORIES_ENV_VAR = 'UNITXT_ARTIFACTORIES' |
|
|
|
|
|
non_registered_files = [ |
|
"__init__.py", |
|
"artifact.py", |
|
"utils.py", |
|
"register.py", |
|
"metric.py", |
|
"dataset.py", |
|
"blocks.py", |
|
] |
|
|
|
|
|
def _register_all_catalogs(): |
|
Artifactories().register_atrifactory(LocalCatalog()) |
|
if UNITXT_ARTIFACTORIES_ENV_VAR in os.environ: |
|
for path in os.environ[UNITXT_ARTIFACTORIES_ENV_VAR].split(PATHS_SEP): |
|
Artifactories().register_atrifactory(LocalCatalog(location=path)) |
|
Artifactories().register_atrifactory(GithubCatalog()) |
|
|
|
def _register_all_artifacts(): |
|
dir = os.path.dirname(__file__) |
|
file_name = os.path.basename(__file__) |
|
|
|
for file in os.listdir(dir): |
|
if file.endswith(".py") and file not in non_registered_files and file != file_name: |
|
module_name = file.replace(".py", "") |
|
|
|
module = importlib.import_module("." + module_name, __package__) |
|
|
|
for name, obj in inspect.getmembers(module): |
|
|
|
if inspect.isclass(obj): |
|
|
|
if issubclass(obj, Artifact) and obj is not Artifact: |
|
Artifact.register_class(obj) |
|
|
|
|
|
class ProjectArtifactRegisterer(metaclass=Singleton): |
|
def __init__(self): |
|
if not hasattr(self, "_registered"): |
|
self._registered = False |
|
|
|
if not self._registered: |
|
_register_all_catalogs() |
|
_register_all_artifacts() |
|
self._registered = True |
|
|
|
|
|
def register_all_artifacts(): |
|
ProjectArtifactRegisterer() |
|
|