Source code for offline_translator.core

"""Core module."""
from __future__ import annotations

import contextlib
import logging
import multiprocessing
import os
import pathlib
import queue
import string
from traceback import print_exc
from typing import Callable

ALPHA_NUM = set(string.ascii_letters) | set(string.digits)
MODEL = (
    pathlib.Path(__file__).parent
    / "resources"
    / "huggingface"
    / "Helsinki-NLP"
    / "opus-mt-fr-en"
).resolve()
logger = logging.getLogger(__name__)


def worker(
    pending: multiprocessing.Queue[str], done: multiprocessing.Queue[str]
) -> None:
    """Worker."""
    os.environ["TRANSFORMERS_OFFLINE"] = "1"
    os.environ["HF_DATASETS_OFFLINE"] = "1"
    os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL)
        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
        while True:
            text = pending.get()
            while not pending.empty():
                text = pending.get()
            has_text = set(text.strip()) & ALPHA_NUM
            if has_text:
                text = text.replace("\n", "<n>")
                input_ids = tokenizer.encode(
                    text,
                    return_tensors="pt",
                )
                outputs = model.generate(input_ids)
                answer: str = tokenizer.decode(
                    outputs[0],
                    skip_special_tokens=True,
                )
                answer = answer.replace("<n>", "\n")
            else:
                answer = text
            done.put_nowait(answer)
    except KeyboardInterrupt:
        pass
    # Do not propagate Exception in worker, just stop it
    except BaseException:  # noqa: BLE001
        print_exc()


[docs]class ModelWorker:
[docs] def __init__(self, callback: Callable[[str], None]) -> None: """Nothing to do here.""" self._pending: multiprocessing.Queue[str] = multiprocessing.Queue() self._done: multiprocessing.Queue[str] = multiprocessing.Queue() self._callback = callback self._worker: multiprocessing.Process | None = None self.load("fr", "en")
[docs] def load(self, __from: str, __to: str, /) -> None: """Load language model.""" self.model_name(__from, __to) if self._worker is not None: self.close() self._pending = multiprocessing.Queue() self._done = multiprocessing.Queue() self._worker = multiprocessing.Process( target=worker, args=( self._pending, self._done, ), ) self._worker.start()
[docs] @staticmethod def model_name(__from: str, __to: str, /) -> str: """Get model name from destination and source language.""" return f"Helsinki-NLP/opus-mt-{__from}-{__to}"
[docs] def submit(self, text: str) -> None: """Translate text.""" logger.info("Submit %r", text) self._pending.put(text)
[docs] def update(self) -> None: """Call all callback with the current text.""" try: while True: translation = self._done.get_nowait() logger.info("Translation %r", translation) self._callback(translation) except queue.Empty: pass
[docs] def close(self) -> None: """Close the worker.""" self._pending.close() self._done.close() with contextlib.suppress(AttributeError): self._worker.terminate() # type: ignore[union-attr]
[docs] def __del__(self) -> None: """Cleanup object.""" self.close()