# Copyright (c) 2023, Trustees of the University of Pennsylvania
# See LICENSE for licensing conditions
"""HTK command line tool wrappers."""
from dataclasses import dataclass
from math import log
from pathlib import Path
import subprocess
from subprocess import CalledProcessError
from typing import Iterable
from .utils import which
__all__ = ['HTKError', 'HTKSegfault', 'HViteConfig', 'hvite', 'write_hmmdefs']
[docs]@dataclass
class HViteConfig:
    """HVite decoding configuration
    Parameters
    ----------
    slf_path : pathlib.Path
        Path to HTK SLF file defining the recognition network.
    hmmdefs_path : pathlib.Path
        Path to HTK MMF file containing HMM definitions.
    macros_path : pathlib.Path
        Path to HTK MMF file containing additional macro definitions (e.g.,
        variance floors).
    config_path : pathlib.Path
        Path to HTK configuration file defining expected source audio format
        and feature extraction pipeline.
    dict_path : pathlib.Path
        Path to pronunciation dictionary.
    monophones_path : pathlib.Path
        Path to file listing HMMs to load from the MMF files.
    """
    slf_path: Path
    hmmdefs_path: Path
    macros_path: Path
    config_path: Path
    dict_path: Path
    monophones_path: Path
[docs]    @staticmethod
    def from_model_dir(model_dir):
        """Construct :ref:`HViteConfig` from contents ofa model directory.
        TODO
        Parameters
        ----------
        model_dir : pathlib.Path
            Model directory.
        Returns
        -------
        HViteConfig
        """
        model_dir = Path(model_dir)
        return HViteConfig(
            model_dir / 'phone_net',
            model_dir / 'hmmdefs',
            model_dir / 'macros',
            model_dir / 'config',
            model_dir / 'dict',
            model_dir / 'monophones')  
[docs]class HTKError(Exception):
    """Call to HTK command line tool failed.""" 
class HTKSegfault(HTKError):
    """Call to HTK command line tool resulted in segmentation fault.."""
[docs]def hvite(wav_path, config, working_dir):
    """Perform Viterbi decoding for WAV file.
    Parameters
    ----------
    wav_path : pathlib.Path
        Path to WAV file to be decoded.
    config : HViteConfig
        Config file defining paths to files defining network.
    working_dir : pathlib.Path
        Path to working directory for intermediate and output files.
    Returns
    -------
    lab_path : pathlib.Path
        Path to output label file.
    """
    # Check that HVite exists.
    # TODO: Update link when docs are online.
    if not which('HVite'):
        raise FileNotFoundError(
            f'HVite is not installed. Please install HTK and try again: '
            f'[INSERT LINK TO INSTRUCTIONS HERE]') from None
    # Run HVite.
    wav_path = Path(wav_path)
    cmd = ['HVite',
           '-T', '0',
           '-w', str(config.slf_path),
           '-l', str(working_dir),
           '-H', str(config.macros_path),
           '-H', str(config.hmmdefs_path),
           '-C', str(config.config_path),
           '-p', '-0.3',  # TODO: Pass as param.
           '-s', '5.0',
           '-y', 'lab',
           str(config.dict_path),
           str(config.monophones_path),
           wav_path
           ]
    try:
        subprocess.run(cmd, capture_output=True, text=True, check=True)
    except CalledProcessError as e:
        if e.returncode == -11:
            raise HTKSegfault('HVite call caused segfault.') from None
        elif e.stderr:
            raise HTKError(f'HVite failed with following error: \n{e.stderr}') from None
        else:
            raise e
    return wav_path.with_suffix('.lab') 
[docs]def write_hmmdefs(old_hmmdefs_path, new_hmmdefs_path, speech_scale_factor=1,
                  speech_phones=None):
    """Modify an HTK hmmdefs file in which speech model acoustic likelihoods
    are scaled by ``speech_scale_factor``.
    Parameters
    ----------
    old_hmmdefs_path : pathlib.Path
        Path to original HTK `hmmdefs` file.
    new_hmmsdefs_path : str
        Path for modified HTK `hmmdefs` file. If file already exists, it
        will be overwritten.
    speech_scale_factor : float, optional
        Factor by which speech model acoustic likelihoods are scaled prior to
        beam search.
        (Default: 1)
    speech_phones : Iterable[str], optional
        Names of speech phones. Only relevant when `speech_scale_factor != 1`.
        If None, `speech_scale_factor` has no effect.
        (Default: None)
    """
    old_hmmdefs_path = Path(old_hmmdefs_path)
    new_hmmdefs_path = Path(new_hmmdefs_path)
    if speech_phones is None:
        speech_phones = set()
    speech_phones = set(speech_phones)
    with open(old_hmmdefs_path, 'r', encoding='utf-8') as f:
        with open(new_hmmdefs_path, 'w', encoding='utf-8') as g:
            # Header.
            for _ in range(3):
                g.write(f.readline())
            # Model definitions.
            curr_phone = None
            for line in f:
                if line.startswith('~h'):
                    curr_phone = line[3:].strip('"\n')
                if (line.startswith('<GCONST>') and
                    speech_scale_factor != 1 and
                    curr_phone in speech_phones):
                    # Modify GCONST only for mixtures of speech models.
                    gconst = float(line[9:-1])
                    gconst += log(speech_scale_factor)
                    line = f'<GCONST> {gconst:.6e}\n'
                g.write(line)