diff --git a/.dockerignore b/.dockerignore index c42a40019d..66cb3564d1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,3 @@ * -!requirements* \ No newline at end of file +!requirements* +!_requirements* diff --git a/.fs_cache/.keep b/.fs_cache/.keep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000000..45572b2246 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,2 @@ +patreon: faceswap +github: deepfakes diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index a3b1c7d447..0d01824152 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,6 +1,6 @@ -**Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap-playground](https://github.com/deepfakes/faceswap-playground/issues) instead. The faceswap-playground is also an excellent place to ask questions and submit feedback.** +**Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap Forum](https://faceswap.dev/forum) instead. The [faceswap Forum](https://faceswap.dev/forum) is also an excellent place to ask questions and submit feedback. Non-bugs are likely to be closed without response.** -**Please always attach your generated crash_report.log to any bug report** +**Please always attach your generated crash_report.log to any bug report. Failure to attach this report may lead to your issue being closed without response.** ## Expected behavior diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index dd84ea7824..68ebacd28e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -6,6 +6,11 @@ labels: '' assignees: '' --- +*Note: For general usage questions and help, please use either our [FaceSwap Forum](https://faceswap.dev/forum) +or [FaceSwap Discord server](https://discord.gg/FC54sYg). General usage questions are liable to be closed without +response.* + +**Crash reports MUST be included when reporting bugs.** **Describe the bug** A clear and concise description of what the bug is. @@ -25,14 +30,12 @@ If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] - -**Smartphone (please complete the following information):** - - Device: [e.g. iPhone6] - - OS: [e.g. iOS8.1] - - Browser [e.g. stock browser, safari] - - Version [e.g. 22] - + - Python Version [e.g. 3.5, 3.6] + - Conda Version [e.g. 4.5.12] + - Commit ID [e.g. e83819f] + - **Additional context** Add any other context about the problem here. + +**Crash Report** +The crash report generated in the root of your Faceswap folder \ No newline at end of file diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000000..2b3415a366 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,185 @@ +name: ci/build + +on: + push: + pull_request: + paths-ignore: + - docs/** + - "**/README.md" + +jobs: + build_conda: + name: conda (${{ matrix.os }}, ${{ matrix.backend }} ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + fail-fast: false + matrix: + # TODO revert. Despite documentation to the contrary, MacOS runners are always x86-64 + #os: ["ubuntu-latest", "windows-latest", "macos-latest"] + os: ["ubuntu-latest", "windows-latest"] + python-version: ["3.11", "3.12", "3.13"] + backend: ["nvidia", "cpu", "rocm", "apple-silicon"] + exclude: + # CPU + Nvidia only on Windows + - os: "windows-latest" + backend: "rocm" + - os: windows-latest + backend: apple-silicon + # No apple-silicon on Linux + - os: ubuntu-latest + backend: apple-silicon + # Only Apple-Silicon on MacOS + - os: "macos-latest" + backend: "rocm" + - os: "macos-latest" + backend: "cpu" + - os: "macos-latest" + backend: "nvidia" + steps: + - uses: actions/checkout@v3 + - name: Cleanup space + # We run out of space on rocm. Ref: https://github.com/actions/runner-images/issues/709 + if: matrix.backend == 'rocm' + run: | + sudo rm -rf "/usr/local/share/boost" "$AGENT_TOOLSDIRECTORY" + - name: Set cache date + run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV + # TODO Re-enable. Currently disabled as it does not seem to get used and takes a lot of space + #- name: Cache conda + # uses: actions/cache@v3 + # env: + # # Increase this value to manually reset cache + # CACHE_NUMBER: 1 + # REQ_FILE: ./requirements/requirements_${{ matrix.backend }}.txt + # with: + # path: ~/conda_pkgs_dir + # key: ${{ runner.os }}-${{ matrix.backend }}-conda-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.DATE }}-${{ hashFiles('./requirements/requirements.txt', env.REQ_FILE) }} + - name: Set up Conda + uses: conda-incubator/setup-miniconda@v2 + with: + python-version: ${{ matrix.python-version }} + miniconda-version: "latest" + auto-update-conda: true + activate-environment: faceswap + - name: Conda info + run: conda info && conda list + - name: Install + run: | + python setup.py --installer --dev --${{ matrix.backend }} + pip install wheel pytest-xvfb types-attrs types-cryptography types-pyOpenSSL + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --select=E9,F63,F7,F82 --show-source + flake8 . --exit-zero + - name: MyPy Typing + continue-on-error: true + run: | + mypy . + - name: SysInfo + run: python -m lib.system.sysinfo + - name: Unit Tests + # These backends will fail as GPU drivers not available + if: matrix.backend == 'cpu' + run: | + KERAS_BACKEND=torch FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/; + - name: End to End Tests + # These backends will fail as GPU drivers not available + if: matrix.backend == 'cpu' + run: | + KERAS_BACKEND=torch FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py; + + build_linux: + name: "pip (ubuntu-latest, ${{ matrix.backend }} ${{ matrix.python-version }})" + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12", "3.13"] + backend: ["cpu"] + include: + - backend: "cpu" + steps: + - uses: actions/checkout@v5 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + './requirements/requirements_base.txt' + './requirements/requirements_${{ matrix.backend }}.txt' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r ./requirements/requirements_${{ matrix.backend }}.txt + pip install -r ./requirements/_requirements_dev.txt + pip install wheel pytest-xvfb types-attrs types-cryptography types-pyOpenSSL + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --select=E9,F63,F7,F82 --show-source + # exit-zero treats all errors as warnings. + flake8 . --exit-zero + - name: MyPy Typing + continue-on-error: true + run: | + mypy . + - name: SysInfo + run: FACESWAP_BACKEND="${{ matrix.backend }}" python -m lib.system.sysinfo + - name: Unit Tests + run: | + KERAS_BACKEND=torch FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/; + - name: End to End Tests + run: | + KERAS_BACKEND=torch FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py; + + build_windows: + name: "pip (windows-latest, ${{ matrix.backend }} ${{ matrix.python-version }})" + runs-on: windows-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12", "3.13"] + backend: ["cpu"] + include: + - backend: "cpu" + steps: + - uses: actions/checkout@v5 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + './requirements/requirements_base.txt' + './requirements/requirements_${{ matrix.backend }}.txt' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install types-attrs types-cryptography types-pyOpenSSL wheel + pip install -r ./requirements/_requirements_dev.txt + pip install -r ./requirements/requirements_${{ matrix.backend }}.txt + - name: Set Faceswap Backend EnvVar + run: echo "FACESWAP_BACKEND=${{ matrix.backend }}" | Out-File -FilePath $env:GITHUB_ENV -Append + - name: Set Keras Backend EnvVar + run: echo "KERAS_BACKEND=torch" | Out-File -FilePath $env:GITHUB_ENV -Append + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --select=E9,F63,F7,F82 --show-source + # exit-zero treats all errors as warnings. + flake8 . --exit-zero + - name: MyPy Typing + continue-on-error: true + run: | + mypy . + - name: SysInfo + run: python -m lib.system.sysinfo + - name: Unit Tests + run: py.test -v tests + - name: End to End Tests + run: python tests/simple_tests.py diff --git a/.gitignore b/.gitignore index f4fd44aada..ba4b84e9eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,35 +1,73 @@ +# Global (Exclude all + retain files that are unlikely to pollute local installs) * -!setup.cfg -!*.dat -!*.h5 -!*.ico -!*.inf !*.keep !*.md -!*.npy -!*.nsi -!*.pb -!*.png -!*.py -!*.txt -!.cache + +# Requirements files +!/requirements/ +!/requirements/*requirements*.txt +!/requirements/*conda*.yml +!/requirements/*.py + +# Root files !Dockerfile* -!requirements* +!pyproject.toml +!.gitignore +!.travis.yml +!/faceswap.py +!/setup.py +!/tools.py +!/update_deps.py + +# Support files +!/.github/ +!/.github/workflows/ +!/.github/workflows/*.yml !.install/ -!.install/windows +!.install/** !config/ +!.readthedocs.yml +!docs/ +!docs/_static/ +!docs/_static/*.png +!docs/full/ +!docs/full/**/ +!docs/full/**/*.rst +!locales/ +!locales/** + +# Test files +!tests/ +!tests/**/ +!tests/**/*.py +!tests/**/*.mp4 +!tests/**/*.jpg + +# Core files +!.fs_cache !lib/ -!lib/* -!lib/gui -!lib/gui/.cache/preview -!lib/gui/.cache/icons -!scripts +!lib/**/ +!lib/**/*.py +!lib/gui/**/icons/*.png +!lib/gui/**/themes/default.json +!lib/gui/**/presets/**/*.json !plugins/ -!plugins/* -!plugins/extract/* -!plugins/train/* -!tools -!tools/lib* -*.ini -*.pyc -__pycache__/ +!plugins/**/ +!plugins/**/*.py +!scripts/ +!scripts/*.py +!tools/ +!tools/**/ +!tools/**/*.py + +# GUI Plugin Presets +!lib/gui/**/presets/train/model_phaze_a_dfaker_preset.json +!lib/gui/**/presets/train/model_phaze_a_dfl-h128_preset.json +!lib/gui/**/presets/train/model_phaze_a_dfl-sae-df_preset.json +!lib/gui/**/presets/train/model_phaze_a_dfl-sae-liae_preset.json +!lib/gui/**/presets/train/model_phaze_a_dfl-saehd-df_preset.json +!lib/gui/**/presets/train/model_phaze_a_dfl-saehd-liae_preset.json +!lib/gui/**/presets/train/model_phaze_a_iae_preset.json +!lib/gui/**/presets/train/model_phaze_a_lightweight_preset.json +!lib/gui/**/presets/train/model_phaze_a_original_preset.json +!lib/gui/**/presets/train/model_phaze_a_stojo_preset.json diff --git a/.install/linux/faceswap_setup_x64.sh b/.install/linux/faceswap_setup_x64.sh new file mode 100644 index 0000000000..f9833ff8c8 --- /dev/null +++ b/.install/linux/faceswap_setup_x64.sh @@ -0,0 +1,502 @@ +#!/bin/bash +# TODO force conda-forge + +TMP_DIR="/tmp/faceswap_install" +DL_CONDA="https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" +DL_FACESWAP="https://github.com/deepfakes/faceswap.git" + +CONDA_PATHS=("/opt" "$HOME") +CONDA_NAMES=("/ana" "/mini") +CONDA_VERSIONS=("3" "2") +CONDA_BINS=("/bin/conda" "/condabin/conda") +DIR_CONDA="$HOME/miniconda3" +CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda" +CONDA_TO_PATH=false +ENV_NAME="faceswap" +PYENV_VERSION="3.13" + +DIR_FACESWAP="$HOME/faceswap" +VERSION="nvidia" +LIB_VERSION="13" + +DESKTOP=false + +header() { + # Format header text + length=${#1} + padding=$(( (72 - length) / 2)) + sep=$(printf '=%.0s' $(seq 1 $padding)) + echo "" + echo -e "\e[32m$sep $1 $sep" +} + +info () { + # output info message + while read -r line ; do + echo -e "\e[32mINFO\e[97m $line" + done <<< "$(echo "$1" | fmt -cu -w 70)" +} + +warn () { + # output warning message + while read -r line ; do + echo -e "\e[33mWARNING\e[97m $line" + done <<< "$(echo "$1" | fmt -cu -w 70)" +} + +error () { + # output error message. + while read -r line ; do + echo -e "\e[31mERROR\e[97m $line" + done <<< "$(echo "$1" | fmt -cu -w 70)" +} + +yellow () { + # Change text color to yellow + echo -en "\e[33m" +} + +check_file_exists () { + # Check whether a file exists and return true or false + test -f "$1" +} + +check_folder_exists () { + # Check whether a folder exists and return true or false + test -d "$1" +} + +download_file () { + # Download a file to the temp folder + fname=$(basename -- "$1") + curl "$1" --output "$TMP_DIR/$fname" --progress-bar +} + +check_for_sudo() { + # Ensure user isn't running as sudo/root. We don't want to screw up any system install + if [ "$EUID" == 0 ] ; then + error "This install script should not be run with root privileges. Please run as a normal user." + exit 1 + fi +} + +check_for_curl() { + # Ensure that curl is available on the system + if ! command -V curl &> /dev/null ; then + error "'curl' is required for running the Faceswap installer, but could not be found. \ + Please install 'curl' using the package manager for your distribution before proceeding." + exit 1 + fi +} + +create_tmp_dir() { + TMP_DIR="$(mktemp -d)" + if [ -z "$TMP_DIR" -o ! -d "$TMP_DIR" ]; then + # This shouldn't happen, but just in case to prevent the tmp cleanup function to mess things up. + error "Failed creating the temporary install directory." + exit 2 + fi + trap cleanup_tmp_dir EXIT +} + +cleanup_tmp_dir() { + rm -rf "$TMP_DIR" +} + +ask () { + # Ask for input. First parameter: Display text, 2nd parameter variable name + default="${!2}" + read -rp $'\e[36m'"$1 [default: '$default']: "$'\e[97m' inp + inp="${inp:-${default}}" + if [ "$inp" == "\n" ] ; then inp=${!2} ; fi + printf -v $2 "$inp" +} + +ask_yesno () { + # Ask yes or no. First Param: Question, 2nd param: Default + # Returns True for yes, False for No + case $2 in + [Yy]* ) opts="[YES/no]" ;; + [Nn]* ) opts="[yes/NO]" ;; + esac + while true; do + read -rp $'\e[36m'"$1 $opts: "$'\e[97m' yn + yn="${yn:-${2}}" + case $yn in + [Yy]* ) retval=true ; break ;; + [Nn]* ) retval=false ; break ;; + * ) echo "Please answer yes or no." ;; + esac + done + $retval +} + + +ask_version() { + # Ask which version of faceswap to install + while true; do + default=1 + read -rp $'\e[36mSelect:\t1: NVIDIA\n\t2: AMD (ROCm)\n\t3: CPU\n'"[default: $default]: "$'\e[97m' vers + vers="${vers:-${default}}" + case $vers in + 1) VERSION="nvidia" ; break ;; + 2) VERSION="rocm" ; break ;; + 3) VERSION="cpu" ; break ;; + * ) echo "Invalid selection." ;; + esac + done +} + + +ask_cuda_version() { + # Ask which Cuda Version to install + while true; do + default=1 + read -rp $'\e[36mSelect:\t1: RTX 20xx ->\n\t2: GTX 9xx - GTX 10xx\n\t3: GTX 7xx - GTX 9xx\n'"[default: $default]: "$'\e[97m' vers + vers="${vers:-${default}}" + case $vers in + 1) LIB_VERSION="13" ; break ;; + 2) LIB_VERSION="12" ; break ;; + 3) LIB_VERSION="11" ; break ;; + * ) echo "Invalid selection." ;; + esac + done +} + + +ask_rocm_version() { + # Ask which Cuda Version to install + while true; do + default=1 + read -rp $'\e[36mSelect:\t1: ROCm 6.4\n\t2: ROCm 6.3\n\t3: ROCm 6.2\n\t4: ROCm 6.1\n\t5: ROCm 6.0\n'"[default: $default]: "$'\e[97m' vers + vers="${vers:-${default}}" + case $vers in + 1) LIB_VERSION="64" ; break ;; + 2) LIB_VERSION="63" ; break ;; + 3) LIB_VERSION="62" ; break ;; + 4) LIB_VERSION="61" ; break ;; + 5) LIB_VERSION="60" ; break ;; + * ) echo "Invalid selection." ;; + esac + done +} + +banner () { + echo -e " \e[32m 001" + echo -e " \e[32m 11 10 010" + echo -e " \e[97m @@@@\e[32m 10" + echo -e " \e[97m @@@@@@@@\e[32m 00 1" + echo -e " \e[97m @@@@@@@@@@\e[32m 1 1 0" + echo -e " \e[97m @@@@@@@@\e[32m 0000 01111" + echo -e " \e[97m @@@@@@@@@@\e[32m 01 110 01 1" + echo -e " \e[97m@@@@@@@@@@@@\e[32m 111 010 0" + echo -e " \e[97m@@@@@@@@@@@@@@@@\e[32m 10 0" + echo -e " \e[97m@@@@@@@@@@@@@\e[32m 0010 1" + echo -e " \e[97m@@@@@@@@@ @@@\e[32m 100 1" + echo -e " \e[97m@@@@@@@ .@@@@\e[32m 10 1" + echo -e " \e[97m #@@@@@@@@@@@\e[32m 001 0" + echo -e " \e[97m @@@@@@@@@@@ ," + echo -e " \e[97m @@@@@@@@ @@@@@" + echo -e " \e[97m @@@@@@@@ @@@@@@@@" + echo -e " \e[97m @@@@@@@@@,@@@@@@@@ / _|" + echo -e " \e[97m %@@@@@@@@@@@@@@@@@ | |_ ___ " + echo -e " \e[97m @@@@@@@@@@@@@@ | _|/ __|" + echo -e " \e[97m @@@@@@@@@@@@ | | \__ \\" + echo -e " \e[97m @@@@@@@@@@( |_| |___/" + echo -e " \e[97m @@@@@@" + echo -e " \e[97m @@@@" + sleep 2 +} + +find_conda_install() { + if check_conda_path; + then true + elif check_conda_locations ; then true + else false + fi +} + +set_conda_dir_from_bin() { + # Set the DIR_CONDA variable from the bin file + DIR_CONDA=$(readlink -f "$(dirname "$1")/..") + info "Found existing conda install at: $DIR_CONDA" +} + +check_conda_path() { + # Check if conda is in PATH + conda_bin="$(which conda 2>/dev/null)" + if [[ "$?" == "0" ]]; then + set_conda_dir_from_bin "$conda_bin" + CONDA_EXECUTABLE="$conda_bin" + true + else + false + fi +} + +check_conda_locations() { + # Check common conda install locations + retval=false + for path in "${CONDA_PATHS[@]}"; do + for name in "${CONDA_NAMES[@]}" ; do + foldername="$path${name}conda" + for vers in "${CONDA_VERSIONS[@]}" ; do + for bin in "${CONDA_BINS[@]}" ; do + condabin="$foldername$vers$bin" + if check_file_exists "$condabin" ; then + set_conda_dir_from_bin "$condabin" + CONDA_EXECUTABLE="$condabin"; + retval=true + break 4 + fi + done + done + done + done + $retval +} + +user_input() { + # Get user options for install + header "Welcome to the Linux Faceswap Installer" + info "To get setup we need to gather some information about where you would like Faceswap\ + and Conda to be installed." + info "To accept the default values just hit the 'ENTER' key for each option. You will have\ + an opportunity to review your responses prior to commencing the install." + echo "" + info "\e[33mIMPORTANT:\e[97m Make sure that the user '$USER' has full permissions for all of the\ + destinations that you select." + read -rp $'\e[36m'"Press 'ENTER' to continue with the setup..."$'\e[36m' + conda_opts + faceswap_opts + post_install_opts +} + +conda_opts () { + # Options pertaining to the installation of conda + header "CONDA" + info "Faceswap uses Conda as it handles the installation of all prerequisites." + if find_conda_install && ask_yesno "Use the pre installed conda?" "Yes"; then + info "Using Conda install at $DIR_CONDA" + else + info "If you have an existing Conda install then enter the location here,\ + otherwise Miniconda3 will be installed in the given location." + err_msg="The location for Conda must not contain spaces (this is a specific\ + limitation of Conda)." + tmp_dir_conda="$DIR_CONDA" + while true ; do + ask "Please specify a location for Conda." "DIR_CONDA" + case ${DIR_CONDA} in + *\ * ) error "$err_msg" ; DIR_CONDA=$tmp_dir_conda ;; + * ) break ;; + esac + CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda" + done + fi + if ! check_file_exists "$CONDA_EXECUTABLE" ; then + info "The Conda executable can be added to your PATH. This makes it easier to run Conda\ + commands directly. If you already have a pre-existing Conda install then you should\ + probably not enable this, otherwise this should be fine." + if ask_yesno "Add Conda executable to path?" "Yes" ; then CONDA_TO_PATH=true ; fi + fi + echo "" + info "Faceswap will be installed inside a Conda Environment. If an environment already\ + exists with the name specified then it will be deleted." + ask "Please specify a name for the Faceswap Conda Environment" "ENV_NAME" +} + +faceswap_opts () { + # Options pertaining to the installation of faceswap + header "FACESWAP" + info "Faceswap will be installed in the given location. If a folder exists at the\ + location you specify, then it will be deleted." + ask "Please specify a location for Faceswap" "DIR_FACESWAP" + echo "" + info "Faceswap can be run on NVIDIA or AMD GPUs or on CPU. You should make sure that you have the \ + latest graphics card drivers installed from the relevant vendor. Please select the version\ + of Faceswap you wish to install." + ask_version + if [ $VERSION == "nvidia" ] ; then + info "Depending on your GPU a different version of Cuda may be required. Please select the \ + generation of Nvidia GPU you use below." + ask_cuda_version + fi + if [ $VERSION == "rocm" ] ; then + info "Depending on your installed version of ROCm a different version of PyTorch may be required. \ + Please select the ROCm version you use below." + ask_rocm_version + warn "ROCm support is experimental. Please make sure that your GPU is supported by ROCm and that \ + ROCm has been installed on your system before proceeding. Installation instructions: \ + https://docs.amd.com/bundle/ROCm_Installation_Guidev5.0/page/Overview_of_ROCm_Installation_Methods.html" + sleep 2 + fi +} + +post_install_opts() { + # Post installation options + if check_folder_exists "$HOME/Desktop" ; then + header "POST INSTALLATION ACTIONS" + info "Launching Faceswap requires activating your Conda Environment and then running\ + Faceswap. The installer can simplify this by creating a desktop shortcut to launch\ + straight into the Faceswap GUI" + if ask_yesno "Create Desktop Shortcut?" "Yes" + then DESKTOP=true + fi + fi +} + +review() { + # Review user options and ask continue + header "Review install options" + info "Please review the selected installation options before proceeding:" + echo "" + if ! check_folder_exists "$DIR_CONDA" + then + echo " - MiniConda3 will be installed in '$DIR_CONDA'" + else + echo " - Existing Conda install at '$DIR_CONDA' will be used" + fi + if $CONDA_TO_PATH ; then echo " - MiniConda3 will be added to your PATH" ; fi + if check_env_exists ; then + echo -e " \e[33m- Existing Conda Environment '$ENV_NAME' will be removed\e[97m" + fi + echo " - Conda Environment '$ENV_NAME' will be created." + if check_folder_exists "$DIR_FACESWAP" ; then + echo -e " \e[33m- Existing Faceswap folder '$DIR_FACESWAP' will be removed\e[97m" + fi + echo " - Faceswap will be installed in '$DIR_FACESWAP'" + echo " - Installing for '$VERSION'" + if [ $VERSION == "nvidia" ] ; then + echo " - Cuda version $LIB_VERSION will be used" + fi + if [ $VERSION == "rocm" ] ; then + echo " - ROCm version '$LIB_VERSION' will be used" + echo -e " \e[33m- Note: Please ensure that ROCm is supported by your GPU\e[97m" + echo -e " \e[33m and is installed prior to proceeding.\e[97m" + fi + if $DESKTOP ; then echo " - A Desktop shortcut will be created" ; fi + if ! ask_yesno "Do you wish to continue?" "No" ; then exit ; fi +} + +conda_install() { + # Download and install Mini Conda3 + if ! check_folder_exists "$DIR_CONDA" ; then + info "Downloading Miniconda3..." + yellow ; download_file $DL_CONDA + info "Installing Miniconda3..." + yellow ; fname="$(basename -- $DL_CONDA)" + bash "$TMP_DIR/$fname" -b -p "$DIR_CONDA" + "$CONDA_EXECUTABLE" tos accept + if $CONDA_TO_PATH ; then + info "Adding Miniconda3 to PATH..." + yellow ; "$CONDA_EXECUTABLE" init + "$CONDA_EXECUTABLE" config --set auto_activate false + fi + fi +} + +check_env_exists() { + # Check if an environment with the given name exists + if check_file_exists "$CONDA_EXECUTABLE" ; then + "$CONDA_EXECUTABLE" env list | grep -qE "^${ENV_NAME}\W" + else false + fi +} + +delete_env() { + # Delete the env if it previously exists + if check_env_exists ; then + info "Removing pre-existing Virtual Environment" + yellow ; "$CONDA_EXECUTABLE" env remove -n "$ENV_NAME" + fi +} + +create_env() { + # Create Python 3.13 env for faceswap + delete_env + info "Creating Conda Virtual Environment..." + yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -c conda-forge -q python="$PYENV_VERSION" -y +} + + +activate_env() { + # Activate the conda environment + # shellcheck source=/dev/null + source "$DIR_CONDA/etc/profile.d/conda.sh" activate + conda activate "$ENV_NAME" +} + +install_git() { + # Install git inside conda environment + info "Installing Git..." + # TODO On linux version 2.45.2 makes the font fixed TK pull in Python from + # graalpy, which breaks pretty much everything + yellow ; conda install "git<2.45" -q -y +} + +delete_faceswap() { + # Delete existing faceswap folder + if check_folder_exists "$DIR_FACESWAP" ; then + info "Removing Faceswap folder: '$DIR_FACESWAP'" + rm -rf "$DIR_FACESWAP" + fi +} + +clone_faceswap() { + # Clone the faceswap repo + delete_faceswap + info "Downloading Faceswap..." + yellow ; git clone --depth 1 --no-single-branch "$DL_FACESWAP" "$DIR_FACESWAP" +} + +setup_faceswap() { + # Run faceswap setup script + info "Setting up Faceswap..." + python -u "$DIR_FACESWAP/setup.py" --installer --$VERSION$LIB_VERSION +} + +create_gui_launcher () { + # Create a shortcut to launch into the GUI + launcher="$DIR_FACESWAP/faceswap_gui_launcher.sh" + launch_script="source \"$DIR_CONDA/etc/profile.d/conda.sh\" activate &&\n" + launch_script+="conda activate '$ENV_NAME' &&\n" + launch_script+="python \"$DIR_FACESWAP/faceswap.py\" gui\n" + echo -e "$launch_script" > "$launcher" + chmod +x "$launcher" +} + +create_desktop_shortcut () { + # Create a shell script to launch the GUI and add a desktop shortcut + if $DESKTOP ; then + desktop_icon="$HOME/Desktop/faceswap.desktop" + desktop_file="[Desktop Entry]\n" + desktop_file+="Version=1.0\n" + desktop_file+="Type=Application\n" + desktop_file+="Terminal=true\n" + desktop_file+="Name=FaceSwap\n" + desktop_file+="Exec=bash $launcher\n" + desktop_file+="Comment=FaceSwap\n" + desktop_file+="Icon=$DIR_FACESWAP/.install/linux/fs_logo.ico\n" + echo -e "$desktop_file" > "$desktop_icon" + chmod +x "$desktop_icon" + fi ; +} + +check_for_sudo +check_for_curl +banner +user_input +review +create_tmp_dir +conda_install +create_env +activate_env +install_git +clone_faceswap +setup_faceswap +create_gui_launcher +create_desktop_shortcut +info "Faceswap installation is complete!" +if $DESKTOP ; then info "You can launch Faceswap from the icon on your desktop" ; exit ; fi +if $CONDA_TO_PATH ; then + info "You should close the terminal and re-open to activate Conda before proceeding" ; fi diff --git a/.install/linux/fs_logo.ico b/.install/linux/fs_logo.ico new file mode 100644 index 0000000000..c96ff6105f Binary files /dev/null and b/.install/linux/fs_logo.ico differ diff --git a/.install/macos/app.zip b/.install/macos/app.zip new file mode 100644 index 0000000000..9ce64629d6 Binary files /dev/null and b/.install/macos/app.zip differ diff --git a/.install/macos/faceswap_setup_macos.sh b/.install/macos/faceswap_setup_macos.sh new file mode 100644 index 0000000000..443c9bf127 --- /dev/null +++ b/.install/macos/faceswap_setup_macos.sh @@ -0,0 +1,492 @@ +#!/bin/bash + +TMP_DIR="/tmp/faceswap_install" + +URL_CONDA="https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-" +DL_CONDA="${URL_CONDA}x86_64.sh" +DL_FACESWAP="https://github.com/deepfakes/faceswap.git" +DL_XQUARTZ="https://github.com/XQuartz/XQuartz/releases/latest/download/XQuartz-2.8.5.pkg" + +CONDA_PATHS=("/opt" "$HOME") +CONDA_NAMES=("anaconda" "miniconda" "miniforge") +CONDA_VERSIONS=("3" "2") +CONDA_BINS=("/bin/conda" "/condabin/conda") +DIR_CONDA="$HOME/miniconda3" +CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda" +CONDA_TO_PATH=false +ENV_NAME="faceswap" +PYENV_VERSION="3.13" + +DIR_FACESWAP="$HOME/faceswap" +VERSION="nvidia" + +DESKTOP=false +XQUARTZ=false + +header() { + # Format header text + length=${#1} + padding=$(( (72 - length) / 2)) + sep=$(printf '=%.0s' $(seq 1 $padding)) + echo "" + echo $'\e[32m'$sep $1 $sep +} + +info () { + # output info message + while read -r line ; do + echo $'\e[32mINFO\e[39m '$line + done <<< "$(echo "$1" | fmt -s -w 70)" +} + +warn () { + # output warning message + while read -r line ; do + echo $'\e[33mWARNING\e[39m '$line + done <<< "$(echo "$1" | fmt -s -w 70)" +} + +error () { + # output error message. + while read -r line ; do + echo $'\e[31mERROR\e[39m '$line + done <<< "$(echo "$1" | fmt -s -w 70)" +} + +yellow () { + # Change text color to yellow + echo $'\e[33m' +} + +check_file_exists () { + # Check whether a file exists and return true or false + test -f "$1" +} + +check_folder_exists () { + # Check whether a folder exists and return true or false + test -d "$1" +} + +download_file () { + # Download a file to the temp folder + fname=$(basename -- "$1") + curl -L "$1" --output "$TMP_DIR/$fname" --progress-bar +} + +check_for_sudo() { + # Ensure user isn't running as sudo/root. We don't want to screw up any system install + if [ "$EUID" == 0 ] ; then + error "This install script should not be run with root privileges. Please run as a normal user." + exit 1 + fi +} + +check_for_curl() { + # Ensure that curl is available on the system + if ! command -V curl &> /dev/null ; then + error "'curl' is required for running the Faceswap installer, but could not be found. \ + Please install 'curl' before proceeding." + exit 1 + fi +} + +check_for_xcode() { + # Ensure that xcode command line tools are available on the system + if xcode-select -p 2>&1 | grep -q "xcode-select: error" ; then + error "Xcode is required to install faceswap. Please install Xcode Command Line Tools \ + before proceeding. If the Xcode installer does not automatically open, then \ + you can run the command:" + error "xcode-select --install" + echo "" + xcode-select --install + exit 1 + fi +} + +create_tmp_dir() { + TMP_DIR="$(mktemp -d)" + if [ -z "$TMP_DIR" -o ! -d "$TMP_DIR" ]; then + # This shouldn't happen, but just in case to prevent the tmp cleanup function to mess things up. + error "Failed creating the temporary install directory." + exit 2 + fi + trap cleanup_tmp_dir EXIT +} + +cleanup_tmp_dir() { + rm -rf "$TMP_DIR" +} + +ask () { + # Ask for input. First parameter: Display text, 2nd parameter variable name + default="${!2}" + read -rp $'\e[35m'"$1 [default: '$default']: "$'\e[39m' inp + inp="${inp:-${default}}" + if [ "$inp" == "\n" ] ; then inp=${!2} ; fi + printf -v $2 "$inp" +} + +ask_yesno () { + # Ask yes or no. First Param: Question, 2nd param: Default + # Returns True for yes, False for No + case $2 in + [Yy]* ) opts="[YES/no]" ;; + [Nn]* ) opts="[yes/NO]" ;; + esac + while true; do + read -rp $'\e[35m'"$1 $opts: "$'\e[39m' yn + yn="${yn:-${2}}" + case $yn in + [Yy]* ) retval=true ; break ;; + [Nn]* ) retval=false ; break ;; + * ) echo "Please answer yes or no." ;; + esac + done + $retval +} + + +ask_version() { + # Ask which version of faceswap to install + while true; do + default=1 + read -rp $'\e[35mSelect:\t1: Apple Silicon\n\t2: NVIDIA\n\t3: CPU\n'"[default: $default]: "$'\e[39m' vers + vers="${vers:-${default}}" + case $vers in + 1) VERSION="apple_silicon" ; break ;; + 2) VERSION="nvidia" ; break ;; + 3) VERSION="cpu" ; break ;; + * ) echo "Invalid selection." ;; + esac + done +} + +banner () { + echo $' \e[32m 001' + echo $' \e[32m 11 10 010' + echo $' \e[39m @@@@\e[32m 10' + echo $' \e[39m @@@@@@@@\e[32m 00 1' + echo $' \e[39m @@@@@@@@@@\e[32m 1 1 0' + echo $' \e[39m @@@@@@@@\e[32m 0000 01111' + echo $' \e[39m @@@@@@@@@@\e[32m 01 110 01 1' + echo $' \e[39m@@@@@@@@@@@@\e[32m 111 010 0' + echo $' \e[39m@@@@@@@@@@@@@@@@\e[32m 10 0' + echo $' \e[39m@@@@@@@@@@@@@\e[32m 0010 1' + echo $' \e[39m@@@@@@@@@ @@@\e[32m 100 1' + echo $' \e[39m@@@@@@@ .@@@@\e[32m 10 1' + echo $' \e[39m #@@@@@@@@@@@\e[32m 001 0' + echo $' \e[39m @@@@@@@@@@@ ,' + echo ' @@@@@@@@ @@@@@' + echo ' @@@@@@@@ @@@@@@@@ _' + echo ' @@@@@@@@@,@@@@@@@@ / _|' + echo ' %@@@@@@@@@@@@@@@@@ | |_ ___ ' + echo ' @@@@@@@@@@@@@@ | _|/ __|' + echo ' @@@@@@@@@@@@ | | \__ \' + echo ' @@@@@@@@@@( |_| |___/' + echo ' @@@@@@' + echo ' @@@@' + sleep 2 +} + +find_conda_install() { + if check_conda_path; + then true + elif check_conda_locations ; then true + else false + fi +} + +set_conda_dir_from_bin() { + # Set the DIR_CONDA variable from the bin file + pth="$(dirname "$1")/.." + DIR_CONDA=$(python -c "import os, sys; print(os.path.realpath('$pth'))") + info "Found existing conda install at: $DIR_CONDA" +} + +check_conda_path() { + # Check if conda is in PATH + conda_bin="$(which conda 2>/dev/null)" + if [[ "$?" == "0" ]]; then + set_conda_dir_from_bin "$conda_bin" + CONDA_EXECUTABLE="$conda_bin" + true + else + false + fi +} + +check_conda_locations() { + # Check common conda install locations + retval=false + for path in "${CONDA_PATHS[@]}"; do + for name in "${CONDA_NAMES[@]}" ; do + foldername="$path/$name" + for vers in "${CONDA_VERSIONS[@]}" ; do + for bin in "${CONDA_BINS[@]}" ; do + condabin="$foldername$vers$bin" + if check_file_exists "$condabin" ; then + set_conda_dir_from_bin "$condabin" + CONDA_EXECUTABLE="$condabin"; + retval=true + break 4 + fi + done + done + done + done + $retval +} + +user_input() { + # Get user options for install + header "Welcome to the macOS Faceswap Installer" + info "To get setup we need to gather some information about where you would like Faceswap\ + and Conda to be installed." + info "To accept the default values just hit the 'ENTER' key for each option. You will have\ + an opportunity to review your responses prior to commencing the install." + echo "" + info "IMPORTANT: Make sure that the user '$USER' has full permissions for all of the\ + destinations that you select." + read -rp $'\e[35m'"Press 'ENTER' to continue with the setup..."$'\e[39m' + apps_opts + conda_opts + faceswap_opts + post_install_opts +} + +apps_opts () { + # Options pertaining to additional apps that are required + if ! command -V xquartz &> /dev/null ; then + header "APPS" + info "XQuartz is required to use the Faceswap GUI but was not detected. " + if ask_yesno "Install XQuartz for GUI support?" "Yes" ; then + XQUARTZ=true + fi + fi +} + +conda_opts () { + # Options pertaining to the installation of conda + header "CONDA" + info "Faceswap uses Conda as it handles the installation of all prerequisites." + if find_conda_install && ask_yesno "Use the pre installed conda?" "Yes"; then + info "Using Conda install at $DIR_CONDA" + else + echo "" + info "If you have an existing Conda install then enter the location here,\ + otherwise Miniconda3 will be installed in the given location." + err_msg="The location for Conda must not contain spaces (this is a specific\ + limitation of Conda)." + tmp_dir_conda="$DIR_CONDA" + while true ; do + ask "Please specify a location for Conda." "DIR_CONDA" + case ${DIR_CONDA} in + *\ * ) error "$err_msg" ; DIR_CONDA=$tmp_dir_conda ;; + * ) break ;; + esac + CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda" + done + fi + if ! check_file_exists "$CONDA_EXECUTABLE" ; then + echo "" + info "The Conda executable can be added to your PATH. This makes it easier to run Conda\ + commands directly. If you already have a pre-existing Conda install then you should\ + probably not enable this, otherwise this should be fine." + if ask_yesno "Add Conda executable to path?" "Yes" ; then CONDA_TO_PATH=true ; fi + fi + echo "" + info "Faceswap will be installed inside a Conda Environment. If an environment already\ + exists with the name specified then it will be deleted." + ask "Please specify a name for the Faceswap Conda Environment" "ENV_NAME" +} + +faceswap_opts () { + # Options pertaining to the installation of faceswap + header "FACESWAP" + info "Faceswap will be installed in the given location. If a folder exists at the\ + location you specify, then it will be deleted." + ask "Please specify a location for Faceswap" "DIR_FACESWAP" + echo "" + info "Faceswap can be run on Apple Silicon (M1, M2 etc.), compatible NVIDIA gpus, or on CPU. You should make sure that any \ + drivers are up to date. Please select the version of Faceswap you wish to install." + ask_version + if [ $VERSION == "apple_silicon" ] ; then + DL_CONDA="${URL_CONDA}arm64.sh" + fi +} + +post_install_opts() { + # Post installation options + header "POST INSTALLATION ACTIONS" + info "Launching Faceswap requires activating your Conda Environment and then running\ + Faceswap. The installer can simplify this by creating an Application Launcher file and placing it \ + on your desktop to launch straight into the Faceswap GUI" + if ask_yesno "Create FaceswapGUI Launcher?" "Yes" ; then + DESKTOP=true + fi +} + +review() { + # Review user options and ask continue + header "Review install options" + info "Please review the selected installation options before proceeding:" + echo "" + if $XQUARTZ ; then echo " - The XQuartz installer will be downloaded and launched" ; fi + if ! check_folder_exists "$DIR_CONDA" + then + echo " - MiniConda3 will be installed in '$DIR_CONDA'" + else + echo " - Existing Conda install at '$DIR_CONDA' will be used" + fi + if $CONDA_TO_PATH ; then echo " - MiniConda3 will be added to your PATH" ; fi + if check_env_exists ; then + echo $' \e[33m- Existing Conda Environment '$ENV_NAME $' will be removed\e[39m' + fi + echo " - Conda Environment '$ENV_NAME' will be created." + if check_folder_exists "$DIR_FACESWAP" ; then + echo $' \e[33m- Existing Faceswap folder '$DIR_FACESWAP $' will be removed\e[39m' + fi + echo " - Faceswap will be installed in '$DIR_FACESWAP'" + echo " - Installing for '$VERSION'" + if [ $VERSION == "nvidia" ] ; then + echo $' \e[33m- Note: Please ensure that Nvidia drivers are installed prior to proceeding\e[39m' + fi + if $DESKTOP ; then echo " - An Application Launcher will be created" ; fi + if ! ask_yesno "Do you wish to continue?" "No" ; then exit ; fi +} + +xquartz_install() { + # Download and install XQuartz + if $XQUARTZ ; then + info "Downloading XQuartz..." + yellow ; download_file $DL_XQUARTZ + echo "" + + info "Installing XQuartz..." + info "Admin password required to install XQuartz:" + fname="$(basename -- $DL_XQUARTZ)" + yellow ; sudo installer -pkg "$TMP_DIR/$fname" -target / + echo "" + fi +} + +conda_install() { + # Download and install Mini Conda3 + if ! check_folder_exists "$DIR_CONDA" ; then + info "Downloading Miniconda3..." + yellow ; download_file $DL_CONDA + info "Installing Miniconda3..." + yellow ; fname="$(basename -- $DL_CONDA)" + bash "$TMP_DIR/$fname" -b -p "$DIR_CONDA" + "$CONDA_EXECUTABLE" tos accept + if $CONDA_TO_PATH ; then + info "Adding Miniconda3 to PATH..." + yellow ; "$CONDA_EXECUTABLE" init zsh bash + "$CONDA_EXECUTABLE" config --set auto_activate false + fi + fi +} + +check_env_exists() { + # Check if an environment with the given name exists + if check_file_exists "$CONDA_EXECUTABLE" ; then + "$CONDA_EXECUTABLE" env list | grep -qE "^${ENV_NAME}\W" + else false + fi +} + +delete_env() { + # Delete the env if it previously exists + if check_env_exists ; then + info "Removing pre-existing Virtual Environment" + yellow ; "$CONDA_EXECUTABLE" env remove -n "$ENV_NAME" + fi +} + +create_env() { + # Create Python 3.13 env for faceswap + delete_env + info "Creating Conda Virtual Environment..." + yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -c conda-forge -q python="$PYENV_VERSION" -y +} + + +activate_env() { + # Activate the conda environment + # shellcheck source=/dev/null + source "$DIR_CONDA/etc/profile.d/conda.sh" activate + conda activate "$ENV_NAME" +} + +delete_faceswap() { + # Delete existing faceswap folder + if check_folder_exists "$DIR_FACESWAP" ; then + info "Removing Faceswap folder: '$DIR_FACESWAP'" + rm -rf "$DIR_FACESWAP" + fi +} + +clone_faceswap() { + # Clone the faceswap repo + delete_faceswap + info "Downloading Faceswap..." + yellow ; git clone --depth 1 --no-single-branch "$DL_FACESWAP" "$DIR_FACESWAP" +} + +setup_faceswap() { + # Run faceswap setup script + info "Setting up Faceswap..." + python -u "$DIR_FACESWAP/setup.py" --installer --$VERSION +} + +create_gui_launcher () { + # Create a shortcut to launch into the GUI + launcher="$DIR_FACESWAP/faceswap_gui_launcher.command" + launch_script="#!/bin/bash\n" + launch_script+="source \"$DIR_CONDA/etc/profile.d/conda.sh\" activate && \n" + launch_script+="conda activate '$ENV_NAME' && \n" + launch_script+="python \"$DIR_FACESWAP/faceswap.py\" gui" + printf "$launch_script" > "$launcher" + chmod +x "$launcher" +} + +create_app_on_desktop () { + # Create a simple .app wrapper to launch GUI + if $DESKTOP ; then + app_name="FaceswapGUI" + app_dir="$TMP_DIR/$app_name.app" + + unzip -qq "$DIR_FACESWAP/.install/macos/app.zip" -d "$TMP_DIR" + + script="#!/bin/bash\n" + script+="bash \"$DIR_FACESWAP/faceswap_gui_launcher.command\"" + printf "$script" > "$app_dir/Contents/Resources/script" + chmod +x "$app_dir/Contents/Resources/script" + + rm -rf "$HOME/Desktop/$app_name.app" + mv "$app_dir" "$HOME/Desktop" + fi ; +} + +check_for_sudo +check_for_curl +check_for_xcode +banner +user_input +review +create_tmp_dir +xquartz_install +conda_install +create_env +activate_env +clone_faceswap +setup_faceswap +create_gui_launcher +create_app_on_desktop +info "Faceswap installation is complete!" +if $CONDA_TO_PATH ; then + info "You should close the terminal before proceeding" ; fi +if $DESKTOP ; then info "You can launch Faceswap from the icon on your desktop" ; fi +if $XQUARTZ ; then + warn "XQuartz has been installed. You must log out and log in again to be able to use the GUI" ; fi diff --git a/.install/windows/fs_logo.ico b/.install/windows/fs_logo.ico new file mode 100644 index 0000000000..c33f64010d Binary files /dev/null and b/.install/windows/fs_logo.ico differ diff --git a/.install/windows/fs_logo_32.ico b/.install/windows/fs_logo_32.ico index fbf031b301..c33f64010d 100644 Binary files a/.install/windows/fs_logo_32.ico and b/.install/windows/fs_logo_32.ico differ diff --git a/.install/windows/git_install.inf b/.install/windows/git_install.inf deleted file mode 100644 index c0cf808a95..0000000000 --- a/.install/windows/git_install.inf +++ /dev/null @@ -1,18 +0,0 @@ -[Setup] -Lang=default -Dir=C:\Program Files\Git -Group=Git -NoIcons=0 -SetupType=default -Components=ext,ext\shellhere,ext\guihere,gitlfs,assoc,assoc_sh -Tasks= -EditorOption=VisualStudioCode -CustomEditorPath= -PathOption=Cmd -SSHOption=OpenSSH -CURLOption=OpenSSL -CRLFOption=CRLFAlways -BashTerminalOption=MinTTY -PerformanceTweaksFSCache=Enabled -UseCredentialManager=Enabled -EnableSymlinks=Disabled diff --git a/.install/windows/install.nsi b/.install/windows/install.nsi index 4983ab2ecc..0e60506beb 100644 --- a/.install/windows/install.nsi +++ b/.install/windows/install.nsi @@ -1,5 +1,8 @@ +# TODO: Install visualstudio build tools for fastcluster +# TODO: Check if we still get realtime output with Subprocess in setup.py !include MUI2.nsh !include nsDialogs.nsh +!include winmessages.nsh !include LogicLib.nsh !include CPUFeatures.nsh !include MultiDetailPrint.nsi @@ -9,48 +12,39 @@ OutFile "faceswap_setup_x64.exe" Name "Faceswap" InstallDir $PROFILE\faceswap -# Download sites -!define wwwGit "https://github.com/git-for-windows/git/releases/download/v2.20.1.windows.1/Git-2.20.1-64-bit.exe" +# Sometimes miniconda breaks. Uncomment/comment the following 2 lines to pin !define wwwConda "https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe" +#!define wwwConda "https://repo.anaconda.com/miniconda/Miniconda3-4.5.12-Windows-x86_64.exe" !define wwwRepo "https://github.com/deepfakes/faceswap.git" - +!define wwwFaceswap "https://www.faceswap.dev" # Faceswap Specific -!define flagsSetup "setup.py --installer" +!define flagsSetup "--installer" # Install cli flags -!define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$Profile\MiniConda3" -!define flagsGit "/SILENT /NORESTART /NOCANCEL /SP /CLOSEAPPLICATIONS /RESTARTAPPLICATIONS" +!define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$PROFILE\MiniConda3" !define flagsRepo "--depth 1 --no-single-branch ${wwwRepo}" -!define flagsEnv "-y python=3.6" - -# Dlib Wheel prefix -!define prefixDlib "dlib-19.16.99-cp36-cp36m-win_amd64" -!define dlibFinalName "dlib-19.16.99-cp36-cp36m-win_amd64.whl" # Dlib Wheel MUST have this name before installing -!define cudaDlib "_cuda90" -!define avxDlib "_avx" -!define sseDlib "_sse4" -!define noneDlib "_none" - +!define flagsEnv "-y python=3.13" # Folders +Var ProgramData Var dirTemp Var dirMiniconda +Var dirMinicondaAll Var dirAnaconda +Var dirAnacondaAll Var dirConda # Items to Install -Var InstallGit Var InstallConda -Var dlibWhl # Misc -Var gitInf Var InstallFailed Var lblPos Var hasAVX Var hasSSE4 -Var noNvidia +Var setupType +Var ctlRadio Var ctlCondaText Var ctlCondaButton Var Log @@ -61,7 +55,7 @@ Var envName !define MUI_ABORTWARNING # Install Location Page -!define MUI_ICON "fs_logo_32.ico" +!define MUI_ICON "fs_logo.ico" !define MUI_PAGE_HEADER_TEXT "Faceswap.py Installer" !define MUI_PAGE_HEADER_SUBTEXT "Install Location" !define MUI_DIRECTORYPAGE_TEXT_DESTINATION "Select Destination Folder:" @@ -72,6 +66,7 @@ Var envName Page custom pgPrereqCreate pgPrereqLeave # Install Faceswap Page +!define MUI_PAGE_CUSTOMFUNCTION_SHOW InstFilesShow !define MUI_PAGE_HEADER_SUBTEXT "Installing Faceswap..." !insertmacro MUI_PAGE_INSTFILES @@ -80,19 +75,27 @@ Page custom pgPrereqCreate pgPrereqLeave # Init Function .onInit + SetShellVarContext all + StrCpy $ProgramData $APPDATA + SetShellVarContext current # It's better to put stuff in $pluginsdir, $temp is shared InitPluginsDir StrCpy $dirTemp "$pluginsdir\faceswap\temp" StrCpy $dirMiniconda "$PROFILE\Miniconda3" StrCpy $dirAnaconda "$PROFILE\Anaconda3" - StrCpy $gitInf "$dirTemp\git_install.inf" + StrCpy $dirMinicondaAll "$ProgramData\Miniconda3" + StrCpy $dirAnacondaAll "$ProgramData\Anaconda3" StrCpy $envName "faceswap" SetOutPath "$dirTemp" - File *.whl - File git_install.inf Call CheckPrerequisites FunctionEnd +# Enable the cancel button during installation +Function InstFilesShow + GetDlgItem $0 $HWNDPARENT 2 + EnableWindow $0 1 +FunctionEnd + Function VerifyInstallDir # Check install folder does not already exist IfFileExists $INSTDIR 0 +3 @@ -117,47 +120,58 @@ Function pgPrereqCreate StrCpy $lblPos 14 # Info Installing applications - ${NSD_CreateGroupBox} 5% 5% 90% 35% "The following applications will be installed" + ${NSD_CreateGroupBox} 1% 1% 98% 30% "The following applications will be installed" Pop $0 - ${If} $InstallGit == 1 - ${NSD_CreateLabel} 10% $lblPos% 80% 14u "Git for Windows" - Pop $0 - intOp $lblPos $lblPos + 7 - ${EndIf} - ${If} $InstallConda == 1 ${NSD_CreateLabel} 10% $lblPos% 80% 14u "MiniConda 3" Pop $0 intOp $lblPos $lblPos + 7 ${EndIf} ${NSD_CreateLabel} 10% $lblPos% 80% 14u "Faceswap" + Pop $0 - StrCpy $lblPos 46 + intOp $lblPos $lblPos + 15 # Info Custom Options - ${NSD_CreateGroupBox} 5% 40% 90% 60% "Custom Items" + ${NSD_CreateGroupBox} 1% 31% 98% 65% "GPU and Location" Pop $0 - ${NSD_CreateCheckBox} 10% $lblPos% 80% 11u " IMPORTANT! Check here if you do NOT have an NVIDIA graphics card" - Pop $noNvidia - intOp $lblPos $lblPos + 10 - - ${NSD_CreateLabel} 10% $lblPos% 80% 10u "Environment Name (NB: Existing envs with this name will be deleted):" + ${NSD_CreateRadioButton} 4% $lblPos% 27% 20u "NVIDIA RTX 20xx +" + Pop $ctlRadio + ${NSD_AddStyle} $ctlRadio ${WS_GROUP} + nsDialogs::SetUserData $ctlRadio "nvidia13" + ${NSD_OnClick} $ctlRadio RadioClick + ${NSD_CreateRadioButton} 32% $lblPos% 25% 20u "Nvidia GTX 9xx - GTX 10xx" + Pop $ctlRadio + nsDialogs::SetUserData $ctlRadio "nvidia12" + ${NSD_OnClick} $ctlRadio RadioClick + ${NSD_CreateRadioButton} 60% $lblPos% 25% 20u "Nvidia GTX 7xx - GTX 8xx" + Pop $ctlRadio + nsDialogs::SetUserData $ctlRadio "nvidia11" + ${NSD_OnClick} $ctlRadio RadioClick + ${NSD_CreateRadioButton} 88% $lblPos% 25% 20u "CPU" + Pop $ctlRadio + nsDialogs::SetUserData $ctlRadio "cpu" + ${NSD_OnClick} $ctlRadio RadioClick + + intOp $lblPos $lblPos + 18 + + ${NSD_CreateLabel} 4% $lblPos% 90% 10u "Environment Name (NB: Existing envs with this name will be deleted):" pop $0 intOp $lblPos $lblPos + 7 - ${NSD_CreateText} 10% $lblPos% 80% 11u "$envName" + ${NSD_CreateText} 4% $lblPos% 90% 11u "$envName" Pop $envName intOp $lblPos $lblPos + 11 ${If} $InstallConda == 1 - ${NSD_CreateLabel} 10% $lblPos% 80% 18u "Conda is required but could not be detected. If you have Conda already installed specify the location below, otherwise leave blank:" + ${NSD_CreateLabel} 4% $lblPos% 90% 18u "Conda is required but could not be detected. If you have Conda already installed specify the location below, otherwise leave blank:" Pop $0 intOp $lblPos $lblPos + 13 - ${NSD_CreateText} 10% $lblPos% 73% 12u "" + ${NSD_CreateText} 4% $lblPos% 73% 12u "" Pop $ctlCondaText - ${NSD_CreateButton} 83% $lblPos% 7% 12u "..." + ${NSD_CreateButton} 77% $lblPos% 13% 12u "..." Pop $ctlCondaButton ${NSD_OnClick} $ctlCondaButton fnc_hCtl_test_DirRequest1_Click ${EndIf} @@ -165,6 +179,12 @@ Function pgPrereqCreate nsDialogs::Show FunctionEnd +Function RadioClick + Pop $R0 + nsDialogs::GetUserData $R0 + Pop $setupType +FunctionEnd + Function fnc_hCtl_test_DirRequest1_Click Pop $R0 ${If} $R0 == $ctlCondaButton @@ -178,16 +198,24 @@ Function fnc_hCtl_test_DirRequest1_Click FunctionEnd Function pgPrereqLeave + call CheckSetupType Call CheckCustomCondaPath - ${NSD_GetState} $noNvidia $noNvidia ${NSD_GetText} $envName $envName FunctionEnd +Function CheckSetupType + ${If} $setupType == "" + MessageBox MB_OK "Please specify whether to setup for Nvidia or CPU." + Abort + ${EndIf} + StrCpy $Log "$log(check) Setting up for: $setupType$\n" +FunctionEnd + Function CheckCustomCondaPath ${NSD_GetText} $ctlCondaText $2 ${If} $2 != "" - nsExec::ExecToStack "$2\Scripts\conda.exe -V" + nsExec::ExecToStack "$\"$2\Scripts\conda.exe$\" -V" pop $0 pop $1 ${If} $0 == 0 @@ -200,39 +228,57 @@ Function CheckCustomCondaPath ${EndIf} FunctionEnd -Function CheckPrerequisites - #Git - nsExec::ExecToStack "git --version" - pop $0 - pop $1 - ${If} $0 == 0 - StrCpy $Log "$log(check) Git installed: $1" - ${Else} - StrCpy $InstallGit 1 - ${EndIf} +Function CheckConda + # miniconda + nsExec::ExecToStack "$\"$dirMiniconda\Scripts\conda.exe$\" -V" + pop $0 + pop $1 + + nsExec::ExecToStack "$\"$dirMinicondaAll\Scripts\conda.exe$\" -V" + pop $2 + pop $3 + + # anaconda + nsExec::ExecToStack "$\"$dirAnaconda\Scripts\conda.exe$\" -V" + pop $4 + pop $5 + + nsExec::ExecToStack "$\"$dirAnacondaAll\Scripts\conda.exe$\" -V" + pop $6 + pop $7 + + ${If} $0 == 0 + StrCpy $dirConda "$dirMiniconda" + StrCpy $Log "$log(check) MiniConda installed: $1" + ${ElseIf} $2 == 0 + StrCpy $dirConda "$dirMinicondaAll" + StrCpy $Log "$log(check) MiniConda installed: $3" + ${ElseIf} $4 == 0 + StrCpy $dirConda "$dirAnaconda" + StrCpy $Log "$log(check) AnaConda installed: $5" + ${ElseIf} $6 == 0 + StrCpy $dirConda "$dirAnacondaAll" + StrCpy $Log "$log(check) AnaConda installed: $7" + ${EndIf} +FunctionEnd +Function CheckPrerequisites # Conda - # miniconda - nsExec::ExecToStack "$dirMiniconda\Scripts\conda.exe -V" - pop $0 - pop $1 - - # anaconda - nsExec::ExecToStack "$dirAnaconda\Scripts\conda.exe -V" - pop $2 - pop $3 + Call CheckConda + Push $PROFILE + Call CheckForSpaces + Pop $R0 + # If spaces in user profile look for and install Conda in C: + ${If} $dirConda == "" + ${AndIf} $R0 != 0 + StrCpy $dirMiniconda "C:\Miniconda3" + StrCpy $dirAnaconda "C:\Anaconda3" + Call CheckConda + ${EndIf} - ${If} $0 == 0 - StrCpy $dirConda "$dirMiniconda" - StrCpy $Log "$log(check) MiniConda installed: $1" - ${Else} - ${If} $2 == 0 - StrCpy $dirConda "$dirAnaconda" - StrCpy $Log "$log(check) AnaConda installed: $0" - ${Else} - StrCpy $InstallConda 1 - ${EndIf} - ${EndIf} + ${If} $dirConda == "" + StrCpy $InstallConda 1 + ${EndIf} # CPU Capabilities ${If} ${CPUSupports} "AVX2" @@ -249,98 +295,121 @@ Function CheckPrerequisites StrCpy $Log "$Log(check) Completed check for installed applications$\n" FunctionEnd +Function CheckForSpaces +# Check a string for space (Used for defining MiniConda install Location) + Exch $R0 + Push $R1 + Push $R2 + Push $R3 + StrCpy $R1 -1 + StrCpy $R3 $R0 + StrCpy $R0 0 + loop: + StrCpy $R2 $R3 1 $R1 + IntOp $R1 $R1 - 1 + StrCmp $R2 "" done + StrCmp $R2 " " 0 loop + IntOp $R0 $R0 + 1 + Goto loop + done: + Pop $R3 + Pop $R2 + Pop $R1 + Exch $R0 + +FunctionEnd + Section Install Push $Log Call MultiDetailPrint - Call InstallPrerequisites - Call CloneRepo + Call InstallConda Call SetEnvironment - Call InstallDlib + Call InstallGit + Call CloneRepo Call SetupFaceSwap + Call AddGuiLauncher Call DesktopShortcut + ExecShell "open" "${wwwFaceswap}" + DetailPrint "Visit ${wwwFaceswap} for help and support." SectionEnd -Function InstallPrerequisites - # GIT - ${If} $InstallGit == 1 - DetailPrint "Downloading Git..." - inetc::get /caption "Downloading Git..." /canceltext "Cancel" ${wwwGit} "git_installer.exe" /end - Pop $0 # return value = exit code, "OK" means OK - ${If} $0 == "OK" - DetailPrint "Installing Git..." - SetDetailsPrint listonly - ExecWait "$dirTemp\git_installer.exe ${flagsGit} /LOADINF=$\"$gitInf$\"" $0 - SetDetailsPrint both - ${If} $0 != 0 - DetailPrint "Error Installing Git" - StrCpy $InstallFailed 1 - ${EndIf} - ${Else} - DetailPrint "Error Downloading Git" - StrCpy $InstallFailed 1 - ${EndIf} - ${EndIf} - - # CONDA - ${If} $InstallConda == 1 - DetailPrint "Downloading Miniconda3..." - inetc::get /caption "Downloading Miniconda3." /canceltext "Cancel" ${wwwConda} "Miniconda3.exe" /end - Pop $0 - ${If} $0 == "OK" - DetailPrint "Installing Miniconda3. This will take a few minutes..." - SetDetailsPrint listonly - ExecWait "$dirTemp\Miniconda3.exe ${flagsConda}" $0 - StrCpy $dirConda "$dirMiniconda" - SetDetailsPrint both - ${If} $0 != 0 - DetailPrint "Error Installing Miniconda3" - StrCpy $InstallFailed 1 - ${EndIf} - ${Else} - DetailPrint "Error Downloading Miniconda3" +Function InstallConda + ${If} $InstallConda == 1 + DetailPrint "Downloading Miniconda3..." + inetc::get /caption "Downloading Miniconda3." /canceltext "Cancel" ${wwwConda} "Miniconda3.exe" /end + Pop $0 + ${If} $0 == "OK" + DetailPrint "Installing Miniconda3. This will take a few minutes..." + SetDetailsPrint listonly + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirTemp\Miniconda3.exe$\" ${flagsConda}" + pop $0 + ExecDos::wait $0 + pop $0 + StrCpy $dirConda "$dirMiniconda" + SetDetailsPrint both + ${If} $0 != 0 + DetailPrint "Error Installing Miniconda3" StrCpy $InstallFailed 1 ${EndIf} + ${Else} + DetailPrint "Error Downloading Miniconda3" + StrCpy $InstallFailed 1 ${EndIf} + ${EndIf} ${If} $InstallFailed == 1 Call Abort ${Else} - DetailPrint "All Prerequisites installed." + DetailPrint "Miniconda3 installed." ${EndIf} FunctionEnd -Function CloneRepo - DetailPrint "Downloading Faceswap..." +Function SetEnvironment + DetailPrint "Initializing Conda..." SetDetailsPrint listonly - ExecWait "$PROGRAMFILES64\git\bin\git.exe clone ${flagsRepo} $INSTDIR" $0 + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\Scripts\conda.exe$\" tos accept" + pop $0 + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\scripts\activate.bat$\" && conda update -y -n base -c defaults conda && conda deactivate" + pop $0 + ExecDos::wait $0 + pop $0 SetDetailsPrint both - ${If} $0 != 0 - DetailPrint "Error Downloading Faceswap" - Call Abort - ${EndIf} -FunctionEnd - -Function SetEnvironment - # Updating Conda breaks setup.py. Commented out in case this issue gets resolved in future -# DetailPrint "Initializing Conda..." -# SetDetailsPrint listonly -# ExecWait "$dirConda\scripts\activate.bat && conda update -y -n base -c defaults conda && conda deactivate" $0 -# SetDetailsPrint both DetailPrint "Creating Conda Virtual Environment..." - IfFileExists "$dirConda\envs\faceswap" DeleteEnv CreateEnv + IfFileExists "$dirConda\envs\$envName" DeleteEnv CreateEnv DeleteEnv: + DetailPrint "Removing existing Conda Virtual Environment..." SetDetailsPrint listonly - ExecWait "$dirConda\scripts\activate.bat && conda env remove -y -n $\"$envName$\" && conda deactivate" $0 + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\scripts\activate.bat$\" && conda env remove -y -n $\"$envName$\" && conda deactivate" + pop $0 + ExecDos::wait $0 + pop $0 SetDetailsPrint both ${If} $0 != 0 DetailPrint "Error deleting Conda Virtual Environment" Call Abort ${EndIf} + # Often Conda won't actually remove the folder and some of it's contents which leads to permission problems later + IfFileExists "$dirConda\envs\$envName" DeleteFolder CreateEnv + DeleteFolder: + DetailPrint "Deleting stale Conda Virtual Environment files..." + SetDetailsPrint listonly + RMDir /r "$dirConda\envs\$envName" + pop $0 + SetDetailsPrint both + ${If} $0 != 0 + DetailPrint "Error deleting Conda Virtual Environment Folder" + Call Abort + ${EndIf} + CreateEnv: SetDetailsPrint listonly - ExecWait "$dirConda\scripts\activate.bat && conda create ${flagsEnv} -n $\"$envName$\" && conda deactivate" $0 + StrCpy $0 "${flagsEnv}" + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\scripts\activate.bat$\" && conda create $0 -c conda-forge -n $\"$envName$\" && conda deactivate" + pop $0 + ExecDos::wait $0 + pop $0 SetDetailsPrint both ${If} $0 != 0 DetailPrint "Error Creating Conda Virtual Environment" @@ -348,46 +417,43 @@ Function SetEnvironment ${EndIf} FunctionEnd -Function InstallDlib - DetailPrint "Installing Dlib..." +Function InstallGit + DetailPrint "Installing Git..." SetDetailsPrint listonly - - StrCpy $dlibWhl ${prefixDlib} - - ${If} $noNvidia != 1 - StrCpy $dlibWhl "$dlibWhl${cudaDlib}" - ${EndIf} - - ${If} $hasAVX == 1 - StrCpy $dlibWhl "$dlibWhl${avxDlib}" - ${ElseIf} $hasSSE4 == 1 - StrCpy $dlibWhl "$dlibWhl${sseDlib}" - ${Else} - StrCpy $dlibWhl "$dlibWhl${noneDlib}" + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\scripts\activate.bat$\" && conda activate $\"$envName$\" && conda install git -y -q && conda deactivate" + pop $0 + ExecDos::wait $0 + pop $0 + SetDetailsPrint both + ${If} $0 != 0 + DetailPrint "Error Installing Git" + StrCpy $InstallFailed 1 ${EndIf} +FunctionEnd - StrCpy $dlibWhl "$dlibWhl.whl" - DetailPrint "Renaming $dlibWhl to ${dlibFinalName}" - Rename $dirTemp\$dlibWhl $dirTemp\${dlibFinalName} - - ExecWait "$dirConda\scripts\activate.bat && conda activate $\"$envName$\" && pip install $dirTemp\${dlibFinalName} && conda deactivate" $0 +Function CloneRepo + DetailPrint "Downloading Faceswap..." + SetDetailsPrint listonly + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\scripts\activate.bat$\" && conda activate $\"$envName$\" && git clone ${flagsRepo} $\"$INSTDIR$\" && conda deactivate" + pop $0 + ExecDos::wait $0 + pop $0 SetDetailsPrint both ${If} $0 != 0 - DetailPrint "Error Installing Dlib" + DetailPrint "Error Downloading Faceswap" Call Abort ${EndIf} - FunctionEnd Function SetupFaceSwap - DetailPrint "Setting up FaceSwap Environment" + DetailPrint "Setting up FaceSwap Environment... This may take a while" StrCpy $0 "${flagsSetup}" - ${If} $noNvidia != 1 - StrCpy $0 "$0 --gpu" - ${EndIf} - + StrCpy $0 "$0 --$setupType" SetDetailsPrint listonly - ExecWait "$dirConda\scripts\activate.bat && conda activate $\"$envName$\" && python $INSTDIR\$0 && conda deactivate" $0 + ExecDos::exec /NOUNLOAD /ASYNC /DETAILED "$\"$dirConda\scripts\activate.bat$\" && conda activate $\"$envName$\" && python -u $\"$INSTDIR\setup.py$\" $0 && conda deactivate" + pop $0 + ExecDos::wait $0 + pop $0 SetDetailsPrint both ${If} $0 != 0 DetailPrint "Error Setting up Faceswap" @@ -395,12 +461,16 @@ Function SetupFaceSwap ${EndIf} FunctionEnd -Function DesktopShortcut - DetailPrint "Creating Desktop Shortcut" +Function AddGuiLauncher + DetailPrint "Creating GUI Launcher" SetOutPath "$INSTDIR" StrCpy $0 "faceswap_win_launcher.bat" FileOpen $9 "$INSTDIR\$0" w FileWrite $9 "$\"$dirConda\scripts\activate.bat$\" && conda activate $\"$envName$\" && python $\"$INSTDIR/faceswap.py$\" gui$\r$\n" FileClose $9 - CreateShortCut "$DESKTOP\FaceSwap.lnk" "$INSTDIR\$0" "" "$INSTDIR\.install\windows\fs_logo_32.ico" -FunctionEnd \ No newline at end of file +FunctionEnd + +Function DesktopShortcut + DetailPrint "Creating Desktop Shortcut" + CreateShortCut "$DESKTOP\FaceSwap.lnk" "$\"$INSTDIR\$0$\"" "" "$INSTDIR\.install\windows\fs_logo.ico" +FunctionEnd diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000000..1b36b2bbec --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,27 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-24.04 + tools: + python: "3.13" + apt_packages: + - graphviz + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/sphinx_requirements.txt diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..6cc7a89219 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at deefakesrepo@gmail.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/Dockerfile.cpu b/Dockerfile.cpu index eb201a2e73..0c27ec9b69 100755 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,14 +1,19 @@ -FROM tensorflow/tensorflow:latest-py3 +FROM ubuntu:22.04 -RUN apt-get update -qq -y \ - && apt-get install -y libsm6 libxrender1 libxext-dev python3-tk\ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* +# To disable tzdata and others from asking for input +ENV DEBIAN_FRONTEND noninteractive +ENV FACESWAP_BACKEND cpu -COPY requirements.txt /opt/ -RUN pip3 install cmake -RUN pip3 install dlib --install-option=--yes --install-option=USE_AVX_INSTRUCTIONS -RUN pip3 --no-cache-dir install -r /opt/requirements.txt && rm /opt/requirements.txt +RUN apt-get update -qq -y +RUN apt-get upgrade -y +RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git -WORKDIR "/notebooks" -CMD ["/run_jupyter.sh", "--allow-root"] +RUN ln -s $(which python3) /usr/local/bin/python + +RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git +WORKDIR "/faceswap" + +RUN python -m pip install --upgrade pip +RUN python -m pip --no-cache-dir install -r ./requirements/requirements_cpu.txt + +CMD ["/bin/bash"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu index c7ca4b62c9..5b9c0abd0a 100755 --- a/Dockerfile.gpu +++ b/Dockerfile.gpu @@ -1,20 +1,18 @@ -FROM tensorflow/tensorflow:latest-gpu-py3 +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 -RUN apt-get update -qq -y \ - && apt-get install -y libsm6 libxrender1 libxext-dev python3-tk\ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* +ENV DEBIAN_FRONTEND=noninteractive +ENV FACESWAP_BACKEND nvidia -COPY requirements.txt /opt/ -RUN pip3 install cmake -RUN pip3 install dlib --install-option=--yes --install-option=USE_AVX_INSTRUCTIONS -RUN pip3 --no-cache-dir install -r /opt/requirements.txt && rm /opt/requirements.txt +RUN apt-get update -qq -y +RUN apt-get upgrade -y +RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git -# patch for tensorflow:latest-gpu-py3 image -RUN cd /usr/local/cuda/lib64 \ - && mv stubs/libcuda.so ./ \ - && ln -s libcuda.so libcuda.so.1 \ - && ldconfig +RUN ln -s $(which python3) /usr/local/bin/python -WORKDIR "/notebooks" -CMD ["/run_jupyter.sh", "--allow-root"] +RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git +WORKDIR "/faceswap" + +RUN python -m pip install --upgrade pip +RUN python -m pip --no-cache-dir install -r ./requirements/requirements_nvidia.txt + +CMD ["/bin/bash"] diff --git a/INSTALL.md b/INSTALL.md index eded06d367..63ba045f91 100755 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,42 +1,51 @@ -# Installing Faceswap -- [Installing Faceswap](#installing-faceswap) +# Installing faceswap +- [Installing faceswap](#installing-faceswap) - [Prerequisites](#prerequisites) - [Hardware Requirements](#hardware-requirements) - [Supported operating systems](#supported-operating-systems) - [Important before you proceed](#important-before-you-proceed) -- [General Install Guide](#general-install-guide) - - [Installing dependencies](#installing-dependencies) - - [Getting the faceswap code](#getting-the-faceswap-code) - - [Setup](#setup) - - [About some of the options](#about-some-of-the-options) - - [Run the project](#run-the-project) - - [Notes](#notes) -- [Windows Install Guide](#windows-install-guide) +- [Linux, Windows and macOS Install Guide](#linux-windows-and-macos-install-guide) - [Installer](#installer) - - [Manual Install](#Manual-install) + - [Manual Install](#manual-install) - [Prerequisites](#prerequisites-1) - - [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015) - - [Cuda](#cuda) - - [cuDNN](#cudnn) - - [CMake](#cmake) - [Anaconda](#anaconda) - [Git](#git) - - [Setup](#setup-1) + - [Setup](#setup) - [Anaconda](#anaconda-1) - [Set up a virtual environment](#set-up-a-virtual-environment) - [Entering your virtual environment](#entering-your-virtual-environment) - - [Faceswap](#faceswap) + - [faceswap](#faceswap) - [Easy install](#easy-install) - - [Manual install](#manual-install) - - [Running Faceswap](#running-faceswap) + - [Manual install](#manual-install-1) + - [Running faceswap](#running-faceswap) - [Create a desktop shortcut](#create-a-desktop-shortcut) - [Updating faceswap](#updating-faceswap) - - [Dlib](#dlib) - - [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support) - - [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support) +- [macOS (Apple Silicon) Install Guide](#macos-apple-silicon-install-guide) + - [Prerequisites](#prerequisites-2) + - [OS](#os) + - [XCode Tools](#xcode-tools) + - [XQuartz](#xquartz) + - [Conda](#conda) + - [Setup](#setup-1) + - [Create and Activate the Environment](#create-and-activate-the-environment) + - [faceswap](#faceswap-1) + - [Easy install](#easy-install-1) +- [General Install Guide](#general-install-guide) + - [Installing dependencies](#installing-dependencies) + - [Git](#git-1) + - [Python](#python) + - [Virtual Environment](#virtual-environment) + - [Getting the faceswap code](#getting-the-faceswap-code) + - [Setup](#setup-2) + - [About some of the options](#about-some-of-the-options) +- [Docker Install Guide](#docker-install-guide) + - [Docker CPU](#docker-cpu) + - [Docker Nvidia](#docker-nvidia) +- [Run the project](#run-the-project) + - [Notes](#notes) # Prerequisites -Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up. +Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up. The type of computations that the process does are well suited for graphics cards, rather than regular processors. **It is pretty much required that you run the training process on a desktop or server capable GPU.** Running this on your CPU means it can take weeks to train your model, compared to several hours on a GPU. @@ -46,214 +55,47 @@ The type of computations that the process does are well suited for graphics card - **A powerful CPU** - Laptop CPUs can often run the software, but will not be fast enough to train at reasonable speeds - **A powerful GPU** - - Currently only Nvidia GPUs are supported. AMD graphics cards are not supported. - This is not something that we have control over. It is a requirement of the Tensorflow library. - - The GPU needs to support at least CUDA Compute Capability 3.0 or higher. + - Currently, Nvidia GPUs are fully supported + - More modern AMD GPUs are supported on Linux through ROCm. + - M-series Macs are supported using Metal + - If using an Nvidia GPU, then it needs to support at least CUDA Compute Capability 3.5. (Release 1.0 will work on Compute Capability 3.0) To see which version your GPU supports, consult this list: https://developer.nvidia.com/cuda-gpus Desktop cards later than the 7xx series are most likely supported. - **A lot of patience** ## Supported operating systems -- **Windows 10** - Windows 7 and 8 might work. Your milage may vary. Windows has an installer which will set up everything you need. See: https://github.com/deepfakes/faceswap/releases +- **Windows 10/11** + Windows 7 and 8 might work for Nvidia. Your mileage may vary. + Windows has an installer which will set up everything you need. See: https://github.com/deepfakes/faceswap/releases - **Linux** - Most Ubuntu/Debian or CentOS based Linux distributions will work. + Most Ubuntu/Debian or CentOS based Linux distributions will work. There is a Linux install script that will install and set up everything you need. See: https://github.com/deepfakes/faceswap/releases - **macOS** - GPU support on macOS is limited due to lack of drivers/libraries from Nvidia. -- All operating systems must be 64-bit for Tensorflow to run. + Experimental support for GPU-accelerated, native Apple Silicon processing (e.g. Apple M1 chips). Installation instructions can be found [further down this page](#macos-apple-silicon-install-guide). + Intel based macOS systems should work, but you will need to follow the [Manual Install](#manual-install) instructions. +- All operating systems must be 64-bit. -Alternatively there is a docker image that is based on Debian. +Alternatively, there is a docker image that is based on Debian. # Important before you proceed -**In its current iteration, the project relies heavily on the use of the command line, although a gui is available. if you are unfamiliar with command line tools, you may have difficulty setting up the environment and should perhaps not attempt any of the steps described in this guide.** This guide assumes you have intermediate knowledge of the command line. +**In its current iteration, the project relies heavily on the use of the command line, although a gui is available. if you are unfamiliar with command line tools, you may have difficulty setting up the environment and should perhaps not attempt any of the steps described in this guide.** This guide assumes you have intermediate knowledge of the command line. The developers are also not responsible for any damage you might cause to your own computer. -# General Install Guide -## Installing dependencies -- Python >= 3.2-3.6 64-bit (cannot be 3.7.x as Tensorflow has not been updated to provide support) - - apt/yum install python3 (Linux) - - [Installer](https://www.python.org/downloads/release/python-368/) (Windows) - - [brew](https://brew.sh/) install python3 (macOS) - -- [virtualenv](https://github.com/pypa/virtualenv) and [virtualenvwrapper](https://virtualenvwrapper.readthedocs.io) may help when you are not using docker. -- If you are using an Nvidia graphics card You should install CUDA (https://developer.nvidia.com/cuda-zone) and CUDNN (https://developer.nvidia.com/cudnn). If you do not plan to build Tensorflow yourself, make sure you install no higher than version 9.0 of CUDA and 7.0.x of CUDNN -- dlib is required for face recognition and is compiled as part of the setup process. You will need the following applications for your os to successfully install dlib (nb: list may be incomplete. Please raise an issue if another prerequisite is required for your OS): - - Windows: Visual Studio 2015, CMake v3.8.2 - - Linux: build-essential, cmake - - macOS: xquartz - -## Getting the faceswap code -Simply download the code from http://github.com/deepfakes/faceswap - For development it is recommended to use git instead of downloading the code and extracting it. - -For now, extract the code to a directory where you're comfortable working with it. Navigate to it with the command line. For our example we will use `~/faceswap/` as our project directory. - -## Setup -Enter the folder that faceswap has been downloaded to and run: -```bash -python setup.py -``` -If setup fails for any reason you can still manually install the packages listed within requirements.txt - -### About some of the options - - CUDA: For acceleration. Requires a good nVidia Graphics Card (which supports CUDA inside) - - Docker: Provide a ready-made image. Hide trivial details. Get you straight to the project. - - nVidia-Docker: Access to the nVidia GPU on host machine from inside container. - -CUDA with Docker in 20 minutes. -``` -INFO The tool provides tips for installation - and installs required python packages -INFO Setup in Linux 4.14.39-1-MANJARO -INFO Installed Python: 3.6.5 64bit -INFO Installed PIP: 10.0.1 -Enable Docker? [Y/n] -INFO Docker Enabled -Enable CUDA? [Y/n] -INFO CUDA Enabled -INFO 1. Install Docker - https://www.docker.com/community-edition - - 2. Install Nvidia-Docker & Restart Docker Service - https://github.com/NVIDIA/nvidia-docker - - 3. Build Docker Image For Faceswap - docker build -t deepfakes-gpu -f Dockerfile.gpu . - - 4. Mount faceswap volume and Run it - # without gui. tools.py gui not working. - nvidia-docker run --rm -it -p 8888:8888 \ - --hostname faceswap-gpu --name faceswap-gpu \ - -v /opt/faceswap:/srv \ - deepfakes-gpu - - # with gui. tools.py gui working. - ## enable local access to X11 server - xhost +local: - ## enable nvidia device if working under bumblebee - echo ON > /proc/acpi/bbswitch - ## create container - nvidia-docker run -p 8888:8888 \ - --hostname faceswap-gpu --name faceswap-gpu \ - -v /opt/faceswap:/srv \ - -v /tmp/.X11-unix:/tmp/.X11-unix \ - -e DISPLAY=unix$DISPLAY \ - -e AUDIO_GID=`getent group audio | cut -d: -f3` \ - -e VIDEO_GID=`getent group video | cut -d: -f3` \ - -e GID=`id -g` \ - -e UID=`id -u` \ - deepfakes-gpu - - 5. Open a new terminal to interact with the project - docker exec faceswap-gpu python /srv/tools.py gui -``` - -A successful setup log, without docker. -``` -INFO The tool provides tips for installation - and installs required python packages -INFO Setup in Linux 4.14.39-1-MANJARO -INFO Installed Python: 3.6.5 64bit -INFO Installed PIP: 10.0.1 -Enable Docker? [Y/n] n -INFO Docker Disabled -Enable CUDA? [Y/n] -INFO CUDA Enabled -INFO CUDA version: 9.1 -INFO cuDNN version: 7 -WARNING Tensorflow has no official prebuild for CUDA 9.1 currently. - To continue, You have to build your own tensorflow-gpu. - Help: https://www.tensorflow.org/install/install_sources -Are System Dependencies met? [y/N] y -INFO Installing Missing Python Packages... -INFO Installing tensorflow-gpu -INFO Installing pathlib==1.0.1 -...... -INFO Installing tqdm -INFO Installing matplotlib -INFO All python3 dependencies are met. - You are good to go. -``` - -## Run the project -Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options. - -```bash -python faceswap.py -h -``` - -or run with `gui` to launch the GUI -```bash -python faceswap.py gui -``` - - -Proceed to [../blob/master/USAGE.md](USAGE.md) - -## Notes -This guide is far from complete. Functionality may change over time, and new dependencies are added and removed as time goes on. - -If you are experiencing issues, please raise them in the [faceswap-playground](https://github.com/deepfakes/faceswap-playground) repository instead of the main repo. - -# Windows Install Guide +# Linux, Windows and macOS Install Guide ## Installer -Windows now has an installer which installs everything for you and creates a desktop shortcut to launch straight into the GUI. You can download the installer from https://github.com/deepfakes/faceswap/releases. +Windows, Linux and macOS all have installers which set up everything for you. You can download the installer from https://github.com/deepfakes/faceswap/releases. -If you have issues with the installer then read on for the more manual way to install Faceswap on Windows. +If you have issues with the installer then read on for the more manual way to install faceswap on Windows. ## Manual Install -Setting up Faceswap can seem a little intimidating to new users, but it isn't that complicated, although a little time consuming. It is recommended to use Linux where possible as Windows will hog about 20% of your GPU Memory, making Faceswap run a little slower, however using Windows is perfectly fine and 100% supported. +Setting up faceswap can seem a little intimidating to new users, but it isn't that complicated, although a little time consuming. It is recommended to use Linux where possible as Windows will hog about 20% of your GPU Memory, making faceswap run a little slower, however using Windows is perfectly fine and 100% supported. ## Prerequisites -### Microsoft Visual Studio 2015 -**Important** Make sure to downoad the 2015 version of Microsoft Visual Studio - -Download and install Microsoft Visual Studio 2015 from: https://go.microsoft.com/fwlink/?LinkId=532606&clcid=0x409 - -On the install screen: -- Select "Custom" then click "Next"\ -![MSVS Custom](https://i.imgur.com/Bx8fjzT.png) -- Uncheck all previously checked options -- Expand "Programming Languages" and select "Visual C++"\ -![MSVS C++](https://i.imgur.com/c8k1IYD.png) -- Select "Next" and "Install" - - -### Cuda -**GPU Only** If you do not have an Nvidia GPU you can skip this step. - -At the time of writing Tensorflow (version 1.12) only supports Cuda up to version 9.0, but check https://www.tensorflow.org/install/gpu for the latest supported version. It is crucial that you download the correct version of Cuda. - -Download and install the correct version of the Cuda Toolkit from: https://developer.nvidia.com/cuda-toolkit-archive - -NB: Make a note of the install folder as you'll need to access it in the next step. - -### cuDNN -**GPU Only** If you do not have an Nvidia GPU you can skip this step. - -As with Cuda you will need to install the correct version of cuDNN that the latest Tensorflow supports. At the time of writing this is Tensorflow v1.12 which supports cuDNN version 7.2, but check https://www.tensorflow.org/install/gpu for the latest supported version. - -Download cuDNN from https://developer.nvidia.com/cudnn. You will need to create an account with Nvidia. - -At the bottom of the list of latest cuDNN release will be a link to "Archived cuDNN Releases": -![cuDNN Archive](https://i.imgur.com/dHiAsxg.png) - -Select this and choose the latest version of cuDNN that supports the version of Cuda you installed and has a minor version greater than or equal to the latest version that Tensorflow supports. (Eg Tensorflow 1.12 supports Cuda 9.0 and cuDNN 7.2. There is not an archived version of cuDNN 7.2 for Cuda 9.0, so select cuDNN version 7.3) -- Open the zip file -- Extract all of the files and folders into your Cuda folder (It is likely to be located in `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA`):\ -![cuDNN to Cuda](https://i.imgur.com/X098w0N.png) - -### CMake -Install the latest stable release of CMake from https://cmake.org/download/. (Scroll down the page for Latest Releases and select the relevant Binary distribution installer for your OS). - -When installing CMake make sure to enable the option to CMake to the system path: -![cmake path](https://i.imgur.com/XTtacdY.png) - ### Anaconda -Download and install the latest Python 3 Anacconda from: https://www.anaconda.com/download/. Unless you know what you are doing, you can leave all the options at default. +Download and install the latest Python 3 Anaconda from: https://www.anaconda.com/download/. Unless you know what you are doing, you can leave all the options at default. ### Git Download and install Git for Windows: https://git-scm.com/download/win. Unless you know what you are doing, you can leave all the options at default. @@ -268,9 +110,9 @@ Reboot your PC, so that everything you have just installed gets registered. - Select "Create" at the bottom - In the pop up: - Give it the name: faceswap - - **IMPORTANT**: Select python version 3.6 - - Hit "Create" (NB: This may take a while as it will need to download Python 3.6) -![Anaconda virtual env setup](https://i.imgur.com/Tl5tyVq.png) + - **IMPORTANT**: Select python version 3.13 + - Hit "Create" (NB: This may take a while as it will need to download Python) +![Anaconda virtual env setup](https://i.imgur.com/CLIDDfa.png) #### Entering your virtual environment To enter the virtual environment: @@ -279,33 +121,48 @@ To enter the virtual environment: - Hit the ">" arrow next to your faceswap environment and select "Open Terminal" ![Anaconda enter virtual env](https://i.imgur.com/rKSq2Pd.png) -### Faceswap +### faceswap - If you are not already in your virtual environment follow [these steps](#entering-your-virtual-environment) -- Get the Faceswap repo by typing: `git clone https://github.com/deepfakes/faceswap.git` +- Get the faceswap repo by typing: `git clone --depth 1 https://github.com/deepfakes/faceswap.git` - Enter the faceswap folder: `cd faceswap` #### Easy install -- Enter `python setup.py` and follow the prompts. - -If you have issues/errors follow the Manual install steps below. +- Enter the command `python setup.py` and follow the prompts: +- If you have issues/errors follow the Manual install steps below. #### Manual install -If dlib failed to install you can follow the steps to [manually install dlib](#dlib).\ -Once dlib is installed follow these steps: - +Do not follow these steps if the Easy Install above completed succesfully. +If you are using an Nvidia card make sure you have the correct versions of Cuda/cuDNN installed for the required version of Torch - Install tkinter (required for the GUI) by typing: `conda install tk` -- Install requirements: `pip install -r requirements.txt` -- Install Tensorflow (either GPU or CPU version depending on your setup): - - GPU Version: `pip install tensorflow-gpu` - - Non GPU Version: `pip install tensorflow` - -## Running Faceswap +- Install requirements: + - For **Nvidia** GPU users: + - RTX20xx GPUS onwards: `pip install -r ./requirements/requirements_nvidia_13.txt` + - GTX9xx - GTX10xx GPUs: `pip install -r ./requirements/requirements_nvidia_12.txt` + - GTX7xx - GTX8xx GPUs: `pip install -r ./requirements/requirements_nvidia_11.txt` + - **Note:** Maximum supported Python version for GTX8xx - GTX9xx GPUs is `3.13` + + - For **AMD** GPU users (Linux only): + - **Note** You must install a version of ROCm to your system that is compatible with your OS and GPU. + - ROCm 6.4: `pip install -r ./requirements/requirements_rocm64.txt` + - ROCm 6.3: `pip install -r ./requirements/requirements_rocm63.txt` + - ROCm 6.2: `pip install -r ./requirements/requirements_rocm62.txt` + - **Note:** Maximum supported Python version for ROCm 6.2 is `3.13` + - ROCm 6.1: `pip install -r ./requirements/requirements_rocm61.txt` + - **Note:** Maximum supported Python version for ROCm 6.1 is `3.13` + - ROCm 6.0: `pip install -r ./requirements/requirements_rocm60.txt` + - **Note:** Maximum supported Python version for ROCm 6.0 is `3.12` + + - For **CPU** users: `pip install -r ./requirements/requirements_cpu.txt` + + - For **Apple-Silicon (M Series)** users: `pip install -r ./requirements/requirements_apple-silicon.txt` + +## Running faceswap - If you are not already in your virtual environment follow [these steps](#entering-your-virtual-environment) - Enter the faceswap folder: `cd faceswap` - Enter the following to see the list of commands: `python faceswap.py -h` or enter `python faceswap.py gui` to launch the GUI ## Create a desktop shortcut -A desktop shortcut can be added to easily launch staight into the faceswap GUI: +A desktop shortcut can be added to easily launch straight into the faceswap GUI: - Open Notepad - Paste the following: @@ -316,34 +173,200 @@ A desktop shortcut can be added to easily launch staight into the faceswap GUI: ## Updating faceswap It's good to keep faceswap up to date as new features are added and bugs are fixed. To do so: +- If using the GUI you can go to the Help menu and select "Check for Updates...". If updates are available go to the Help menu and select "Update Faceswap". Restart Faceswap to complete the update. - If you are not already in your virtual environment follow [these steps](#entering-your-virtual-environment) - Enter the faceswap folder: `cd faceswap` - Enter the following `git pull --all` -- Once the latest version has downloaded, make sure your requirements are up to date: `pip install --upgrade -r requirements.txt` +- Once the latest version has downloaded, make sure your dependencies are up to date. There is a script to help with this: `python update_deps.py` -## Dlib -You should only need to follow these steps if you want the latest Dlib code or the process was unable to install Dlib for you. +# macOS (Apple Silicon) Install Guide -For reasons outside of our control, this is the trickiest part of the process, and most of the prerequisites you installed are to support just Dlib. It is recommended to build Dlib from source for 3 main reasons: -1) To get the latest version -2) Enable GPU Support in Dlib -3) To prevent yourself running into a whole host of issues later in the process. +macOS now has [an installer](#linux-windows-and-macos-install-guide) which sets everything up for you, but if you run into difficulties and need to set things up manually, the steps are as follows: -If you are not bothered about having GPU support or the latest version, scroll to the end of this section for a simple one-liner to install the CPU version of Dlib. -### Build Latest Dlib with GPU Support -- If you are not already in your virtual environment follow [these steps](#entering-your-virtual-environment) -- In the terminal type: `git clone https://github.com/davisking/dlib.git` -- Enter the dlib folder: `cd dlib` -- Add Visual Studio to your path by typing: `SET PATH=%PATH%;C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin` -- Enter: `python setup.py -G "Visual Studio 14 2015" install --yes USE_AVX_INSTRUCTIONS --yes DLIB_USE_CUDA --clean` +## Prerequisites -This will build and install dlib for you. It is worth backing up the generated .egg file somewhere so that you can re-install it if you ever need to rather than having to re-compile: -- From within the dlib folder copy the file named `dlib-xx.yy.zz-py3.5-win-amd64.egg` to somewhere safe -- If you ever need to install it again, then from within your virtual environment enter: `python -m easy_install ` +### OS +macOS 12.0+ -Once Dlib is built, you can remove Visual Studio and CMake from your PC. +### XCode Tools +```sh +xcode-select --install +``` -### Easy install of Dlib without GPU Support -NB: Don't do this if you have already compiled Dlib with GPU support. -- If you are not already in your virtual environment follow [these steps](#entering-your-virtual-environment) -- In the terminal type: `conda install -c conda-forge dlib` +### XQuartz +Download and install from: +- https://www.xquartz.org/ + +### Conda +Download and install the latest Conda env from: +- https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh + +Install Conda: +```sh +$ chmod +x ~/Downloads/Miniforge3-MacOSX-arm64.sh +$ sh ~/Downloads/Miniforge3-MacOSX-arm64.sh +$ source ~/miniforge3/bin/activate +``` +## Setup +### Create and Activate the Environment +```sh +$ conda create --name faceswap python=3.13 +$ conda activate faceswap +``` + +### faceswap +- Download the faceswap repo and enter the faceswap folder: +```sh +$ git clone --depth 1 https://github.com/deepfakes/faceswap.git +$ cd faceswap +``` + +#### Easy install +```sh +$ python setup.py +``` + +- If you have issues/errors follow the Manual install steps below. + + +# General Install Guide + +## Installing dependencies +### Git +Git is required for obtaining the code and keeping your codebase up to date. +Obtain git for your distribution from the [git website](https://git-scm.com/downloads). + +### Python +The recommended install method is to use a Conda3 Environment as this will handle the installation of Nvidia's CUDA and cuDNN straight into your Conda Environment. This is by far the easiest and most reliable way to setup the project. + - MiniConda3 is recommended: [MiniConda3](https://docs.conda.io/en/latest/miniconda.html) + +Alternatively you can install Python (3.14 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Torch yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Torch. + - Python distributions: + - apt/yum install python3 (Linux) + - [Installer](https://www.python.org/downloads/release/python-368/) (Windows) + - [brew](https://brew.sh/) install python3 (macOS) + +### Virtual Environment + It is highly recommended that you setup faceswap inside a virtual environment. In fact we will not generally support installations that are not within a virtual environment as troubleshooting package conflicts can be next to impossible. + + If using Conda3 then setting up virtual environments is relatively straight forward. More information can be found at [Conda Docs](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) + + If using a default Python distribution then [virtualenv](https://github.com/pypa/virtualenv) and [virtualenvwrapper](https://virtualenvwrapper.readthedocs.io) may help when you are not using docker. + + +## Getting the faceswap code +It is recommended to clone the repo with git instead of downloading the code from http://github.com/deepfakes/faceswap and extracting it as this will make it far easier to get the latest code (which can be done from the GUI). To clone a repo you can either use the Git GUI for your distribution or open up a command prompt, enter the folder where you want to store faceswap and enter: +```bash +git clone https://github.com/deepfakes/faceswap.git +``` + + +## Setup +Enter your virtual environment and then enter the folder that faceswap has been downloaded to and run: +```bash +python setup.py +``` +If setup fails for any reason you can still manually install the packages listed within the files in the requirements folder. + +### About some of the options + - CUDA: For acceleration. Requires a good nVidia Graphics Card (which supports CUDA inside) + - Docker: Provide a ready-made image. Hide trivial details. Get you straight to the project. + - nVidia-Docker: Access to the nVidia GPU on host machine from inside container. + +# Docker Install Guide + +This Faceswap repo contains Docker build scripts for CPU and Nvidia backends. The scripts will set up a Docker container for you and install the latest version of the Faceswap software. + +You must first ensure that Docker is installed and running on your system. Follow the guide for downloading and installing Docker from their website: + + - https://www.docker.com/get-started + +Once Docker is installed and running, follow the relevant steps for your chosen backend +## Docker CPU +To run the CPU version of Faceswap follow these steps: + +1. Build the Docker image For faceswap: +``` +docker build \ +-t faceswap-cpu \ +https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.cpu +``` +2. Launch and enter the Faceswap container: + + a. For the **headless/command line** version of Faceswap run: + ``` + docker run --rm -it faceswap-cpu + ``` + You can then execute faceswap the standard way: + ``` + python faceswap.py --help + ``` + b. For the **GUI** version of Faceswap run: + ``` + xhost +local: && \ + docker run --rm -it \ + -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY=${DISPLAY} \ + faceswap-cpu + ``` + You can then launch the GUI with + ``` + python faceswap.py gui + ``` + ## Docker Nvidia +To build the NVIDIA GPU version of Faceswap, follow these steps: + +1. Nvidia Docker builds need extra resources to provide the Docker container with access to your GPU. + + a. Follow the instructions to install and apply the `Nvidia Container Toolkit` for your distribution from: + - https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html + + b. If Docker is already running, restart it to pick up the changes made by the Nvidia Container Toolkit. + +2. Build the Docker image For faceswap +``` +docker build \ +-t faceswap-gpu \ +https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.gpu +``` +1. Launch and enter the Faceswap container: + + a. For the **headless/command line** version of Faceswap run: + ``` + docker run --runtime=nvidia --rm -it faceswap-gpu + ``` + You can then execute faceswap the standard way: + ``` + python faceswap.py --help + ``` + b. For the **GUI** version of Faceswap run: + ``` + xhost +local: && \ + docker run --runtime=nvidia --rm -it \ + -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY=${DISPLAY} \ + faceswap-gpu + ``` + You can then launch the GUI with + ``` + python faceswap.py gui + ``` +# Run the project +Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options. + +```bash +python faceswap.py -h +``` + +or run with `gui` to launch the GUI +```bash +python faceswap.py gui +``` + + +Proceed to [../blob/master/USAGE.md](USAGE.md) + +## Notes +This guide is far from complete. Functionality may change over time, and new dependencies are added and removed as time goes on. + +If you are experiencing issues, please raise them in the [faceswap Forum](https://faceswap.dev/forum) instead of the main repo. Usage questions raised in the issues within this repo are liable to be closed without response. diff --git a/README.md b/README.md index 382140781b..fd7941cb19 100755 --- a/README.md +++ b/README.md @@ -1,95 +1,110 @@ -**Notice:** This repository is not operated or maintained by /u/deepfakes. Please read the explanation below for details. # deepfakes_faceswap -Faceswap is a tool that utilizes deep learning to recognize and swap faces in pictures and videos. -![Screenshots](https://i.imgur.com/nWHFLDf.jpg) +### Important information for **Patreon** and **PayPal** supporters. Please see this forum post: https://forum.faceswap.dev/viewtopic.php?f=14&t=3120 + +

+ +
FaceSwap is a tool that utilizes deep learning to recognize and swap faces in pictures and videos. +

+

+ +

+ +

+ +    

+ +

+ +
Emma Stone/Scarlett Johansson FaceSwap using the Phaze-A model +

-
Jennifer Lawrence/Steve Buscemi Faceswap using the Villain model +
Jennifer Lawrence/Steve Buscemi FaceSwap using the Villain model

+ +![Build Status](https://github.com/deepfakes/faceswap/actions/workflows/pytest.yml/badge.svg) [![Documentation Status](https://readthedocs.org/projects/faceswap/badge/?version=latest)](https://faceswap.readthedocs.io/en/latest/?badge=latest) + Make sure you check out [INSTALL.md](INSTALL.md) before getting started. +- [deepfakes\_faceswap](#deepfakes_faceswap) + - [Important information for **Patreon** and **PayPal** supporters. Please see this forum post: https://forum.faceswap.dev/viewtopic.php?f=14\&t=3120](#important-information-for-patreon-and-paypal-supporters-please-see-this-forum-post-httpsforumfaceswapdevviewtopicphpf14t3120) - [Manifesto](#manifesto) + - [FaceSwap has ethical uses.](#faceswap-has-ethical-uses) - [How To setup and run the project](#how-to-setup-and-run-the-project) - - [Overview](#overview) +- [Overview](#overview) - [Extract](#extract) - [Train](#train) - [Convert](#convert) - [GUI](#gui) - - [General notes:](#general-notes) +- [General notes:](#general-notes) - [Help I need support!](#help-i-need-support) - [Discord Server](#discord-server) - - [Faceswap-Playground](#faceswap-playground) + - [FaceSwap Forum](#faceswap-forum) - [Donate](#donate) - - [@torzdf](#@torzdf) - - [@andenixa](#andenixa) - - [@kvrooman](#@kvrooman) + - [Patreon](#patreon) + - [One time Donations](#one-time-donations) + - [@torzdf](#torzdf) + - [@andenixa](#andenixa) - [How to contribute](#how-to-contribute) - [For people interested in the generative models](#for-people-interested-in-the-generative-models) - [For devs](#for-devs) - [For non-dev advanced users](#for-non-dev-advanced-users) - [For end-users](#for-end-users) - - [For haters](#for-haters) -- [About github.com/deepfakes](#about-githubcomdeepfakes) - - [What is this repo?](#what-is-this-repo) - - [Why this repo?](#why-this-repo) - - [Why is it named 'deepfakes' if it is not /u/deepfakes?](#why-is-it-named-deepfakes-if-it-is-not-udeepfakes) - - [What if /u/deepfakes feels bad about that?](#what-if-udeepfakes-feels-bad-about-that) - [About machine learning](#about-machine-learning) - - [How does a computer know how to recognise/shape a faces? How does machine learning work? What is a neural network?](#how-does-a-computer-know-how-to-recogniseshape-a-faces-how-does-machine-learning-work-what-is-a-neural-network) + - [How does a computer know how to recognize/shape faces? How does machine learning work? What is a neural network?](#how-does-a-computer-know-how-to-recognizeshape-faces-how-does-machine-learning-work-what-is-a-neural-network) ---- -## Manifesto +# Manifesto -### Faceswap is not porn. +## FaceSwap has ethical uses. -When faceswapping was first developed and published, the technology was groundbreaking, it was a huge step in AI development. It was also completely ignored outside of academia because the code was confusing and fragmentary. It required a thorough understanding of complicated AI techniques and took a lot of effort to figure it out. Until one individual brought it together into a single, cohesive collection. It ran, it worked, and as is so often the way with new technology emerging on the internet, it was immediately used to create porn. The problem was that this was the first AI code that anyone could download, run and learn by experimentation without having a PhD in math, computer theory, psychology, and more. Before "deepfakes" these techniques were like black magic, only practiced by those who could understand all of the inner workings as described in esoteric and endlessly complicated books and papers. +When faceswapping was first developed and published, the technology was groundbreaking, it was a huge step in AI development. It was also completely ignored outside of academia because the code was confusing and fragmentary. It required a thorough understanding of complicated AI techniques and took a lot of effort to figure it out. Until one individual brought it together into a single, cohesive collection. It ran, it worked, and as is so often the way with new technology emerging on the internet, it was immediately used to create inappropriate content. Despite the inappropriate uses the software was given originally, it was the first AI code that anyone could download, run and learn by experimentation without having a Ph.D. in math, computer theory, psychology, and more. Before "deepfakes" these techniques were like black magic, only practiced by those who could understand all of the inner workings as described in esoteric and endlessly complicated books and papers. "Deepfakes" changed all that and anyone could participate in AI development. To us, developers, the release of this code opened up a fantastic learning opportunity. It allowed us to build on ideas developed by others, collaborate with a variety of skilled coders, experiment with AI whilst learning new skills and ultimately contribute towards an emerging technology which will only see more mainstream use as it progresses. -Are there some out there doing horrible things with similar software? Yes. And because of this, the developers have been following strict ethical standards. Many of us don't even use it to create videos, we just tinker with the code to see what it does. Sadly, the media concentrates only on the unethical uses of this software. That is, unfortunately the nature of how it was first exposed to the public, but it is not representative of why it was created, how we use it now, or what we see in its future. Like any technology, it can be used for good or it can be abused. It is our intention to develop faceswap in a way that its potential for abuse is minimized whilst maximizing its potential as a tool for learning, experimenting and, yes, for legitimate faceswapping. +Are there some out there doing horrible things with similar software? Yes. And because of this, the developers have been following strict ethical standards. Many of us don't even use it to create videos, we just tinker with the code to see what it does. Sadly, the media concentrates only on the unethical uses of this software. That is, unfortunately, the nature of how it was first exposed to the public, but it is not representative of why it was created, how we use it now, or what we see in its future. Like any technology, it can be used for good or it can be abused. It is our intention to develop FaceSwap in a way that its potential for abuse is minimized whilst maximizing its potential as a tool for learning, experimenting and, yes, for legitimate faceswapping. We are not trying to denigrate celebrities or to demean anyone. We are programmers, we are engineers, we are Hollywood VFX artists, we are activists, we are hobbyists, we are human beings. To this end, we feel that it's time to come out with a standard statement of what this software is and isn't as far as us developers are concerned. -- Faceswap is not for creating porn -- Faceswap is not for changing faces without consent or with the intent of hiding its use. -- Faceswap is not for any illicit, unethical, or questionable purposes. -- Faceswap exists to experiment and discover AI techniques, for social or political commentary, for movies, and for any number of ethical and reasonable uses. +- FaceSwap is not for creating inappropriate content. +- FaceSwap is not for changing faces without consent or with the intent of hiding its use. +- FaceSwap is not for any illicit, unethical, or questionable purposes. +- FaceSwap exists to experiment and discover AI techniques, for social or political commentary, for movies, and for any number of ethical and reasonable uses. -We are very troubled by the fact that faceswap can be used for unethical and disreputable things. However, we support the development of tools and techniques that can be used ethically as well as provide education and experience in AI for anyone who wants to learn it hands-on. We will take a zero tolerance approach to anyone using this software for any unethical purposes and will actively discourage any such uses. +We are very troubled by the fact that FaceSwap can be used for unethical and disreputable things. However, we support the development of tools and techniques that can be used ethically as well as provide education and experience in AI for anyone who wants to learn it hands-on. We will take a zero tolerance approach to anyone using this software for any unethical purposes and will actively discourage any such uses. -## How To setup and run the project -Faceswap is a Python program that will run on multiple Operating Systems including Windows, Linux and MacOS. +# How To setup and run the project +FaceSwap is a Python program that will run on multiple Operating Systems including Windows, Linux, and MacOS. -See [INSTALL.md](INSTALL.md) for full installation instructions. You will need a modern GPU with CUDA support for best performance. +See [INSTALL.md](INSTALL.md) for full installation instructions. You will need a modern GPU with CUDA support for best performance. Many AMD GPUs are supported through ROCm (Linux). -## Overview +# Overview The project has multiple entry points. You will have to: - - Gather photos (or use the one provided in the training data provided below) + - Gather photos and/or videos - **Extract** faces from your raw photos - - **Train** a model on your photos (or use the one provided in the training data provided below) + - **Train** a model on the faces extracted from the photos/videos - **Convert** your sources with the model Check out [USAGE.md](USAGE.md) for more detailed instructions. -### Extract +## Extract From your setup folder, run `python faceswap.py extract`. This will take photos from `src` folder and extract faces into `extract` folder. -### Train +## Train From your setup folder, run `python faceswap.py train`. This will take photos from two folders containing pictures of both faces and train a model that will be saved inside the `models` folder. -### Convert +## Convert From your setup folder, run `python faceswap.py convert`. This will take photos from `original` folder and apply new faces into `modified` folder. -### GUI -Alternatively you can run the GUI by running `python faceswap.py gui` +## GUI +Alternatively, you can run the GUI by running `python faceswap.py gui` -## General notes: +# General notes: - All of the scripts mentioned have `-h`/`--help` options with arguments that they will accept. You're smart, you can figure out how this works, right?! -NB: there is a conversion tool for video. This can be accessed by running `python tools.py effmpeg -h`. Alternatively you can use [ffmpeg](https://www.ffmpeg.org) to convert video into photos, process images, and convert images back to video. +NB: there is a conversion tool for video. This can be accessed by running `python tools.py effmpeg -h`. Alternatively, you can use [ffmpeg](https://www.ffmpeg.org) to convert video into photos, process images, and convert images back to the video. **Some tips:** @@ -97,84 +112,67 @@ NB: there is a conversion tool for video. This can be accessed by running `pytho Reusing existing models will train much faster than starting from nothing. If there is not enough training data, start with someone who looks similar, then switch the data. -## Help I need support! -### Discord Server -Your best bet is to join the [Faceswap Discord server](https://discord.gg/FdEwxXd) where there are plenty of users willing to help. Please note that, like this repo, this is a SFW Server! +# Help I need support! +## Discord Server +Your best bet is to join the [FaceSwap Discord server](https://discord.gg/FC54sYg) where there are plenty of users willing to help. Please note that, like this repo, this is a SFW Server! -### Faceswap-Playground -Alternatively you can post questions in the [Faceswap Playground](https://github.com/deepfakes/faceswap-playground). Please do not post general support questions in this repo. +## FaceSwap Forum +Alternatively, you can post questions in the [FaceSwap Forum](https://faceswap.dev/forum). Please do not post general support questions in this repo as they are liable to be deleted without response. -## Donate -The developers work tirelessly to improve and develop faceswap. Many hours have been put in to provide the software as it is today, but this is an extremely time consuming process with no financial reward. If you enjoy using the software, please consider donating to the devs, so they can spend more time implementing improvements. +# Donate +The developers work tirelessly to improve and develop FaceSwap. Many hours have been put in to provide the software as it is today, but this is an extremely time-consuming process with no financial reward. If you enjoy using the software, please consider donating to the devs, so they can spend more time implementing improvements. -### @torzdf ### - There is very little faceswap code that hasn't been touched by torzdf. He is responsible for implementing the GUI, FAN aligner, MTCNN detector and porting the Villain, DFL-H128 and DFaker models to faceswap, as well as significantly improving many areas of the code. +## Patreon +The best way to support us is through our Patreon page: -**Bitcoin:** 385a1r9tyZpt5LyZcNk1FALTxC8ZHta7yq +[![become-a-patron](https://c5.patreon.com/external/logo/become_a_patron_button.png)](https://www.patreon.com/bePatron?u=23238350) -**Ethereum:** 0x18CBbff5fA7C78de7B949A2b0160A0d1bd649f80 +## One time Donations +Alternatively you can give a one off donation to any of our Devs: +### @torzdf + There is very little FaceSwap code that hasn't been touched by torzdf. He is responsible for implementing the GUI, FAN aligner, MTCNN detector and porting the Villain, DFL-H128 and DFaker models to FaceSwap, as well as significantly improving many areas of the code. + +**Bitcoin:** bc1qpm22suz59ylzk0j7qk5e4c7cnkjmve2rmtrnc6 + +**Ethereum:** 0xd3e954dC241B87C4E8E1A801ada485DC1d530F01 + +**Monero:** 45dLrtQZ2pkHizBpt3P3yyJKkhcFHnhfNYPMSnz3yVEbdWm3Hj6Kr5TgmGAn3Far8LVaQf1th2n3DJVTRkfeB5ZkHxWozSX **Paypal:** [![torzdf](https://www.paypalobjects.com/en_GB/i/btn/btn_donate_SM.gif)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=JZ8PP3YE9J62L) -### @andenixa ### +### @andenixa Creator of the Unbalanced and OHR models, as well as expanding various capabilities within the training process. Andenixa is currently working on new models and will take requests for donations. **Paypal:** [![andenixa](https://www.paypalobjects.com/en_GB/i/btn/btn_donate_SM.gif)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=NRVLQYGS6NWTU) -### @kvrooman ### -Responsible for consolidating the converters, adding a lot of code to fix model stability issues, and helping significantly towards making the training process more modular, kvrooman continues to be a very active contributor. +# How to contribute -**Ethereum:** 0x18CBbff5fA7C78de7B949A2b0160A0d1bd649f80 - -## How to contribute - -### For people interested in the generative models +## For people interested in the generative models - Go to the 'faceswap-model' to discuss/suggest/commit alternatives to the current algorithm. -### For devs +## For devs - Read this README entirely - Fork the repo - - Download the data with the link provided above - Play with it - Check issues with the 'dev' tag - For devs more interested in computer vision and openCV, look at issues with the 'opencv' tag. Also feel free to add your own alternatives/improvements -### For non-dev advanced users +## For non-dev advanced users - Read this README entirely - Clone the repo - - Download the data with the link provided above - Play with it - Check issues with the 'advuser' tag - - Also go to the 'faceswap-playground' repo and help others. + - Also go to the '[faceswap Forum](https://faceswap.dev/forum)' and help others. -### For end-users +## For end-users - Get the code here and play with it if you can - - You can also go to the 'faceswap-playground' repo and help or get help from others. - - Be patient. This is relatively new technology for developers as well. Much effort is already being put into making this program easy to use for the average user. It just takes time! - - **Notice** Any issue related to running the code has to be open in the 'faceswap-playground' project! - -### For haters -Sorry, no time for that. - -# About github.com/deepfakes - -## What is this repo? -It is a community repository for active users. - -## Why this repo? -The joshua-wu repo seems not active. Simple bugs like missing _http://_ in front of urls have not been solved since days. - -## Why is it named 'deepfakes' if it is not /u/deepfakes? - 1. Because a typosquat would have happened sooner or later as project grows - 2. Because we wanted to recognize the original author - 3. Because it will better federate contributors and users - -## What if /u/deepfakes feels bad about that? -This is a friendly typosquat, and it is fully dedicated to the project. If /u/deepfakes wants to take over this repo/user and drive the project, he is welcomed to do so (Raise an issue, and he will be contacted on Reddit). Please do not send /u/deepfakes messages for help with the code you find here. + - You can also go to the [faceswap Forum](https://faceswap.dev/forum) and help or get help from others. + - Be patient. This is a relatively new technology for developers as well. Much effort is already being put into making this program easy to use for the average user. It just takes time! + - **Notice** Any issue related to running the code has to be opened in the [faceswap Forum](https://faceswap.dev/forum)! # About machine learning -## How does a computer know how to recognise/shape a faces? How does machine learning work? What is a neural network? +## How does a computer know how to recognize/shape faces? How does machine learning work? What is a neural network? It's complicated. Here's a good video that makes the process understandable: [![How Machines Learn](https://img.youtube.com/vi/R9OHn5ZF4Uo/0.jpg)](https://www.youtube.com/watch?v=R9OHn5ZF4Uo) diff --git a/USAGE.md b/USAGE.md index 852e35479c..de78f5abf8 100755 --- a/USAGE.md +++ b/USAGE.md @@ -1,110 +1,192 @@ -**Before attempting any of this, please make sure you have read, understood and completed the [installation instructions](../master/INSTALL.md). If you are experiencing issues, please raise them in the [faceswap-playground](https://github.com/deepfakes/faceswap-playground) repository instead of the main repo.** - # Workflow -So, you want to swap faces in pictures and videos? Well hold up, because first you gotta understand what this collection of scripts will do, how it does it and what it can't currently do. + +**Before attempting any of this, please make sure you have read, understood and completed the [installation instructions](../master/INSTALL.md). If you are experiencing issues, please raise them in the [faceswap Forum](https://faceswap.dev/forum) or the [FaceSwap Discord server](https://discord.gg/FdEwxXd) instead of the main repo.** + +- [Workflow](#workflow) +- [Introduction](#introduction) + - [Disclaimer](#disclaimer) + - [Getting Started](#getting-started) +- [Extract](#extract) + - [Gathering raw data](#gathering-raw-data) + - [Extracting Faces](#extracting-faces) + - [General Tips](#general-tips) +- [Training a model](#training-a-model) + - [General Tips](#general-tips-1) +- [Converting a video](#converting-a-video) + - [General Tips](#general-tips-2) +- [GUI](#gui) +- [Video's](#videos) +- [EFFMPEG](#effmpeg) +- [Extracting video frames with FFMPEG](#extracting-video-frames-with-ffmpeg) +- [Generating a video](#generating-a-video) +- [Notes](#notes) + +# Introduction + +## Disclaimer +This guide provides a high level overview of the faceswapping process. It does not aim to go into every available option, but will provide a useful entry point to using the software. There are many more options available that are not covered by this guide. These can be found, and explained, by passing the `-h` flag to the command line (eg: `python faceswap.py extract -h`) or by hovering over the options within the GUI. + +## Getting Started +So, you want to swap faces in pictures and videos? Well hold up, because first you gotta understand what this application will do, how it does it and what it can't currently do. The basic operation of this script is simple. It trains a machine learning model to recognize and transform two faces based on pictures. The machine learning model is our little "bot" that we're teaching to do the actual swapping and the pictures are the "training data" that we use to train it. Note that the bot is primarily processing faces. Other objects might not work. So here's our plan. We want to create a reality where Donald Trump lost the presidency to Nic Cage; we have his inauguration video; let's replace Trump with Cage. +# Extract ## Gathering raw data -In order to accomplish this, the bot needs to learn to recognize both face A (Trump) and face B (Nic Cage). By default, the bot doesn't know what a Trump or a Nic Cage looks like. So we need to show it some pictures and let it guess which is which. So we need pictures of both of these faces first. +In order to accomplish this, the bot needs to learn to recognize both face A (Trump) and face B (Nic Cage). By default, the bot doesn't know what a Trump or a Nic Cage looks like. So we need to show it lots of pictures and let it guess which is which. So we need pictures of both of these faces first. -A possible source is Google, DuckDuckGo or Bing image search. There are scripts to download large amounts of images. Alternatively, if you have a video of the person you're looking for (from interviews, public speeches, or movies), you can convert this video to still images and use those. see [Extracting video frames](#Extracting_video_frames) for more information. +A possible source is Google, DuckDuckGo or Bing image search. There are scripts to download large amounts of images. A better source of images are videos (from interviews, public speeches, or movies) as these will capture many more natural poses and expressions. Fortunately FaceSwap has you covered and can extract faces from both still images and video files. See [Extracting video frames](#Extracting_video_frames) for more information. -Feel free to list your image sets in the [faceswap-playground](https://github.com/deepfakes/faceswap-playground), or add more methods to this file. +Feel free to list your image sets in the [faceswap Forum](https://faceswap.dev/forum), or add more methods to this file. -So now we have a folder full of pictures of Trump and a separate folder of Nic Cage. Let's save them in our directory where we put the faceswap project. Example: `~/faceswap/photo/trump` and `~/faceswap/photo/cage` +So now we have a folder full of pictures/videos of Trump and a separate folder of Nic Cage. Let's save them in our directory where we put the FaceSwap project. Example: `~/faceswap/src/trump` and `~/faceswap/src/cage` -## EXTRACT -So here's a problem. We have a ton of pictures of both our subjects, but they're just pictures of them doing stuff or in an environment with other people. Their bodies are on there, they're on there with other people... It's a mess. We can only train our bot if the data we have is consistent and focusses on the subject we want to swap. This is where faceswap first comes in. +## Extracting Faces +So here's a problem. We have a ton of pictures and videos of both our subjects, but these are just of them doing stuff or in an environment with other people. Their bodies are on there, they're on there with other people... It's a mess. We can only train our bot if the data we have is consistent and focuses on the subject we want to swap. This is where FaceSwap first comes in. +**Command Line:** ```bash -# To convert trump: -python faceswap.py extract -i ~/faceswap/photo/trump -o ~/faceswap/data/trump -# To convert cage: -python faceswap.py extract -i ~/faceswap/photo/cage -o ~/faceswap/data/cage +# To extract trump from photos in a folder: +python faceswap.py extract -i ~/faceswap/src/trump -o ~/faceswap/faces/trump +# To extract trump from a video file: +python faceswap.py extract -i ~/faceswap/src/trump.mp4 -o ~/faceswap/faces/trump +# To extract cage from photos in a folder: +python faceswap.py extract -i ~/faceswap/src/cage -o ~/faceswap/faces/cage +# To extract cage from a video file: +python faceswap.py extract -i ~/faceswap/src/cage.mp4 -o ~/faceswap/faces/cage ``` -We specify our photo input directory and the output folder where our training data will be saved. The script will then try its best to recognize face landmarks, crop the image to that size, and save it to the output folder. Note: this script will make grabbing test data much easier, but it is not perfect. It will (incorrectly) detect multiple faces in some photos and does not recognize if the face is the person who we want to swap. Therefore: **Always check your training data before you start training.** The training data will influence how good your model will be at swapping. +**GUI:** + +To extract trump from photos in a folder (Right hand folder icon): +![ExtractFolder](https://i.imgur.com/H3h0k36.jpg) + +To extract cage from a video file (Left hand folder icon): +![ExtractVideo](https://i.imgur.com/TK02F0u.jpg) + +For input we either specify our photo directory or video file and for output we specify the folder where our extracted faces will be saved. The script will then try its best to recognize face landmarks, crop the images to a consistent size, and save the faces to the output folder. An `alignments.json` file will also be created and saved into your input folder. This file contains information about each of the faces that will be used by FaceSwap. + +Note: this script will make grabbing test data much easier, but it is not perfect. It will (incorrectly) detect multiple faces in some photos and does not recognize if the face is the person whom we want to swap. Therefore: **Always check your training data before you start training.** The training data will influence how good your model will be at swapping. -You can see the full list of arguments for extracting via help flag. i.e. +## General Tips +When extracting faces for training, you are looking to gather around 500 to 5000 faces for each subject you wish to train. These should be of a high quality and contain a wide variety of angles, expressions and lighting conditions. +You do not want to extract every single frame from a video for training as from frame to frame the faces will be very similar. + +You can see the full list of arguments for extracting by hovering over the options in the GUI or passing the help flag. i.e: ```bash python faceswap.py extract -h ``` -Some of the plugins have configurable options. You can find the config options in: `\plugins\extract\config.ini`. Extract needs to have been run at least once to generate the config file +Some of the plugins have configurable options. You can find the config options in: `\config\extract.ini`. You will need to have run Extract or the GUI at least once for this file to be generated. + +# Training a model +Ok, now you have a folder full of Trump faces and a folder full of Cage faces. What now? It's time to train our bot! This creates a 'model' that contains information about what a Cage is and what a Trump is and how to swap between the two. + +The training process will take the longest, how long depends on many factors; the model used, the number of images, your GPU etc. However, a ballpark figure is 12-48 hours on GPU and weeks if training on CPU. -## TRAIN -The training process will take the longest, especially on CPU. We specify the folders where the two faces are, and where we will save our training model. It will start hammering the training data once you run the command. I personally really like to go by the preview and quit the processing once I'm happy with the results. +We specify the folders where the two faces are, and where we will save our training model. +**Command Line:** ```bash -python faceswap.py train -A ~/faceswap/data/trump -B ~/faceswap/data/cage -m ~/faceswap/models/ +python faceswap.py train -A ~/faceswap/faces/trump -B ~/faceswap/faces/cage -m ~/faceswap/trump_cage_model/ # or -p to show a preview -python faceswap.py train -A ~/faceswap/data/trump -B ~/faceswap/data/cage -m ~/faceswap/models/ -p +python faceswap.py train -A ~/faceswap/faces/trump -B ~/faceswap/faces/cage -m ~/faceswap/trump_cage_model/ -p ``` +**GUI:** + +![Training](https://i.imgur.com/j8bjk4I.jpg) + +Once you run the command, it will start hammering the training data. If you have a preview up, then you will see a load of blotches appear. These are the faces it is learning. They don't look like much, but then your model hasn't learned anything yet. Over time these will more and more start to resemble trump and cage. + +You want to leave your model learning until you are happy with the images in the preview. To stop training you can: +- Command Line: press "Enter" in the preview window or in the console +- GUI: Press the Terminate button -If you use the preview feature, select the preview window and press ENTER to save your processed data and quit gracefully. Without the preview enabled, you might have to forcefully quit by hitting Ctrl+C to cancel the command. Note that it will save the model once it's gone through about 100 iterations, which can take quite a while. So make sure you save before stopping the process. +When stopping training, the model will save and the process will exit. This can take a little while, so be patient. The model will also save every 100 iterations or so. -You can see the full list of arguments for training via help flag. i.e. +You can stop and resume training at any time. Just point FaceSwap at the same folders and carry on. + +## General Tips +If you are training with a mask or using Warp to Landmarks, you will need to pass in an `alignments.json` file for each of the face sets. See [Extract - General Tips](#general-tips) for more information. + +The model is automatically backed up at every save iteration where the overall loss has dropped (i.e. the model has improved). If your model corrupts for some reason, you can go into the model folder and remove the `.bk` extension from the backups to restore the model from backup. + +You can see the full list of arguments for training by hovering over the options in the GUI or passing the help flag. i.e: ```bash python faceswap.py train -h ``` -Some of the plugins have configurable options. You can find the config options in: `\plugins\traom\config.ini`. Train needs to have been run at least once to generate the config file +Some of the plugins have configurable options. You can find the config options in: `\config\train.ini`. You will need to have run Train or the GUI at least once for this file to be generated. + + +# Converting a video +Now that we're happy with our trained model, we can convert our video. How does it work? + +Well firstly we need to generate an `alignments.json` file for our swap. To do this, follow the steps in [Extracting Faces](#extracting-faces), only this time you want to run extract for every face in your source video. This file tells the convert process where the face is on the source frame. +You are likely going to want to cleanup your alignments file, by deleting false positives, badly aligned faces etc. These will not look good on your final convert. There are tools to help with this. -## CONVERT -Now that we're happy with our trained model, we can convert our video. How does it work? Similarly to the extraction script, actually! The conversion script basically detects a face in a picture using the same algorithm, quickly crops the image to the right size, runs our bot on this cropped image of the face it has found, and then (crudely) pastes the processed face back into the picture. +Just like extract you can convert from a series of images or from a video file. Remember those initial pictures we had of Trump? Let's try swapping a face there. We will use that directory as our input directory, create a new folder where the output will be saved, and tell them which model to use. +**Command Line:** ```bash -python faceswap.py convert -i ~/faceswap/photo/trump/ -o ~/faceswap/output/ -m ~/faceswap/models/ +python faceswap.py convert -i ~/faceswap/src/trump/ -o ~/faceswap/converted/ -m ~/faceswap/trump_cage_model/ ``` +**GUI:** + +![convert](https://i.imgur.com/GzX1ME2.jpg) + It should now start swapping faces of all these pictures. -You can see the full list of arguments available for converting via help flag. i.e. + +## General Tips +You can see the full list of arguments for Converting by hovering over the options in the GUI or passing the help flag. i.e: ```bash python faceswap.py convert -h ``` -## GUI +Some of the plugins have configurable options. You can find the config options in: `\config\convert.ini`. You will need to have run Convert or the GUI at least once for this file to be generated. + +# GUI All of the above commands and options can be run from the GUI. This is launched with: ```bash python faceswap.py gui ``` +The GUI allows a more user friendly interface into the scripts and also has some extended functionality. Hovering over options in the GUI will tell you more about what the option does. - -## Video's +# Video's A video is just a series of pictures in the form of frames. Therefore you can gather the raw images from them for your dataset or combine your results into a video. -## EFFMPEG -You can perform various video processes with the built in effmpeg tool. You can see the full list of arguments available by running: +# EFFMPEG +You can perform various video processes with the built-in effmpeg tool. You can see the full list of arguments available by running: ```bash python tools.py effmpeg -h ``` -## Extracting video frames with FFMPEG -Alternatively you can split a video into separate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to separate frames. +# Extracting video frames with FFMPEG +Alternatively, you can split a video into separate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to separate frames. ```bash ffmpeg -i /path/to/my/video.mp4 /path/to/output/video-frame-%d.png ``` -## Generating a video +# Generating a video If you split a video, using [ffmpeg](https://www.ffmpeg.org) for example, and used them as a target for swapping faces onto you can combine these frames again. The command below stitches the png frames back into a single video again. ```bash ffmpeg -i video-frame-%0d.png -c:v libx264 -vf "fps=25,format=yuv420p" out.mp4 ``` -## Notes +# Notes This guide is far from complete. Functionality may change over time, and new dependencies are added and removed as time goes on. -If you are experiencing issues, please raise them in the [faceswap-playground](https://github.com/deepfakes/faceswap-playground) repository instead of the main repo. +If you are experiencing issues, please raise them in the [faceswap Forum](https://faceswap.dev/forum) or the [FaceSwap Discord server](https://discord.gg/FdEwxXd). Usage questions raised in this repo are likely to be closed without response. diff --git a/_config.yml b/_config.yml new file mode 100644 index 0000000000..c4192631f2 --- /dev/null +++ b/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/docs/_static/logo.png b/docs/_static/logo.png new file mode 100755 index 0000000000..fc26247981 Binary files /dev/null and b/docs/_static/logo.png differ diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000..0e46e443ca --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,128 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +# NOTE: To generate docs: +# $ cd docs +# $ rm -rf _build api +# $ python -m sphinx -T -b html -d _build/doctrees -D language=en . _build/output/html + +# pylint:skip-file +import logging +import os +import sys +from unittest import mock + +os.environ["FACESWAP_BACKEND"] = "cpu" +os.environ["KERAS_BACKEND"] = "torch" + +sys.path.insert(0, os.path.abspath('../')) +sys.setrecursionlimit(1500) + + +MOCK_MODULES = ["pynvml", "ctypes.windll", "comtypes"] +for mod_name in MOCK_MODULES: + sys.modules[mod_name] = mock.Mock() + +# -- Project information ----------------------------------------------------- + +project = 'faceswap' +copyright = '2025, faceswap.dev' +author = 'faceswap.dev' + +# The full version, including alpha/beta/rc tags +release = '3.0' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ['sphinx.ext.napoleon', "sphinx_automodapi.automodapi"] +napoleon_custom_sections = ['License'] +numpydoc_show_class_members = False +automodsumm_inherited_members = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_theme_options = { + 'analytics_id': 'UA-145659566-2', + 'logo_only': True, + # Toc options + 'navigation_depth': -1, +} +html_logo = '_static/logo.png' +latext_logo = '_static/logo.png' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +master_doc = 'index' + +# Suppress warnings from all 3rd party libraries +_suppressed_warning_count = 0 + + +def _suppress_third_party_warnings(): + """ Override Sphinx logging to ignore any warnings generated by 3rd party libraries """ + skip = ["lib/python", "site-packages", # system packages/python lib + ".variables", ".non_trainable_variables"] # keras layer inheritance + root = logging.getLogger("sphinx") + for handler in root.handlers: + orig_emit = handler.emit + + def make_filtered_emit(orig_emit): + + def filtered_emit(record): + if record.levelname in ("WARNING", "ERROR"): + try: + msg = record.getMessage() + except TypeError: + orig_emit(record) + return + loc = getattr(record, "location", "") + if any(x in msg or x in str(loc) for x in skip): + global _suppressed_warning_count + _suppressed_warning_count += 1 + return + orig_emit(record) + return filtered_emit + handler.emit = make_filtered_emit(orig_emit) + + +def _on_build_finish(app, exception): + """ Subtract our suppressed warnings from the total warnings count """ + if hasattr(app, "_warncount") and _suppressed_warning_count: + setattr(app, "_warncount", max(0, + getattr(app, + "_warncount", 0) - _suppressed_warning_count)) + + +def setup(app): + """ Install our warnings filter and capture suppressed counts """ + _suppress_third_party_warnings() + app.connect("build-finished", _on_build_finish) diff --git a/docs/full/lib/align.rst b/docs/full/lib/align.rst new file mode 100644 index 0000000000..1b1a5d4479 --- /dev/null +++ b/docs/full/lib/align.rst @@ -0,0 +1,45 @@ +***************** +lib.align package +***************** + +The align Package handles detected faces, their alignments and masks. + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: lib.align.aligned_face + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.align.aligned_mask + :include-all-objects: + +| +.. automodapi:: lib.align.alignments + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.align.constants + :include-all-objects: + +| +.. automodapi:: lib.align.detected_face + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.align.pose + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.align.thumbnails + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.align.updater + :include-all-objects: diff --git a/docs/full/lib/cli.rst b/docs/full/lib/cli.rst new file mode 100644 index 0000000000..4ffca57193 --- /dev/null +++ b/docs/full/lib/cli.rst @@ -0,0 +1,25 @@ +*************** +lib.cli package +*************** + +The CLI Package handles the Command Line Arguments that act as the entry point into Faceswap. + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: lib.cli.actions + :include-all-objects: + +.. automodapi:: lib.cli.args_extract_convert + :include-all-objects: + +.. automodapi:: lib.cli.args_train + :include-all-objects: + +.. automodapi:: lib.cli.args + :include-all-objects: + +.. automodapi:: lib.cli.launcher + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/config.rst b/docs/full/lib/config.rst new file mode 100755 index 0000000000..36aacfb202 --- /dev/null +++ b/docs/full/lib/config.rst @@ -0,0 +1,23 @@ +****************** +lib.config package +****************** + +Holds, validates and handles faceswap configuration items, ensuring type correctness. Handles +interfacing with saved config .ini files + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: lib.config.config + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.config.ini + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.config.objects + :include-all-objects: diff --git a/docs/full/lib/convert.rst b/docs/full/lib/convert.rst new file mode 100755 index 0000000000..b01c3adc82 --- /dev/null +++ b/docs/full/lib/convert.rst @@ -0,0 +1,3 @@ +.. automodapi:: lib.convert + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/git.rst b/docs/full/lib/git.rst new file mode 100644 index 0000000000..55ccdc06b1 --- /dev/null +++ b/docs/full/lib/git.rst @@ -0,0 +1,3 @@ +.. automodapi:: lib.git + :include-all-objects: + :no-inheritance-diagram: \ No newline at end of file diff --git a/docs/full/lib/gpu_stats.rst b/docs/full/lib/gpu_stats.rst new file mode 100755 index 0000000000..a895857be2 --- /dev/null +++ b/docs/full/lib/gpu_stats.rst @@ -0,0 +1,23 @@ +********************** +lib.gpu\_stats package +********************** + +The GPU Stats Package handles collection of information from connected GPUs + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: lib.gpu_stats.apple_silicon + :include-all-objects: + +.. automodapi:: lib.gpu_stats.cpu + :include-all-objects: + +| +.. automodapi:: lib.gpu_stats.nvidia + :include-all-objects: + +| +.. automodapi:: lib.gpu_stats.rocm + :include-all-objects: diff --git a/docs/full/lib/gui.rst b/docs/full/lib/gui.rst new file mode 100755 index 0000000000..dd4f1f050e --- /dev/null +++ b/docs/full/lib/gui.rst @@ -0,0 +1,120 @@ +*************** +lib.gui package +*************** + +The GUI Package contains the entire code base for Faceswap's optional GUI. The GUI itself +is largely self-generated from the command line options specified in :mod:`lib.cli.args`. + +.. contents:: Contents + :local: + :depth: 2 + +analysis package +================ + +.. automodapi:: lib.gui.analysis.event_reader + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.analysis.stats + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.analysis.moving_average + :include-all-objects: + :no-inheritance-diagram: + +utils package +============= + +| +.. automodapi:: lib.gui.utils.config + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.utils.file_handler + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.utils.image + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.utils.misc + :include-all-objects: + + +gui package +=========== + +| +.. automodapi:: lib.gui.gui_config + :include-all-objects: + +| +.. automodapi:: lib.gui.command + :include-all-objects: + +| +.. automodapi:: lib.gui.control_helper + :include-all-objects: + +| +.. automodapi:: lib.gui.custom_widgets + :include-all-objects: + +| +.. automodapi:: lib.gui.display + :include-all-objects: + +| +.. automodapi:: lib.gui.display_analysis + :include-all-objects: + +| +.. automodapi:: lib.gui.display_command + :include-all-objects: + +| +.. automodapi:: lib.gui.display_graph + :include-all-objects: + +| +.. automodapi:: lib.gui.display_page + :include-all-objects: + +| +.. automodapi:: lib.gui.menu + :include-all-objects: + +| +.. automodapi:: lib.gui.options + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.popup_configure + :include-all-objects: + +| +.. automodapi:: lib.gui.popup_session + :include-all-objects: + +| +.. automodapi:: lib.gui.project + :include-all-objects: + +| +.. automodapi:: lib.gui.theme + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.gui.wrapper + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/image.rst b/docs/full/lib/image.rst new file mode 100755 index 0000000000..6f7fffd075 --- /dev/null +++ b/docs/full/lib/image.rst @@ -0,0 +1,2 @@ +.. automodapi:: lib.image + :include-all-objects: diff --git a/docs/full/lib/keras_utils.rst b/docs/full/lib/keras_utils.rst new file mode 100644 index 0000000000..03a049a76d --- /dev/null +++ b/docs/full/lib/keras_utils.rst @@ -0,0 +1,3 @@ +.. automodapi:: lib.keras_utils + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/keypress.rst b/docs/full/lib/keypress.rst new file mode 100644 index 0000000000..16fc211302 --- /dev/null +++ b/docs/full/lib/keypress.rst @@ -0,0 +1,2 @@ +.. automodapi:: lib.keypress + :include-all-objects: \ No newline at end of file diff --git a/docs/full/lib/lib.rst b/docs/full/lib/lib.rst new file mode 100644 index 0000000000..4f20a3c3dc --- /dev/null +++ b/docs/full/lib/lib.rst @@ -0,0 +1,10 @@ +lib package +=========== + +The lib package holds core functionality used throughout Faceswap. + +.. toctree:: + :maxdepth: 2 + :glob: + + * diff --git a/docs/full/lib/logger.rst b/docs/full/lib/logger.rst new file mode 100755 index 0000000000..9f69671d82 --- /dev/null +++ b/docs/full/lib/logger.rst @@ -0,0 +1,3 @@ +.. automodapi:: lib.logger + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/model.rst b/docs/full/lib/model.rst new file mode 100755 index 0000000000..bd5dea4224 --- /dev/null +++ b/docs/full/lib/model.rst @@ -0,0 +1,66 @@ +***************** +lib.model package +***************** +The Model Package handles interfacing with the neural network backend and holds custom objects. + +.. contents:: Contents + :local: + :depth: 2 + +losses package +============== + +.. automodapi:: lib.model.losses.feature_loss + :include-all-objects: + +| +.. automodapi:: lib.model.losses.loss + :include-all-objects: + +| +.. automodapi:: lib.model.losses.perceptual_loss + :include-all-objects: + +networks package +================ + +.. automodapi:: lib.model.networks.clip + :include-all-objects: + :noindex: + +| +.. automodapi:: lib.model.networks.simple_nets + :include-all-objects: + +model package +============= + +.. automodapi:: lib.model.autoclip + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.model.backup_restore + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.model.initializers + :include-all-objects: + +| +.. automodapi:: lib.model.layers + :include-all-objects: + +| +.. automodapi:: lib.model.nn_blocks + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.model.normalization + :include-all-objects: + +| +.. automodapi:: lib.model.optimizers + :include-all-objects: diff --git a/docs/full/lib/multithreading.rst b/docs/full/lib/multithreading.rst new file mode 100644 index 0000000000..b99a029a2f --- /dev/null +++ b/docs/full/lib/multithreading.rst @@ -0,0 +1,2 @@ +.. automodapi:: lib.multithreading + :include-all-objects: diff --git a/docs/full/lib/queue_manager.rst b/docs/full/lib/queue_manager.rst new file mode 100755 index 0000000000..9021183da6 --- /dev/null +++ b/docs/full/lib/queue_manager.rst @@ -0,0 +1,2 @@ +.. automodapi:: lib.queue_manager + :include-all-objects: diff --git a/docs/full/lib/serializer.rst b/docs/full/lib/serializer.rst new file mode 100755 index 0000000000..04b177993d --- /dev/null +++ b/docs/full/lib/serializer.rst @@ -0,0 +1,3 @@ +.. automodapi:: lib.serializer + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/system.rst b/docs/full/lib/system.rst new file mode 100644 index 0000000000..25da19aae4 --- /dev/null +++ b/docs/full/lib/system.rst @@ -0,0 +1,23 @@ +****************** +lib.system package +****************** + +The System Package handles collecting information about the running system + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: lib.system.ml_libs + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.system.sysinfo + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.system.system + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/lib/training.rst b/docs/full/lib/training.rst new file mode 100644 index 0000000000..b47349e087 --- /dev/null +++ b/docs/full/lib/training.rst @@ -0,0 +1,43 @@ +********************* +lib.training package +********************* + +The training Package handles libraries to assist with training a model + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: lib.training.augmentation + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.training.cache + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.training.generator + :include-all-objects: + +| +.. automodapi:: lib.training.lr_finder + :include-all-objects: + +| +.. automodapi:: lib.training.lr_warmup + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: lib.training.preview_cv + :include-all-objects: + +| +.. automodapi:: lib.training.preview_tk + :include-all-objects: + +| +.. automodapi:: lib.training.tensorboard + :include-all-objects: diff --git a/docs/full/lib/utils.rst b/docs/full/lib/utils.rst new file mode 100755 index 0000000000..237dc2ab5a --- /dev/null +++ b/docs/full/lib/utils.rst @@ -0,0 +1,3 @@ +.. automodapi:: lib.utils + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/modules.rst b/docs/full/modules.rst new file mode 100644 index 0000000000..1286cb4a7e --- /dev/null +++ b/docs/full/modules.rst @@ -0,0 +1,12 @@ +faceswap +======== + +.. toctree:: + :maxdepth: 3 + + lib/lib + plugins/plugins + scripts + tools/tools + setup + update_deps diff --git a/docs/full/plugins/convert.rst b/docs/full/plugins/convert.rst new file mode 100755 index 0000000000..845dbd3cbb --- /dev/null +++ b/docs/full/plugins/convert.rst @@ -0,0 +1,66 @@ +*************** +convert package +*************** + +The Convert Package handles the various plugins available for performing conversion in Faceswap + +.. contents:: Contents + :local: + :depth: 2 + +colour package +============== + +.. automodapi:: plugins.convert.color.avg_color + :include-all-objects: + +| +.. automodapi:: plugins.convert.color.color_transfer + :include-all-objects: + +| +.. automodapi:: plugins.convert.color.manual_balance + :include-all-objects: + +| +.. automodapi:: plugins.convert.color.match_hist + :include-all-objects: + +| +.. automodapi:: plugins.convert.color.seamless_clone + :include-all-objects: + +mask package +============ + +.. automodapi:: plugins.convert.mask.mask_blend + :include-all-objects: + :no-inheritance-diagram: + +scaling package +=============== + +.. automodapi:: plugins.convert.scaling.sharpen + :include-all-objects: + +writer package +============== + +.. automodapi:: plugins.convert.writer.ffmpeg + :include-all-objects: + +| +.. automodapi:: plugins.convert.writer.gif + :include-all-objects: + +| +.. automodapi:: plugins.convert.writer.opencv + :include-all-objects: + +| +.. automodapi:: plugins.convert.writer.patch + :include-all-objects: + +| +.. automodapi:: plugins.convert.writer.pillow + :include-all-objects: diff --git a/docs/full/plugins/extract.rst b/docs/full/plugins/extract.rst new file mode 100755 index 0000000000..6d990623be --- /dev/null +++ b/docs/full/plugins/extract.rst @@ -0,0 +1,94 @@ +*************** +extract package +*************** + +The Extract Package handles the various plugins available for extracting face sets in Faceswap. + +.. contents:: Contents + :local: + :depth: 2 + +align package +============= + +.. automodapi:: plugins.extract.align.cv2_dnn + :include-all-objects: + +| +.. automodapi:: plugins.extract.align.external + :include-all-objects: + +| +.. automodapi:: plugins.extract.align.fan + :include-all-objects: + +detect package +============== + +.. automodapi:: plugins.extract.detect.cv2_dnn + :include-all-objects: + +| +.. automodapi:: plugins.extract.detect.external + :include-all-objects: + +| +.. automodapi:: plugins.extract.detect.mtcnn + :include-all-objects: + +| +.. automodapi:: plugins.extract.detect.s3fd + :include-all-objects: + +mask package +============ + +.. automodapi:: plugins.extract.mask.bisenet_fp + :include-all-objects: + +| +.. automodapi:: plugins.extract.mask.components + :include-all-objects: + +| +.. automodapi:: plugins.extract.mask.custom + :include-all-objects: + +| +.. automodapi:: plugins.extract.mask.extended + :include-all-objects: + +| +.. automodapi:: plugins.extract.mask.unet_dfl + :include-all-objects: + +| +.. automodapi:: plugins.extract.mask.vgg_clear + :include-all-objects: + +| +.. automodapi:: plugins.extract.mask.vgg_obstructed + :include-all-objects: + +recognition package +=================== + +.. automodapi:: plugins.extract.recognition.vgg_face2 + :include-all-objects: + +extract package +=============== + +.. automodapi:: plugins.extract.extract_config + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.extract.extract_media + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.extract.pipeline + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/plugins/plugin_loader.rst b/docs/full/plugins/plugin_loader.rst new file mode 100755 index 0000000000..bf42d393ce --- /dev/null +++ b/docs/full/plugins/plugin_loader.rst @@ -0,0 +1,2 @@ +.. automodapi:: plugins.plugin_loader + :include-all-objects: diff --git a/docs/full/plugins/plugins.rst b/docs/full/plugins/plugins.rst new file mode 100644 index 0000000000..70f8ca69b6 --- /dev/null +++ b/docs/full/plugins/plugins.rst @@ -0,0 +1,11 @@ +plugins package +=============== + +The plugins package holds Extraction, Training and Conversion plugins for Faceswap. + +.. toctree:: + :maxdepth: 3 + :glob: + + * + diff --git a/docs/full/plugins/train.rst b/docs/full/plugins/train.rst new file mode 100755 index 0000000000..51d6b55d76 --- /dev/null +++ b/docs/full/plugins/train.rst @@ -0,0 +1,72 @@ +************* +train package +************* + +The Train Package handles the Model and Trainer plugins for training models in Faceswap. + +.. contents:: Contents + :local: + +model package +============= + +This package contains various helper functions that plugins can inherit from + +.. automodapi:: plugins.train.model._base.inference + :include-all-objects: + :no-inheritance-diagram: + +.. automodapi:: plugins.train.model._base.io + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.model._base.model + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.model._base.settings + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.model._base.state + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.model._base.update + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.model.original + :include-all-objects: + + +trainer package +=============== + +This package contains the training loop for Faceswap + +.. automodapi:: plugins.train.trainer._base + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.trainer._display + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: plugins.train.trainer.distributed + :include-all-objects: + +| +.. automodapi:: plugins.train.trainer.original + :include-all-objects: + +| +.. automodapi:: plugins.train.trainer.trainer_config + :include-all-objects: diff --git a/docs/full/scripts.rst b/docs/full/scripts.rst new file mode 100644 index 0000000000..c620a8e46a --- /dev/null +++ b/docs/full/scripts.rst @@ -0,0 +1,26 @@ +*************** +scripts package +*************** + +The Scripts Package is the entry point into Faceswap. + +.. contents:: Contents + :local: + +.. automodapi:: scripts.convert + :include-all-objects: + :no-inheritance-diagram: + +.. automodapi:: scripts.extract + :include-all-objects: + :no-inheritance-diagram: + +.. automodapi:: scripts.fsmedia + :include-all-objects: + +.. automodapi:: scripts.gui + :include-all-objects: + +.. automodapi:: scripts.train + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/setup.rst b/docs/full/setup.rst new file mode 100644 index 0000000000..c0419ad122 --- /dev/null +++ b/docs/full/setup.rst @@ -0,0 +1,3 @@ +.. automodapi:: setup + :include-all-objects: + :no-inheritance-diagram: \ No newline at end of file diff --git a/docs/full/tools/alignments.rst b/docs/full/tools/alignments.rst new file mode 100644 index 0000000000..f91cce2733 --- /dev/null +++ b/docs/full/tools/alignments.rst @@ -0,0 +1,35 @@ +************************ +tools.alignments package +************************ + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: tools.alignments.alignments + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.alignments.cli + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.alignments.jobs + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.alignments.jobs_faces + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.alignments.jobs_frames + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.alignments.media + :include-all-objects: diff --git a/docs/full/tools/ffmpeg.rst b/docs/full/tools/ffmpeg.rst new file mode 100644 index 0000000000..f370594827 --- /dev/null +++ b/docs/full/tools/ffmpeg.rst @@ -0,0 +1,15 @@ +********************* +tools.effmpeg package +********************* + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: tools.effmpeg.cli + :include-all-objects: + +| +.. automodapi:: tools.effmpeg.effmpeg + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/tools/manual.rst b/docs/full/tools/manual.rst new file mode 100644 index 0000000000..859267a88a --- /dev/null +++ b/docs/full/tools/manual.rst @@ -0,0 +1,75 @@ +******************** +tools.manual package +******************** + +.. contents:: Contents + :local: + :depth: 2 + +manual.faceviewer package +========================= + +.. automodapi:: tools.manual.faceviewer.frame + :include-all-objects: + +| +.. automodapi:: tools.manual.faceviewer.interact + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.manual.faceviewer.viewport + :include-all-objects: + :no-inheritance-diagram: + +manual.frameviewer package +========================== + +.. automodapi:: tools.manual.frameviewer.control + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.manual.frameviewer.frame + :include-all-objects: + +| +.. automodapi:: tools.manual.frameviewer.editor.bounding_box + :include-all-objects: + +| +.. automodapi:: tools.manual.frameviewer.editor.extract_box + :include-all-objects: + +| +.. automodapi:: tools.manual.frameviewer.editor.landmarks + :include-all-objects: + +| +.. automodapi:: tools.manual.frameviewer.editor.mask + :include-all-objects: + +manual package +========================== + +.. automodapi:: tools.manual.cli + :include-all-objects: + +| +.. automodapi:: tools.manual.detected_faces + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.manual.globals + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.manual.manual + :include-all-objects: + +| +.. automodapi:: tools.manual.thumbnails + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/tools/mask.rst b/docs/full/tools/mask.rst new file mode 100644 index 0000000000..963cf95b3f --- /dev/null +++ b/docs/full/tools/mask.rst @@ -0,0 +1,35 @@ +****************** +tools.mask package +****************** + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: tools.mask.cli + :include-all-objects: + +| +.. automodapi:: tools.mask.loader + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.mask.mask + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.mask.mask_generate + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.mask.mask_import + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.mask.mask_output + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/tools/model.rst b/docs/full/tools/model.rst new file mode 100644 index 0000000000..3d59937855 --- /dev/null +++ b/docs/full/tools/model.rst @@ -0,0 +1,15 @@ +******************* +tools.model package +******************* + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: tools.model.cli + :include-all-objects: + +| +.. automodapi:: tools.model.model + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/full/tools/preview.rst b/docs/full/tools/preview.rst new file mode 100644 index 0000000000..350c953d84 --- /dev/null +++ b/docs/full/tools/preview.rst @@ -0,0 +1,22 @@ +********************* +tools.preview package +********************* + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: tools.preview.cli + :include-all-objects: + +| +.. automodapi:: tools.preview.control_panels + :include-all-objects: + +| +.. automodapi:: tools.preview.preview + :include-all-objects: + +| +.. automodapi:: tools.preview.viewer + :include-all-objects: diff --git a/docs/full/tools/sort.rst b/docs/full/tools/sort.rst new file mode 100644 index 0000000000..05aae7ec7e --- /dev/null +++ b/docs/full/tools/sort.rst @@ -0,0 +1,23 @@ +************ +sort package +************ + +.. contents:: Contents + :local: + :depth: 2 + +.. automodapi:: tools.sort.cli + :include-all-objects: + +| +.. automodapi:: tools.sort.sort + :include-all-objects: + :no-inheritance-diagram: + +| +.. automodapi:: tools.sort.sort_methods + :include-all-objects: + +| +.. automodapi:: tools.sort.sort_methods_aligned + :include-all-objects: diff --git a/docs/full/tools/tools.rst b/docs/full/tools/tools.rst new file mode 100644 index 0000000000..9e02fcbc6f --- /dev/null +++ b/docs/full/tools/tools.rst @@ -0,0 +1,11 @@ +************* +tools package +************* + +The Tools Package provides various tools for working with Faceswap outside of the core functionality. + +.. toctree:: + :maxdepth: 3 + :glob: + + * diff --git a/docs/full/update_deps.rst b/docs/full/update_deps.rst new file mode 100644 index 0000000000..be3d11dc52 --- /dev/null +++ b/docs/full/update_deps.rst @@ -0,0 +1,3 @@ +.. automodapi:: update_deps + :include-all-objects: + :no-inheritance-diagram: diff --git a/docs/index.rst b/docs/index.rst new file mode 100755 index 0000000000..511d36b598 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,21 @@ +.. faceswap documentation master file, created by + sphinx-quickstart on Fri Sep 13 11:28:50 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +faceswap.dev Developer Documentation +==================================== + +.. toctree:: + :maxdepth: 4 + :caption: Contents: + + full/modules + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/sphinx_requirements.txt b/docs/sphinx_requirements.txt new file mode 100755 index 0000000000..c1699c198f --- /dev/null +++ b/docs/sphinx_requirements.txt @@ -0,0 +1,6 @@ +# NB Do not install from this requirements file +# It is for documentation purposes only +-r ../requirements/requirements_cpu.txt +-r ../requirements/_requirements_dev.txt +sphinx_rtd_theme +sphinx-automodapi diff --git a/faceswap.py b/faceswap.py index 89d9514eeb..6fb2f06b39 100755 --- a/faceswap.py +++ b/faceswap.py @@ -1,36 +1,61 @@ #!/usr/bin/env python3 """ The master faceswap.py script """ +import gettext +import locale +import os import sys -import lib.cli as cli +# Translations don't work by default in Windows, so hack in environment variable +if sys.platform.startswith("win"): + import ctypes + windll = ctypes.windll.kernel32 + os.environ["LANG"] = locale.windows_locale[windll.GetUserDefaultUILanguage()] -if sys.version_info[0] < 3: - raise Exception("This program requires at least python3.2") -if sys.version_info[0] == 3 and sys.version_info[1] < 2: - raise Exception("This program requires at least python3.2") +from lib.cli import args as cli_args # pylint:disable=wrong-import-position +from lib.cli.args_train import TrainArgs # pylint:disable=wrong-import-position +from lib.cli.args_extract_convert import ConvertArgs, ExtractArgs # noqa:E501 pylint:disable=wrong-import-position +from lib.config import generate_configs # pylint:disable=wrong-import-position +from lib.system import System # pylint:disable=wrong-import-position +# LOCALES +_LANG = gettext.translation("faceswap", localedir="locales", fallback=True) +_ = _LANG.gettext -def bad_args(args): - """ Print help on bad arguments """ - PARSER.print_help() - exit(0) +system = System() +system.validate_python() + +_PARSER = cli_args.FullHelpArgumentParser() + + +def _bad_args(*args) -> None: # pylint:disable=unused-argument + """ Print help to console when bad arguments are provided. """ + print(cli_args) + _PARSER.print_help() + sys.exit(0) + + +def _main() -> None: + """ The main entry point into Faceswap. + + - Generates the config files, if they don't pre-exist. + - Compiles the :class:`~lib.cli.args.FullHelpArgumentParser` objects for each section of + Faceswap. + - Sets the default values and launches the relevant script. + - Outputs help if invalid parameters are provided. + """ + generate_configs() + + subparser = _PARSER.add_subparsers() + ExtractArgs(subparser, "extract", _("Extract the faces from pictures or a video")) + TrainArgs(subparser, "train", _("Train a model for the two faces A and B")) + ConvertArgs(subparser, + "convert", + _("Convert source pictures or video to a new one with the face swapped")) + cli_args.GuiArgs(subparser, "gui", _("Launch the Faceswap Graphical User Interface")) + _PARSER.set_defaults(func=_bad_args) + arguments = _PARSER.parse_args() + arguments.func(arguments) if __name__ == "__main__": - PARSER = cli.FullHelpArgumentParser() - SUBPARSER = PARSER.add_subparsers() - EXTRACT = cli.ExtractArgs(SUBPARSER, - "extract", - "Extract the faces from pictures") - TRAIN = cli.TrainArgs(SUBPARSER, - "train", - "This command trains the model for the two faces A and B") - CONVERT = cli.ConvertArgs(SUBPARSER, - "convert", - "Convert a source image to a new one with the face swapped") - GUI = cli.GuiArgs(SUBPARSER, - "gui", - "Launch the Faceswap Graphical User Interface") - PARSER.set_defaults(func=bad_args) - ARGUMENTS = PARSER.parse_args() - ARGUMENTS.func(ARGUMENTS) + _main() diff --git a/lib/Serializer.py b/lib/Serializer.py deleted file mode 100644 index 23a01d624f..0000000000 --- a/lib/Serializer.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -""" -Library providing convenient classes and methods for writing data to files. -""" -import logging -import json -import pickle - -try: - import yaml -except ImportError: - yaml = None - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Serializer(): - """ Parent Serializer class """ - ext = "" - woptions = "" - roptions = "" - - @classmethod - def marshal(cls, input_data): - """ Override for marshalling """ - raise NotImplementedError() - - @classmethod - def unmarshal(cls, input_string): - """ Override for unmarshalling """ - raise NotImplementedError() - - -class YAMLSerializer(Serializer): - """ YAML Serializer """ - ext = "yml" - woptions = "w" - roptions = "r" - - @classmethod - def marshal(cls, input_data): - return yaml.dump(input_data, default_flow_style=False) - - @classmethod - def unmarshal(cls, input_string): - return yaml.load(input_string) - - -class JSONSerializer(Serializer): - """ JSON Serializer """ - ext = "json" - woptions = "w" - roptions = "r" - - @classmethod - def marshal(cls, input_data): - return json.dumps(input_data, indent=2) - - @classmethod - def unmarshal(cls, input_string): - return json.loads(input_string) - - -class PickleSerializer(Serializer): - """ Picke Serializer """ - ext = "p" - woptions = "wb" - roptions = "rb" - - @classmethod - def marshal(cls, input_data): - return pickle.dumps(input_data) - - @classmethod - def unmarshal(cls, input_bytes): # pylint: disable=arguments-differ - return pickle.loads(input_bytes) - - -def get_serializer(serializer): - """ Return requested serializer """ - if serializer == "json": - return JSONSerializer - if serializer == "pickle": - return PickleSerializer - if serializer == "yaml" and yaml is not None: - return YAMLSerializer - if serializer == "yaml" and yaml is None: - logger.warning("You must have PyYAML installed to use YAML as the serializer." - "Switching to JSON as the serializer.") - return JSONSerializer - - -def get_serializer_from_ext(ext): - """ Get the sertializer from filename extension """ - if ext == ".json": - return JSONSerializer - if ext == ".p": - return PickleSerializer - if ext in (".yaml", ".yml") and yaml is not None: - return YAMLSerializer - if ext in (".yaml", ".yml") and yaml is None: - logger.warning("You must have PyYAML installed to use YAML as the serializer.\n" - "Switching to JSON as the serializer.") - return JSONSerializer diff --git a/lib/__init__.py b/lib/__init__.py index e69de29bb2..c87f4c4316 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Initialization for faceswap's lib section """ +# Import logger here so our custom loglevels are set for when executing code outside of FS +from . import logger diff --git a/lib/align/__init__.py b/lib/align/__init__.py new file mode 100644 index 0000000000..3f5887bcd6 --- /dev/null +++ b/lib/align/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +""" Package for handling alignments files, detected faces and aligned faces along with their +associated objects. """ +from .aligned_face import (AlignedFace, get_adjusted_center, get_matrix_scaling, + get_centered_size, transform_image) +from .aligned_mask import BlurMask, LandmarksMask, Mask +from .alignments import Alignments +from .constants import CenteringType, EXTRACT_RATIOS, LANDMARK_PARTS, LandmarkType +from .detected_face import DetectedFace, update_legacy_png_header diff --git a/lib/align/aligned_face.py b/lib/align/aligned_face.py new file mode 100644 index 0000000000..0a5c92c66e --- /dev/null +++ b/lib/align/aligned_face.py @@ -0,0 +1,775 @@ +#!/usr/bin/env python3 +""" Aligner for faceswap.py """ +from __future__ import annotations + +from dataclasses import dataclass, field +import logging +import typing as T + +from threading import Lock + +import cv2 +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .constants import CenteringType, EXTRACT_RATIOS, LandmarkType, _MEAN_FACE +from .pose import PoseEstimate + +logger = logging.getLogger(__name__) + + +def get_matrix_scaling(matrix: np.ndarray) -> tuple[int, int]: + """ Given a matrix, return the cv2 Interpolation method and inverse interpolation method for + applying the matrix on an image. + + Parameters + ---------- + matrix: :class:`numpy.ndarray` + The transform matrix to return the interpolator for + + Returns + ------- + tuple + The interpolator and inverse interpolator for the given matrix. This will be (Cubic, Area) + for an upscale matrix and (Area, Cubic) for a downscale matrix + """ + x_scale = np.sqrt(matrix[0, 0] * matrix[0, 0] + matrix[0, 1] * matrix[0, 1]) + if x_scale == 0: + y_scale = 0. + else: + y_scale = (matrix[0, 0] * matrix[1, 1] - matrix[0, 1] * matrix[1, 0]) / x_scale + avg_scale = (x_scale + y_scale) * 0.5 + if avg_scale >= 1.: + interpolators = cv2.INTER_CUBIC, cv2.INTER_AREA + else: + interpolators = cv2.INTER_AREA, cv2.INTER_CUBIC + logger.trace("interpolator: %s, inverse interpolator: %s", # type:ignore[attr-defined] + interpolators[0], interpolators[1]) + return interpolators + + +def transform_image(image: np.ndarray, + matrix: np.ndarray, + size: int, + padding: int = 0) -> np.ndarray: + """ Perform transformation on an image, applying the given size and padding to the matrix. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to transform + matrix: :class:`numpy.ndarray` + The transformation matrix to apply to the image + size: int + The final size of the transformed image + padding: int, optional + The amount of padding to apply to the final image. Default: `0` + + Returns + ------- + :class:`numpy.ndarray` + The transformed image + """ + logger.trace("image shape: %s, matrix: %s, size: %s. padding: %s", # type:ignore[attr-defined] + image.shape, matrix, size, padding) + # transform the matrix for size and padding + mat = matrix * (size - 2 * padding) + mat[:, 2] += padding + + # transform image + interpolators = get_matrix_scaling(mat) + retval = cv2.warpAffine(image, mat, (size, size), flags=interpolators[0]) + logger.trace("transformed matrix: %s, final image shape: %s", # type:ignore[attr-defined] + mat, image.shape) + return retval + + +def get_adjusted_center(image_size: int, + source_offset: np.ndarray, + target_offset: np.ndarray, + source_centering: CenteringType, + y_offset: float) -> np.ndarray: + """ Obtain the correct center of a face extracted image to translate between two different + extract centerings. + + Parameters + ---------- + image_size: int + The size of the image at the given :attr:`source_centering` + source_offset: :class:`numpy.ndarray` + The pose offset to translate a base extracted face to source centering + target_offset: :class:`numpy.ndarray` + The pose offset to translate a base extracted face to target centering + source_centering: ["face", "head", "legacy"] + The centering of the source image + y_offset: float + Amount to additionally offset the center of the image along the y-axis + + Returns + ------- + :class:`numpy.ndarray` + The center point of the image at the given size for the target centering + """ + source_size = image_size - (image_size * EXTRACT_RATIOS[source_centering]) + offset = target_offset - source_offset - [0., y_offset] + offset *= source_size + center = np.rint(offset + image_size / 2).astype("int32") + logger.trace( # type:ignore[attr-defined] + "image_size: %s, source_offset: %s, target_offset: %s, source_centering: '%s', " + "y_offset: %s, adjusted_offset: %s, center: %s", + image_size, source_offset, target_offset, source_centering, y_offset, offset, center) + return center + + +def get_centered_size(source_centering: CenteringType, + target_centering: CenteringType, + size: int, + coverage_ratio: float = 1.0) -> int: + """ Obtain the size of a cropped face from an aligned image. + + Given an image of a certain dimensions, returns the dimensions of the sub-crop within that + image for the requested centering at the requested coverage ratio + + Notes + ----- + `"legacy"` places the nose in the center of the image (the original method for aligning). + `"face"` aligns for the nose to be in the center of the face (top to bottom) but the center + of the skull for left to right. `"head"` places the center in the middle of the skull in 3D + space. + + The ROI in relation to the source image is calculated by rounding the padding of one side + to the nearest integer then applying this padding to the center of the crop, to ensure that + any dimensions always have an even number of pixels. + + Parameters + ---------- + source_centering: ["head", "face", "legacy"] + The centering that the original image is aligned at + target_centering: ["head", "face", "legacy"] + The centering that the sub-crop size should be obtained for + size: int + The size of the source image to obtain the cropped size for + coverage_ratio: float, optional + The coverage ratio to be applied to the target image. Default: `1.0` + + Returns + ------- + int + The pixel size of a sub-crop image from a full head aligned image with the given coverage + ratio + """ + if source_centering == target_centering and coverage_ratio == 1.0: + src_size: float | int = size + retval = size + else: + src_size = size - (size * EXTRACT_RATIOS[source_centering]) + retval = 2 * int(np.rint((src_size / (1 - EXTRACT_RATIOS[target_centering]) + * coverage_ratio) / 2)) + logger.trace( # type:ignore[attr-defined] + "source_centering: %s, target_centering: %s, size: %s, coverage_ratio: %s, " + "source_size: %s, crop_size: %s", + source_centering, target_centering, size, coverage_ratio, src_size, retval) + return retval + + +@dataclass +class _FaceCache: # pylint:disable=too-many-instance-attributes + """ Cache for storing items related to a single aligned face. + + Items are cached so that they are only created the first time they are called. + Each item includes a threading lock to make cache creation thread safe. + + Parameters + ---------- + pose: :class:`lib.align.PoseEstimate`, optional + The estimated pose in 3D space. Default: ``None`` + original_roi: :class:`numpy.ndarray`, optional + The location of the extracted face box within the original frame. Default: ``None`` + landmarks: :class:`numpy.ndarray`, optional + The 68 point facial landmarks aligned to the extracted face box. Default: ``None`` + landmarks_normalized: :class:`numpy.ndarray`: + The 68 point facial landmarks normalized to 0.0 - 1.0 as aligned by Umeyama. + Default: ``None`` + average_distance: float, optional + The average distance of the core landmarks (18-67) from the mean face that was used for + aligning the image. Default: `0.0` + relative_eye_mouth_position: float, optional + A float value representing the relative position of the lowest eye/eye-brow point to the + highest mouth point. Positive values indicate that eyes/eyebrows are aligned above the + mouth, negative values indicate that eyes/eyebrows are misaligned below the mouth. + Default: `0.0` + adjusted_matrix: :class:`numpy.ndarray`, optional + The 3x2 transformation matrix for extracting and aligning the core face area out of the + original frame with padding and sizing applied. Default: ``None`` + interpolators: tuple, optional + (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. + Default: `(0, 0)` + cropped_roi, dict, optional + The (`left`, `top`, `right`, `bottom` location of the region of interest within an + aligned face centered for each centering. Default: `{}` + cropped_slices: dict, optional + The slices for an input full head image and output cropped image. Default: `{}` + """ + pose: PoseEstimate | None = None + original_roi: np.ndarray | None = None + landmarks: np.ndarray | None = None + landmarks_normalized: np.ndarray | None = None + average_distance: float = 0.0 + relative_eye_mouth_position: float = 0.0 + adjusted_matrix: np.ndarray | None = None + interpolators: tuple[int, int] = (0, 0) + cropped_roi: dict[CenteringType, np.ndarray] = field(default_factory=dict) + cropped_slices: dict[CenteringType, dict[T.Literal["in", "out"], + tuple[slice, slice]]] = field(default_factory=dict) + + _locks: dict[str, Lock] = field(default_factory=dict) + + def __post_init__(self): + """ Initialize the locks for the class parameters """ + self._locks = {name: Lock() for name in self.__dict__} + + def lock(self, name: str) -> Lock: + """ Obtain the lock for the given property + + Parameters + ---------- + name: str + The name of a parameter within the cache + + Returns + ------- + :class:`threading.Lock` + The lock associated with the requested parameter + """ + return self._locks[name] + + +class AlignedFace(): # pylint:disable=too-many-instance-attributes + """ Class to align a face. + + Holds the aligned landmarks and face image, as well as associated matrices and information + about an aligned face. + + Parameters + ---------- + landmarks: :class:`numpy.ndarray` + The original 68 point landmarks that pertain to the given image for this face + image: :class:`numpy.ndarray`, optional + The original frame that contains the face that is to be aligned. Pass `None` if the aligned + face is not to be generated, and just the co-ordinates should be calculated. + centering: ["legacy", "face", "head"], optional + The type of extracted face that should be loaded. "legacy" places the nose in the center of + the image (the original method for aligning). "face" aligns for the nose to be in the + center of the face (top to bottom) but the center of the skull for left to right. "head" + aligns for the center of the skull (in 3D space) being the center of the extracted image, + with the crop holding the full head. Default: `"face"` + size: int, optional + The size in pixels, of each edge of the final aligned face. Default: `64` + coverage_ratio: float, optional + The amount of the aligned image to return. A ratio of 1.0 will return the full contents of + the aligned image. A ratio of 0.5 will return an image of the given size, but will crop to + the central 50%% of the image. + y_offset: float, optional + Amount to adjust the aligned face along the y-axis in the range -1. to 1. Default: 0.0 + dtype: str, optional + Set a data type for the final face to be returned as. Passing ``None`` will return a face + with the same data type as the original :attr:`image`. Default: ``None`` + is_aligned_face: bool, optional + Indicates that the :attr:`image` is an aligned face rather than a frame. + Default: ``False`` + is_legacy: bool, optional + Only used if `is_aligned` is ``True``. ``True`` indicates that the aligned image being + loaded is a legacy extracted face rather than a current head extracted face + """ + def __init__(self, + landmarks: np.ndarray, + image: np.ndarray | None = None, + centering: CenteringType = "face", + size: int = 64, + coverage_ratio: float = 1.0, + y_offset: float = 0.0, + dtype: str | None = None, + is_aligned: bool = False, + is_legacy: bool = False) -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self._frame_landmarks = landmarks + self._landmark_type = LandmarkType.from_shape(landmarks.shape) + self._centering = centering + self._size = size + self._coverage_ratio = coverage_ratio + self._y_offset = y_offset + self._dtype = dtype + self._is_aligned = is_aligned + self._source_centering: CenteringType = "legacy" if is_legacy and is_aligned else "head" + self._padding = self._padding_from_coverage(size, coverage_ratio) + + lookup = self._landmark_type + self._mean_lookup = LandmarkType.LM_2D_51 if lookup == LandmarkType.LM_2D_68 else lookup + + self._cache = _FaceCache() + self._matrices: dict[CenteringType, np.ndarray] = {"legacy": self._get_default_matrix()} + + self._face = self.extract_face(image) + logger.trace("Initialized: %s (padding: %s, face shape: %s)", # type:ignore[attr-defined] + self.__class__.__name__, self._padding, + self._face if self._face is None else self._face.shape) + + @property + def centering(self) -> T.Literal["legacy", "head", "face"]: + """ str: The centering of the Aligned Face. One of `"legacy"`, `"head"`, `"face"`. """ + return self._centering + + @property + def size(self) -> int: + """ int: The size (in pixels) of one side of the square extracted face image. """ + return self._size + + @property + def padding(self) -> int: + """ int: The amount of padding (in pixels) that is applied to each side of the + extracted face image for the selected extract type. """ + return self._padding[self._centering] + + @property + def y_offset(self) -> float: + """ float: Additional offset applied to the face along the y-axis in -1. to 1. range """ + return self._y_offset + + @property + def matrix(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The 3x2 transformation matrix for extracting and aligning the + core face area out of the original frame, with no padding or sizing applied. The returned + matrix is offset for the given :attr:`centering`. """ + if self._centering not in self._matrices: + matrix = self._matrices["legacy"].copy() + matrix[:, 2] -= self.pose.offset[self._centering] + self._matrices[self._centering] = matrix + logger.trace("original matrix: %s, new matrix: %s", # type:ignore[attr-defined] + self._matrices["legacy"], matrix) + return self._matrices[self._centering] + + @property + def pose(self) -> PoseEstimate: + """ :class:`lib.align.PoseEstimate`: The estimated pose in 3D space. """ + with self._cache.lock("pose"): + if self._cache.pose is None: + lms = np.nan_to_num(cv2.transform(np.expand_dims(self._frame_landmarks, axis=1), + self._matrices["legacy"]).squeeze()) + self._cache.pose = PoseEstimate(lms, self._landmark_type) + return self._cache.pose + + @property + def adjusted_matrix(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The 3x2 transformation matrix for extracting and aligning the + core face area out of the original frame with padding and sizing applied. """ + with self._cache.lock("adjusted_matrix"): + if self._cache.adjusted_matrix is None: + matrix = self.matrix.copy() + mat = matrix * (self._size - 2 * self.padding) + mat[:, 2] += self.padding + logger.trace("adjusted_matrix: %s", mat) # type:ignore[attr-defined] + self._cache.adjusted_matrix = mat + return self._cache.adjusted_matrix + + @property + def face(self) -> np.ndarray | None: + """ :class:`numpy.ndarray`: The aligned face at the given :attr:`size` at the specified + :attr:`coverage` in the given :attr:`dtype`. If an :attr:`image` has not been provided + then an the attribute will return ``None``. """ + return self._face + + @property + def original_roi(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The location of the extracted face box within the original + frame. """ + with self._cache.lock("original_roi"): + if self._cache.original_roi is None: + roi = np.array([[0, 0], + [0, self._size - 1], + [self._size - 1, self._size - 1], + [self._size - 1, 0]]) + roi = np.rint(self.transform_points(roi, invert=True)).astype("int32") + logger.trace("original roi: %s", roi) # type:ignore[attr-defined] + self._cache.original_roi = roi + return self._cache.original_roi + + @property + def landmarks(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The 68 point facial landmarks aligned to the extracted face + box. """ + with self._cache.lock("landmarks"): + if self._cache.landmarks is None: + lms = self.transform_points(self._frame_landmarks) + logger.trace("aligned landmarks: %s", lms) # type:ignore[attr-defined] + self._cache.landmarks = lms + return self._cache.landmarks + + @property + def landmark_type(self) -> LandmarkType: + """:class:`~LandmarkType`: The type of landmarks that generated this aligned face """ + return self._landmark_type + + @property + def normalized_landmarks(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The 68 point facial landmarks normalized to 0.0 - 1.0 as + aligned by Umeyama. """ + with self._cache.lock("landmarks_normalized"): + if self._cache.landmarks_normalized is None: + lms = np.expand_dims(self._frame_landmarks, axis=1) + lms = cv2.transform(lms, self._matrices["legacy"]).squeeze() + logger.trace("normalized landmarks: %s", lms) # type:ignore[attr-defined] + self._cache.landmarks_normalized = lms + return self._cache.landmarks_normalized + + @property + def interpolators(self) -> tuple[int, int]: + """ tuple: (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. """ + with self._cache.lock("interpolators"): + if not any(self._cache.interpolators): + interpolators = get_matrix_scaling(self.adjusted_matrix) + logger.trace("interpolators: %s", interpolators) # type:ignore[attr-defined] + self._cache.interpolators = interpolators + return self._cache.interpolators + + @property + def average_distance(self) -> float: + """ float: The average distance of the core landmarks (18-67) from the mean face that was + used for aligning the image. """ + with self._cache.lock("average_distance"): + if not self._cache.average_distance: + mean_face = _MEAN_FACE[self._mean_lookup] + lms = self.normalized_landmarks + if self._landmark_type == LandmarkType.LM_2D_68: + lms = lms[17:] # 68 point landmarks only use core face items + average_distance = np.mean(np.abs(lms - mean_face)) + logger.trace("average_distance: %s", average_distance) # type:ignore[attr-defined] + self._cache.average_distance = average_distance + return self._cache.average_distance + + @property + def relative_eye_mouth_position(self) -> float: + """ float: Value representing the relative position of the lowest eye/eye-brow point to the + highest mouth point. Positive values indicate that eyes/eyebrows are aligned above the + mouth, negative values indicate that eyes/eyebrows are misaligned below the mouth. """ + with self._cache.lock("relative_eye_mouth_position"): + if not self._cache.relative_eye_mouth_position: + if self._landmark_type != LandmarkType.LM_2D_68: + position = 1.0 # arbitrary positive value + else: + lowest_eyes = np.max(self.normalized_landmarks[np.r_[17:27, 36:48], 1]) + highest_mouth = np.min(self.normalized_landmarks[48:68, 1]) + position = highest_mouth - lowest_eyes + logger.trace("lowest_eyes: %s, highest_mouth: %s, " # type:ignore[attr-defined] + "relative_eye_mouth_position: %s", lowest_eyes, highest_mouth, + position) + self._cache.relative_eye_mouth_position = position + return self._cache.relative_eye_mouth_position + + @classmethod + def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> dict[CenteringType, int]: + """ Return the image padding for a face from coverage_ratio set against a + pre-padded training image. + + Parameters + ---------- + size: int + The final size of the aligned image in pixels + coverage_ratio: float + The ratio of the final image to pad to + + Returns + ------- + dict + The padding required, in pixels for 'head', 'face' and 'legacy' face types + """ + retval = {_type: round((size * (coverage_ratio - (1 - EXTRACT_RATIOS[_type]))) / 2) + for _type in T.get_args(T.Literal["legacy", "face", "head"])} + logger.trace(retval) # type:ignore[attr-defined] + return retval + + def _get_default_matrix(self) -> np.ndarray: + """ Get the default (legacy) matrix. All subsequent matrices are calculated from this + + Returns + ------- + :class:`numpy.ndarray` + The default 'legacy' matrix + """ + lms = self._frame_landmarks + if self._landmark_type == LandmarkType.LM_2D_68: + lms = lms[17:] # 68 point landmarks only use core face items + retval = _umeyama(lms, _MEAN_FACE[self._mean_lookup], True)[0:2] + logger.trace("Default matrix: %s", retval) # type:ignore[attr-defined] + return retval + + def transform_points(self, points: np.ndarray, invert: bool = False) -> np.ndarray: + """ Perform transformation on a series of (x, y) co-ordinates in world space into + aligned face space. + + Parameters + ---------- + points: :class:`numpy.ndarray` + The points to transform + invert: bool, optional + ``True`` to reverse the transformation (i.e. transform the points into world space from + aligned face space). Default: ``False`` + + Returns + ------- + :class:`numpy.ndarray` + The transformed points + """ + retval = np.expand_dims(points, axis=1) + mat = cv2.invertAffineTransform(self.adjusted_matrix) if invert else self.adjusted_matrix + retval = cv2.transform(retval, mat).squeeze() + logger.trace( # type:ignore[attr-defined] + "invert: %s, Original points: %s, transformed points: %s", invert, points, retval) + return retval + + def extract_face(self, image: np.ndarray | None) -> np.ndarray | None: + """ Extract the face from a source image and populate :attr:`face`. If an image is not + provided then ``None`` is returned. + + Parameters + ---------- + image: :class:`numpy.ndarray` or ``None`` + The original frame to extract the face from. ``None`` if the face should not be + extracted + + Returns + ------- + :class:`numpy.ndarray` or ``None`` + The extracted face at the given size, with the given coverage of the given dtype or + ``None`` if no image has been provided. + """ + if image is None: + logger.trace("_extract_face called without a loaded " # type:ignore[attr-defined] + "image. Returning empty face.") + return None + + if self._is_aligned: + # Crop out the sub face from full head + image = self._convert_centering(image) + + if self._is_aligned and image.shape[0] != self._size: # Resize the given aligned face + interp = cv2.INTER_CUBIC if image.shape[0] < self._size else cv2.INTER_AREA + retval = cv2.resize(image, (self._size, self._size), interpolation=interp) + elif self._is_aligned: + retval = image + else: + retval = transform_image(image, self.matrix, self._size, self.padding) + retval = retval if self._dtype is None else retval.astype(self._dtype) + return retval + + def _convert_centering(self, image: np.ndarray) -> np.ndarray: + """ When the face being loaded is pre-aligned, the loaded image will have 'head' centering + so it needs to be cropped out to the appropriate centering. + + This function temporarily converts this object to a full head aligned face, extracts the + sub-cropped face to the correct centering, reverse the sub crop and returns the cropped + face at the selected coverage ratio. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The original head-centered aligned image + + Returns + ------- + :class:`numpy.ndarray` + The aligned image with the correct centering, scaled to image input size + """ + logger.trace( # type:ignore[attr-defined] + "image_size: %s, target_size: %s, coverage_ratio: %s", + image.shape[0], self.size, self._coverage_ratio) + + img_size = image.shape[0] + target_size = get_centered_size(self._source_centering, + self._centering, + img_size, + self._coverage_ratio) + out = np.zeros((target_size, target_size, image.shape[-1]), dtype=image.dtype) + + slices = self._get_cropped_slices(img_size, target_size) + out[slices["out"][0], slices["out"][1], :] = image[slices["in"][0], slices["in"][1], :] + logger.trace( # type:ignore[attr-defined] + "Cropped from aligned extract: (centering: %s, in shape: %s, out shape: %s)", + self._centering, image.shape, out.shape) + return out + + def _get_cropped_slices(self, + image_size: int, + target_size: int, + ) -> dict[T.Literal["in", "out"], tuple[slice, slice]]: + """ Obtain the slices to turn a full head extract into an alternatively centered extract. + + Parameters + ---------- + image_size: int + The size of the full head extracted image loaded from disk + target_size: int + The size of the target centered face with coverage ratio applied in relation to the + original image size + + Returns + ------- + dict + The slices for an input full head image and output cropped image + """ + with self._cache.lock("cropped_slices"): + if not self._cache.cropped_slices.get(self._centering): + roi = self.get_cropped_roi(image_size, target_size, self._centering) + slice_in = (slice(max(roi[1], 0), max(roi[3], 0)), + slice(max(roi[0], 0), max(roi[2], 0))) + slice_out = (slice(max(roi[1] * -1, 0), + target_size - min(target_size, max(0, roi[3] - image_size))), + slice(max(roi[0] * -1, 0), + target_size - min(target_size, max(0, roi[2] - image_size)))) + self._cache.cropped_slices[self._centering] = {"in": slice_in, "out": slice_out} + logger.trace("centering: %s, cropped_slices: %s", # type:ignore[attr-defined] + self._centering, self._cache.cropped_slices[self._centering]) + return self._cache.cropped_slices[self._centering] + + def get_cropped_roi(self, + image_size: int, + target_size: int, + centering: CenteringType) -> np.ndarray: + """ Obtain the region of interest within an aligned face set to centered coverage for + an alternative centering + + Parameters + ---------- + image_size: int + The size of the full head extracted image loaded from disk + target_size: int + The size of the target centered face with coverage ratio applied in relation to the + original image size + + centering: ["legacy", "face"] + The type of centering to obtain the region of interest for. "legacy" places the nose + in the center of the image (the original method for aligning). "face" aligns for the + nose to be in the center of the face (top to bottom) but the center of the skull for + left to right. + + Returns + ------- + :class:`numpy.ndarray` + The (`left`, `top`, `right`, `bottom` location of the region of interest within an + aligned face centered on the head for the given centering + """ + with self._cache.lock("cropped_roi"): + if centering not in self._cache.cropped_roi: + center = get_adjusted_center(image_size, + self.pose.offset[self._source_centering], + self.pose.offset[centering], + self._source_centering, + self.y_offset) + padding = target_size // 2 + roi = np.array([center - padding, center + padding]).ravel() + logger.trace( # type:ignore[attr-defined] + "centering: '%s', center: %s, padding: %s, sub roi: %s", + centering, center, padding, roi) + self._cache.cropped_roi[centering] = roi + return self._cache.cropped_roi[centering] + + def split_mask(self) -> np.ndarray: + """ Remove the mask from the alpha channel of :attr:`face` and return the mask + + Returns + ------- + :class:`numpy.ndarray` + The mask that was stored in the :attr:`face`'s alpha channel + + Raises + ------ + AssertionError + If :attr:`face` does not contain a mask in the alpha channel + """ + assert self._face is not None + assert self._face.shape[-1] == 4, "No mask stored in the alpha channel" + mask = self._face[..., 3] + self._face = self._face[..., :3] + return mask + + +def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray: + """Estimate N-D similarity transformation with or without scaling. + + Imported, and slightly adapted, directly from: + https://github.com/scikit-image/scikit-image/blob/master/skimage/transform/_geometric.py + + + Parameters + ---------- + source: :class:`numpy.ndarray` + (M, N) array source coordinates. + destination: :class:`numpy.ndarray` + (M, N) array destination coordinates. + estimate_scale: bool + Whether to estimate scaling factor. + + Returns + ------- + :class:`numpy.ndarray` + (N + 1, N + 1) The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + # pylint:disable=invalid-name,too-many-locals + num = source.shape[0] + dim = source.shape[1] + + # Compute mean of source and destination. + src_mean = source.mean(axis=0) + dst_mean = destination.mean(axis=0) + + # Subtract mean from source and destination. + src_demean = source - src_mean + dst_demean = destination - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + retval = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * retval + if rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + retval[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + retval[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + retval[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = 1.0 + + retval[:dim, dim] = dst_mean - scale * (retval[:dim, :dim] @ src_mean.T) + retval[:dim, :dim] *= scale + + return retval + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/aligned_mask.py b/lib/align/aligned_mask.py new file mode 100644 index 0000000000..ad5a7e7e1c --- /dev/null +++ b/lib/align/aligned_mask.py @@ -0,0 +1,607 @@ +#!/usr/bin python3 +""" Handles retrieval and storage of Faceswap aligned masks """ + +from __future__ import annotations +import logging +import typing as T + +from zlib import compress, decompress + +import cv2 +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .alignments import MaskAlignmentsFileDict +from . import get_adjusted_center, get_centered_size + +if T.TYPE_CHECKING: + from collections.abc import Callable + from .aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +class Mask(): # pylint:disable=too-many-instance-attributes + """ Face Mask information and convenience methods + + Holds a Faceswap mask as generated from :mod:`plugins.extract.mask` and the information + required to transform it to its original frame. + + Holds convenience methods to handle the warping, storing and retrieval of the mask. + + Parameters + ---------- + storage_size: int, optional + The size (in pixels) that the mask should be stored at. Default: 128. + storage_centering, str (optional): + The centering to store the mask at. One of `"legacy"`, `"face"`, `"head"`. + Default: `"face"` + + Attributes + ---------- + stored_size: int + The size, in pixels, of the stored mask across its height and width. + stored_centering: str + The centering that the mask is stored at. One of `"legacy"`, `"face"`, `"head"` + """ + def __init__(self, + storage_size: int = 128, + storage_centering: CenteringType = "face") -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self.stored_size = storage_size + self.stored_centering: CenteringType = storage_centering + + self._mask: bytes | None = None + self._affine_matrix: np.ndarray | None = None + self._interpolator: int | None = None + + self._blur_type: T.Literal["gaussian", "normalized"] | None = None + self._blur_passes: int = 0 + self._blur_kernel: float | int = 0 + self._threshold = 0.0 + self._dilation: tuple[T.Literal["erode", "dilate"], np.ndarray | None] = ("erode", None) + self._sub_crop_size = 0 + self._sub_crop_slices: dict[T.Literal["in", "out"], list[slice]] = {} + + self.set_blur_and_threshold() + logger.trace("Initialized: %s", self.__class__.__name__) # type:ignore[attr-defined] + + @property + def mask(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The mask at the size of :attr:`stored_size` with any requested + blurring, threshold amount and centering applied.""" + mask = self.stored_mask + if self._dilation[-1] is not None or self._threshold != 0.0 or self._blur_kernel != 0: + mask = mask.copy() + self._dilate_mask(mask) + if self._threshold != 0.0: + mask[mask < self._threshold] = 0.0 + mask[mask > 255.0 - self._threshold] = 255.0 + if self._blur_kernel != 0 and self._blur_type is not None: + mask = BlurMask(self._blur_type, + mask, + self._blur_kernel, + passes=self._blur_passes).blurred + if self._sub_crop_size: # Crop the mask to the given centering + out = np.zeros((self._sub_crop_size, self._sub_crop_size, 1), dtype=mask.dtype) + slice_in, slice_out = self._sub_crop_slices["in"], self._sub_crop_slices["out"] + out[slice_out[0], slice_out[1], :] = mask[slice_in[0], slice_in[1], :] + mask = out + logger.trace("mask shape: %s", mask.shape) # type:ignore[attr-defined] + return mask + + @property + def stored_mask(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The mask at the size of :attr:`stored_size` as it is stored + (i.e. with no blurring/centering applied). """ + assert self._mask is not None + dims = (self.stored_size, self.stored_size, 1) + mask = np.frombuffer(decompress(self._mask), dtype="uint8").reshape(dims) + logger.trace("stored mask shape: %s", mask.shape) # type:ignore[attr-defined] + return mask + + @property + def original_roi(self) -> np.ndarray: + """ :class: `numpy.ndarray`: The original region of interest of the mask in the + source frame. """ + points = np.array([[0, 0], + [0, self.stored_size - 1], + [self.stored_size - 1, self.stored_size - 1], + [self.stored_size - 1, 0]], np.int32).reshape((-1, 1, 2)) + matrix = cv2.invertAffineTransform(self.affine_matrix) + roi = cv2.transform(points, matrix).reshape((4, 2)) + logger.trace("Returning: %s", roi) # type:ignore[attr-defined] + return roi + + @property + def affine_matrix(self) -> np.ndarray: + """ :class: `numpy.ndarray`: The affine matrix to transpose the mask to a full frame. """ + assert self._affine_matrix is not None + return self._affine_matrix + + @property + def interpolator(self) -> int: + """ int: The cv2 interpolator required to transpose the mask to a full frame. """ + assert self._interpolator is not None + return self._interpolator + + def _dilate_mask(self, mask: np.ndarray) -> None: + """ Erode/Dilate the mask. The action is performed in-place on the given mask. + + No action is performed if a dilation amount has not been set + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The mask to be eroded/dilated + """ + if self._dilation[-1] is None: + return + + func = cv2.erode if self._dilation[0] == "erode" else cv2.dilate + func(mask, self._dilation[-1], dst=mask, iterations=1) + + def get_full_frame_mask(self, width: int, height: int) -> np.ndarray: + """ Return the stored mask in a full size frame of the given dimensions + + Parameters + ---------- + width: int + The width of the original frame that the mask was extracted from + height: int + The height of the original frame that the mask was extracted from + + Returns + ------- + :class:`numpy.ndarray`: The mask affined to the original full frame of the given dimensions + """ + frame = np.zeros((width, height, 1), dtype="uint8") + mask = cv2.warpAffine(self.mask, + self.affine_matrix, + (width, height), + frame, + flags=cv2.WARP_INVERSE_MAP | self.interpolator, + borderMode=cv2.BORDER_CONSTANT) + logger.trace("mask shape: %s, mask dtype: %s, mask min: %s, " # type:ignore[attr-defined] + "mask max: %s", mask.shape, mask.dtype, mask.min(), mask.max()) + return mask + + def add(self, mask: np.ndarray, affine_matrix: np.ndarray, interpolator: int) -> None: + """ Add a Faceswap mask to this :class:`Mask`. + + The mask should be the original output from :mod:`plugins.extract.mask` + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The mask that is to be added as output from :mod:`plugins.extract.mask` + It should be in the range 0.0 - 1.0 ideally with a ``dtype`` of ``float32`` + affine_matrix: :class:`numpy.ndarray` + The transformation matrix required to transform the mask to the original frame. + interpolator, int: + The CV2 interpolator required to transform this mask to it's original frame + """ + logger.trace("mask shape: %s, mask dtype: %s, mask min: %s, " # type:ignore[attr-defined] + "mask max: %s, affine_matrix: %s, interpolator: %s)", + mask.shape, mask.dtype, mask.min(), affine_matrix, mask.max(), interpolator) + self._affine_matrix = self._adjust_affine_matrix(mask.shape[0], affine_matrix) + self._interpolator = interpolator + self.replace_mask(mask) + + def replace_mask(self, mask: np.ndarray) -> None: + """ Replace the existing :attr:`_mask` with the given mask. + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The mask that is to be added as output from :mod:`plugins.extract.mask`. + It should be in the range 0.0 - 1.0 ideally with a ``dtype`` of ``float32`` + """ + mask = (cv2.resize(mask * 255.0, + (self.stored_size, self.stored_size), + interpolation=cv2.INTER_AREA)).astype("uint8") + self._mask = compress(mask.tobytes()) + + def set_dilation(self, amount: float) -> None: + """ Set the internal dilation object for returned masks + + Parameters + ---------- + amount: float + The amount of erosion/dilation to apply as a percentage of the total mask size. + Negative values erode the mask. Positive values dilate the mask + """ + if amount == 0: + self._dilation = ("erode", None) + return + + action: T.Literal["erode", "dilate"] = "erode" if amount < 0 else "dilate" + kernel = int(round(self.stored_size * abs(amount / 100.), 0)) + self._dilation = (action, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel, kernel))) + + logger.trace("action: '%s', amount: %s, kernel: %s, ", # type:ignore[attr-defined] + action, amount, kernel) + + def set_blur_and_threshold(self, + blur_kernel: int = 0, + blur_type: T.Literal["gaussian", "normalized"] | None = "gaussian", + blur_passes: int = 1, + threshold: int = 0) -> None: + """ Set the internal blur kernel and threshold amount for returned masks + + Parameters + ---------- + blur_kernel: int, optional + The kernel size, in pixels to apply gaussian blurring to the mask. Set to 0 for no + blurring. Should be odd, if an even number is passed in (outside of 0) then it is + rounded up to the next odd number. Default: 0 + blur_type: ["gaussian", "normalized"], optional + The blur type to use. ``gaussian`` or ``normalized`` box filter. Default: ``gaussian`` + blur_passes: int, optional + The number of passed to perform when blurring. Default: 1 + threshold: int, optional + The threshold amount to minimize/maximize mask values to 0 and 100. Percentage value. + Default: 0 + """ + logger.trace("blur_kernel: %s, blur_type: %s, " # type:ignore[attr-defined] + "blur_passes: %s, threshold: %s", + blur_kernel, blur_type, blur_passes, threshold) + if blur_type is not None: + blur_kernel += 0 if blur_kernel == 0 or blur_kernel % 2 == 1 else 1 + self._blur_kernel = blur_kernel + self._blur_type = blur_type + self._blur_passes = blur_passes + self._threshold = (threshold / 100.0) * 255.0 + + def set_sub_crop(self, + source_offset: np.ndarray, + target_offset: np.ndarray, + centering: CenteringType, + coverage_ratio: float = 1.0, + y_offset: float = 0.0) -> None: + """ Set the internal crop area of the mask to be returned. + + This impacts the returned mask from :attr:`mask` if the requested mask is required for + different face centering than what has been stored. + + Parameters + ---------- + source_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at its stored centering + target_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at the requested target centering + centering: str + The centering to set the sub crop area for. One of `"legacy"`, `"face"`. `"head"` + coverage_ratio: float, optional + The coverage ratio to be applied to the target image. ``None`` for default (1.0). + Default: ``None`` + y_offset: float, optional + Amount to additionally adjust the masks's offset along the y-axis. Default: 0.0 + """ + if centering == self.stored_centering and coverage_ratio == 1.0: + return + + center = get_adjusted_center(self.stored_size, + source_offset, + target_offset, + self.stored_centering, + y_offset) + crop_size = get_centered_size(self.stored_centering, + centering, + self.stored_size, + coverage_ratio=coverage_ratio) + roi = np.array([center - crop_size // 2, center + crop_size // 2]).ravel() + + self._sub_crop_size = crop_size + self._sub_crop_slices["in"] = [slice(max(roi[1], 0), max(roi[3], 0)), + slice(max(roi[0], 0), max(roi[2], 0))] + self._sub_crop_slices["out"] = [ + slice(max(roi[1] * -1, 0), + crop_size - min(crop_size, max(0, roi[3] - self.stored_size))), + slice(max(roi[0] * -1, 0), + crop_size - min(crop_size, max(0, roi[2] - self.stored_size)))] + + logger.trace("src_size: %s, coverage_ratio: %s, " # type:ignore[attr-defined] + "sub_crop_size: %s, sub_crop_slices: %s", + roi, coverage_ratio, self._sub_crop_size, self._sub_crop_slices) + + def _adjust_affine_matrix(self, mask_size: int, affine_matrix: np.ndarray) -> np.ndarray: + """ Adjust the affine matrix for the mask's storage size + + Parameters + ---------- + mask_size: int + The original size of the mask. + affine_matrix: :class:`numpy.ndarray` + The affine matrix to transform the mask at original size to the parent frame. + + Returns + ------- + affine_matrix: :class:`numpy,ndarray` + The affine matrix adjusted for the mask at its stored dimensions. + """ + zoom = self.stored_size / mask_size + zoom_mat = np.array([[zoom, 0, 0.], [0, zoom, 0.]]) + adjust_mat = np.dot(zoom_mat, np.concatenate((affine_matrix, np.array([[0., 0., 1.]])))) + logger.trace("storage_size: %s, mask_size: %s, zoom: %s, " # type:ignore[attr-defined] + "original matrix: %s, adjusted_matrix: %s", self.stored_size, mask_size, zoom, + affine_matrix.shape, adjust_mat.shape) + return adjust_mat + + def to_dict(self, is_png=False) -> MaskAlignmentsFileDict: + """ Convert the mask to a dictionary for saving to an alignments file + + Parameters + ---------- + is_png: bool + ``True`` if the dictionary is being created for storage in a png header otherwise + ``False``. Default: ``False`` + + Returns + ------- + dict: + The :class:`Mask` for saving to an alignments file. Contains the keys ``mask``, + ``affine_matrix``, ``interpolator``, ``stored_size``, ``stored_centering`` + """ + assert self._mask is not None + affine_matrix = self.affine_matrix.tolist() if is_png else self.affine_matrix + retval = MaskAlignmentsFileDict(mask=self._mask, + affine_matrix=affine_matrix, + interpolator=self.interpolator, + stored_size=self.stored_size, + stored_centering=self.stored_centering) + logger.trace({k: v if k != "mask" else type(v) # type:ignore[attr-defined] + for k, v in retval.items()}) + return retval + + def to_png_meta(self) -> MaskAlignmentsFileDict: + """ Convert the mask to a dictionary supported by png itxt headers. + + Returns + ------- + dict: + The :class:`Mask` for saving to an alignments file. Contains the keys ``mask``, + ``affine_matrix``, ``interpolator``, ``stored_size``, ``stored_centering`` + """ + return self.to_dict(is_png=True) + + def from_dict(self, mask_dict: MaskAlignmentsFileDict) -> None: + """ Populates the :class:`Mask` from a dictionary loaded from an alignments file. + + Parameters + ---------- + mask_dict: dict + A dictionary stored in an alignments file containing the keys ``mask``, + ``affine_matrix``, ``interpolator``, ``stored_size``, ``stored_centering`` + """ + self._mask = mask_dict["mask"] + affine_matrix = mask_dict["affine_matrix"] + self._affine_matrix = (affine_matrix if isinstance(affine_matrix, np.ndarray) + else np.array(affine_matrix, dtype="float64")) + self._interpolator = mask_dict["interpolator"] + self.stored_size = mask_dict["stored_size"] + centering = mask_dict.get("stored_centering") + self.stored_centering = "face" if centering is None else centering + logger.trace({k: v if k != "mask" else type(v) # type:ignore[attr-defined] + for k, v in mask_dict.items()}) + + +class LandmarksMask(Mask): + """ Create a single channel mask from aligned landmark points. + + Landmarks masks are created on the fly, so the stored centering and size should be the same as + the aligned face that the mask will be applied to. As the masks are created on the fly, blur + + dilation is applied to the mask at creation (prior to compression) rather than after + decompression when requested. + + Note + ---- + Threshold is not used for Landmarks mask as the mask is binary + + Parameters + ---------- + points : list[:class:`numpy.ndarray`] + A list of landmark points that correspond to the given storage_size to create + the mask. Each item in the list should be a :class:`numpy.ndarray` that a filled + convex polygon will be created from + storage_size : int, optional + The size (in pixels) that the compressed mask should be stored at. Default: 128. + storage_centering : str, optional: + The centering to store the mask at. One of `"legacy"`, `"face"`, `"head"`. + Default: `"face"` + dilation : float, optional + The amount of dilation to apply to the mask. as a percentage of the mask size. Default: 0.0 + """ + def __init__(self, + points: list[np.ndarray], + storage_size: int = 128, + storage_centering: CenteringType = "face", + dilation: float = 0.0) -> None: + super().__init__(storage_size=storage_size, storage_centering=storage_centering) + self._points = points + self.set_dilation(dilation) + + @property + def mask(self) -> np.ndarray: + """ :class:`numpy.ndarray`: Overrides the default mask property, creating the processed + mask at first call and compressing it. The decompressed mask is returned from this + property. """ + return self.stored_mask + + def generate_mask(self, affine_matrix: np.ndarray, interpolator: int) -> None: + """ Generate the mask. + + Creates the mask applying any requested dilation and blurring and assigns compressed mask + to :attr:`_mask` + + Parameters + ---------- + affine_matrix: :class:`numpy.ndarray` + The transformation matrix required to transform the mask to the original frame. + interpolator, int: + The CV2 interpolator required to transform this mask to it's original frame + """ + mask = np.zeros((self.stored_size, self.stored_size, 1), dtype="float32") + for landmarks in self._points: + lms = np.rint(landmarks).astype("int") + cv2.fillConvexPoly(mask, cv2.convexHull(lms), [1.0], lineType=cv2.LINE_AA) + if self._dilation[-1] is not None: + self._dilate_mask(mask) + if self._blur_kernel != 0 and self._blur_type is not None: + mask = BlurMask(self._blur_type, + mask, + self._blur_kernel, + passes=self._blur_passes).blurred + logger.trace("mask: (shape: %s, dtype: %s)", # type:ignore[attr-defined] + mask.shape, mask.dtype) + self.add(mask, affine_matrix, interpolator) + + +class BlurMask(): + """ Factory class to return the correct blur object for requested blur type. + + Works for square images only. Currently supports Gaussian and Normalized Box Filters. + + Parameters + ---------- + blur_type: ["gaussian", "normalized"] + The type of blur to use + mask: :class:`numpy.ndarray` + The mask to apply the blur to + kernel: int or float + Either the kernel size (in pixels) or the size of the kernel as a ratio of mask size + is_ratio: bool, optional + Whether the given :attr:`kernel` parameter is a ratio or not. If ``True`` then the + actual kernel size will be calculated from the given ratio and the mask size. If + ``False`` then the kernel size will be set directly from the :attr:`kernel` parameter. + Default: ``False`` + passes: int, optional + The number of passes to perform when blurring. Default: ``1`` + + Example + ------- + >>> print(mask.shape) + (128, 128, 1) + >>> new_mask = BlurMask("gaussian", mask, 3, is_ratio=False, passes=1).blurred + >>> print(new_mask.shape) + (128, 128, 1) + """ + def __init__(self, + blur_type: T.Literal["gaussian", "normalized"], + mask: np.ndarray, + kernel: int | float, + is_ratio: bool = False, + passes: int = 1) -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self._blur_type = blur_type + self._mask = mask + self._passes = passes + kernel_size = self._get_kernel_size(kernel, is_ratio) + self._kernel_size = self._get_kernel_tuple(kernel_size) + logger.trace("Initialized %s", self.__class__.__name__) # type:ignore[attr-defined] + + @property + def blurred(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The final mask with blurring applied. """ + func = self._func_mapping[self._blur_type] + kwargs = self._get_kwargs() + blurred = self._mask + for i in range(self._passes): + assert isinstance(kwargs["ksize"], tuple) + ksize = int(kwargs["ksize"][0]) + logger.trace("Pass: %s, kernel_size: %s", # type:ignore[attr-defined] + i + 1, (ksize, ksize)) + blurred = func(blurred, **kwargs) + ksize = int(round(ksize * self._multipass_factor)) + kwargs["ksize"] = self._get_kernel_tuple(ksize) + blurred = blurred[..., None] + logger.trace("Returning blurred mask. Shape: %s", # type:ignore[attr-defined] + blurred.shape) + return blurred + + @property + def _multipass_factor(self) -> float: + """ For multiple passes the kernel must be scaled down. This value is + different for box filter and gaussian """ + factor = {"gaussian": 0.8, "normalized": 0.5} + return factor[self._blur_type] + + @property + def _sigma(self) -> T.Literal[0]: + """ int: The Sigma for Gaussian Blur. Returns 0 to force calculation from kernel size. """ + return 0 + + @property + def _func_mapping(self) -> dict[T.Literal["gaussian", "normalized"], Callable]: + """ dict: :attr:`_blur_type` mapped to cv2 Function name. """ + return {"gaussian": cv2.GaussianBlur, "normalized": cv2.blur} + + @property + def _kwarg_requirements(self) -> dict[T.Literal["gaussian", "normalized"], list[str]]: + """ dict: :attr:`_blur_type` mapped to cv2 Function required keyword arguments. """ + return {"gaussian": ['ksize', 'sigmaX'], "normalized": ['ksize']} + + @property + def _kwarg_mapping(self) -> dict[str, int | tuple[int, int]]: + """ dict: cv2 function keyword arguments mapped to their parameters. """ + return {"ksize": self._kernel_size, "sigmaX": self._sigma} + + def _get_kernel_size(self, kernel: int | float, is_ratio: bool) -> int: + """ Set the kernel size to absolute value. + + If :attr:`is_ratio` is ``True`` then the kernel size is calculated from the given ratio and + the :attr:`_mask` size, otherwise the given kernel size is just returned. + + Parameters + ---------- + kernel: int or float + Either the kernel size (in pixels) or the size of the kernel as a ratio of mask size + is_ratio: bool, optional + Whether the given :attr:`kernel` parameter is a ratio or not. If ``True`` then the + actual kernel size will be calculated from the given ratio and the mask size. If + ``False`` then the kernel size will be set directly from the :attr:`kernel` parameter. + + Returns + ------- + int + The size (in pixels) of the blur kernel + """ + if not is_ratio: + return int(kernel) + + mask_diameter = np.sqrt(np.sum(self._mask)) + radius = round(max(1., mask_diameter * kernel / 100.)) + kernel_size = int(radius * 2 + 1) + logger.trace("kernel_size: %s", kernel_size) # type:ignore[attr-defined] + return kernel_size + + @staticmethod + def _get_kernel_tuple(kernel_size: int) -> tuple[int, int]: + """ Make sure kernel_size is odd and return it as a tuple. + + Parameters + ---------- + kernel_size: int + The size in pixels of the blur kernel + + Returns + ------- + tuple + The kernel size as a tuple of ('int', 'int') + """ + kernel_size += 1 if kernel_size % 2 == 0 else 0 + retval = (kernel_size, kernel_size) + logger.trace(retval) # type:ignore[attr-defined] + return retval + + def _get_kwargs(self) -> dict[str, int | tuple[int, int]]: + """ dict: the valid keyword arguments for the requested :attr:`_blur_type` """ + retval = {kword: self._kwarg_mapping[kword] + for kword in self._kwarg_requirements[self._blur_type]} + logger.trace("BlurMask kwargs: %s", retval) # type:ignore[attr-defined] + return retval + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/alignments.py b/lib/align/alignments.py new file mode 100644 index 0000000000..aedf0c94ee --- /dev/null +++ b/lib/align/alignments.py @@ -0,0 +1,753 @@ +#!/usr/bin/env python3 +""" Alignments file functions for reading, writing and manipulating the data stored in a +serialized alignments file. """ +from __future__ import annotations +import logging +import os +import typing as T +from datetime import datetime + +import numpy as np + +from lib.serializer import get_serializer +from lib.utils import FaceswapError, get_module_objects + +from .thumbnails import Thumbnails +from .updater import (FileStructure, IdentityAndVideoMeta, LandmarkRename, Legacy, ListToNumpy, + MaskCentering, VideoExtension) + +if T.TYPE_CHECKING: + from collections.abc import Generator + from .aligned_face import CenteringType + +logger = logging.getLogger(__name__) +_VERSION = 2.4 +# VERSION TRACKING +# 1.0 - Never really existed. Basically any alignments file prior to version 2.0 +# 2.0 - Implementation of full head extract. Any alignments version below this will have used +# legacy extract +# 2.1 - Alignments data to extracted face PNG header. SHA1 hashes of faces no longer calculated +# or stored in alignments file +# 2.2 - Add support for differently centered masks (i.e. not all masks stored as face centering) +# 2.3 - Add 'identity' key to alignments file. May or may not be populated, to contain vggface2 +# embeddings. Make 'video_meta' key a standard key. Can be unpopulated +# 2.4 - Update video file alignment keys to end in the video extension rather than '.png' + + +# TODO Convert these to Dataclasses +class MaskAlignmentsFileDict(T.TypedDict): + """ Typed Dictionary for storing Masks. """ + mask: bytes + affine_matrix: list[float] | np.ndarray + interpolator: int + stored_size: int + stored_centering: CenteringType + + +class PNGHeaderAlignmentsDict(T.TypedDict): + """ Base Dictionary for storing a single faces' Alignment Information in Alignments files and + PNG Headers. """ + x: int + y: int + w: int + h: int + landmarks_xy: list[list[float]] | np.ndarray + mask: dict[str, MaskAlignmentsFileDict] + identity: dict[str, list[float]] + + +class AlignmentFileDict(PNGHeaderAlignmentsDict): + """ Typed Dictionary for storing a single faces' Alignment Information in alignments files. """ + thumb: np.ndarray | None + + +class PNGHeaderSourceDict(T.TypedDict): + """ Dictionary for storing additional meta information in PNG headers """ + alignments_version: float + original_filename: str + face_index: int + source_filename: str + source_is_video: bool + source_frame_dims: tuple[int, int] | None + + +class AlignmentDict(T.TypedDict): + """ Dictionary for holding all of the alignment information within a single alignment file """ + faces: list[AlignmentFileDict] + video_meta: dict[str, float | int] + + +class PNGHeaderDict(T.TypedDict): + """ Dictionary for storing all alignment and meta information in PNG Headers """ + alignments: PNGHeaderAlignmentsDict + source: PNGHeaderSourceDict + + +class Alignments(): # pylint:disable=too-many-public-methods + """ The alignments file is a custom serialized ``.fsa`` file that holds information for each + frame for a video or series of images. + + Specifically, it holds a list of faces that appear in each frame. Each face contains + information detailing their detected bounding box location within the frame, the 68 point + facial landmarks and any masks that have been extracted. + + Additionally it can also hold video meta information (timestamp and whether a frame is a + key frame.) + + Parameters + ---------- + folder: str + The folder that contains the alignments ``.fsa`` file + filename: str, optional + The filename of the ``.fsa`` alignments file. If not provided then the given folder will be + checked for a default alignments file filename. Default: "alignments" + """ + def __init__(self, folder: str, filename: str = "alignments") -> None: + logger.debug("Initializing %s: (folder: '%s', filename: '%s')", + self.__class__.__name__, folder, filename) + self._io = _IO(self, folder, filename) + self._data = self._load() + self._io.update_legacy() + + self._legacy = Legacy(self) + self._thumbnails = Thumbnails(self) + logger.debug("Initialized %s", self.__class__.__name__) + + # << PROPERTIES >> # + + @property + def frames_count(self) -> int: + """ int: The number of frames that appear in the alignments :attr:`data`. """ + retval = len(self._data) + logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def faces_count(self) -> int: + """ int: The total number of faces that appear in the alignments :attr:`data`. """ + retval = sum(len(val["faces"]) for val in self._data.values()) + logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def file(self) -> str: + """ str: The full path to the currently loaded alignments file. """ + return self._io.file + + @property + def data(self) -> dict[str, AlignmentDict]: + """ dict: The loaded alignments :attr:`file` in dictionary form. """ + return self._data + + @property + def have_alignments_file(self) -> bool: + """ bool: ``True`` if an alignments file exists at location :attr:`file` otherwise + ``False``. """ + return self._io.have_alignments_file + + @property + def hashes_to_frame(self) -> dict[str, dict[str, int]]: + """ dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame + that the hash corresponds to. + + Notes + ----- + This method is depractated and exists purely for updating legacy hash based alignments + to new png header storage in :class:`lib.align.update_legacy_png_header`. + """ + return self._legacy.hashes_to_frame + + @property + def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]: + """ dict: The SHA1 hash of the face mapped to the alignment for the face that the hash + corresponds to. The structure of the dictionary is: + + Notes + ----- + This method is depractated and exists purely for updating legacy hash based alignments + to new png header storage in :class:`lib.align.update_legacy_png_header`. + """ + return self._legacy.hashes_to_alignment + + @property + def mask_summary(self) -> dict[str, int]: + """ dict: The mask type names stored in the alignments :attr:`data` as key with the number + of faces which possess the mask type as value. """ + masks: dict[str, int] = {} + for val in self._data.values(): + for face in val["faces"]: + if face.get("mask", None) is None: + masks["none"] = masks.get("none", 0) + 1 + for key in face.get("mask", {}): + masks[key] = masks.get(key, 0) + 1 + return masks + + @property + def video_meta_data(self) -> dict[str, list[int] | list[float] | None]: + """ dict: The frame meta data stored in the alignments file. If data does not exist in the + alignments file then ``None`` is returned for each Key """ + retval: dict[str, list[int] | list[float] | None] = {"pts_time": None, "keyframes": None} + pts_time: list[float] = [] + keyframes: list[int] = [] + for idx, key in enumerate(sorted(self.data)): + if not self.data[key].get("video_meta", {}): + return retval + meta = self.data[key]["video_meta"] + pts_time.append(T.cast(float, meta["pts_time"])) + if meta["keyframe"]: + keyframes.append(idx) + retval = {"pts_time": pts_time, "keyframes": keyframes} + return retval + + @property + def thumbnails(self) -> Thumbnails: + """ :class:`~lib.align.thumbnails.Thumbnails`: The low resolution thumbnail images that + exist within the alignments file """ + return self._thumbnails + + @property + def version(self) -> float: + """ float: The alignments file version number. """ + return self._io.version + + def _load(self) -> dict[str, AlignmentDict]: + """ Load the alignments data from the serialized alignments :attr:`file`. + + Populates :attr:`_version` with the alignment file's loaded version as well as returning + the serialized data. + + Returns + ------- + dict: + The loaded alignments data + """ + return self._io.load() + + def save(self) -> None: + """ Write the contents of :attr:`data` and :attr:`_meta` to a serialized ``.fsa`` file at + the location :attr:`file`. """ + return self._io.save() + + def backup(self) -> None: + """ Create a backup copy of the alignments :attr:`file`. + + Creates a copy of the serialized alignments :attr:`file` appending a + timestamp onto the end of the file name and storing in the same folder as + the original :attr:`file`. + """ + return self._io.backup() + + def save_video_meta_data(self, pts_time: list[float], keyframes: list[int]) -> None: + """ Save video meta data to the alignments file. + + If the alignments file does not have an entry for every frame (e.g. if Extract Every N + was used) then the frame is added to the alignments file with no faces, so that they video + meta data can be stored. + + Parameters + ---------- + pts_time: list + A list of presentation timestamps (`float`) in frame index order for every frame in + the input video + keyframes: list + A list of frame indices corresponding to the key frames in the input video + """ + if pts_time[0] != 0: + pts_time, keyframes = self._pad_leading_frames(pts_time, keyframes) + + sample_filename = next(fname for fname in self.data) + basename = sample_filename[:sample_filename.rfind("_")] + ext = os.path.splitext(sample_filename)[-1] + logger.debug("sample filename: '%s', base filename: '%s' extension: '%s'", + sample_filename, basename, ext) + logger.info("Saving video meta information to Alignments file") + + for idx, pts in enumerate(pts_time): + meta: dict[str, float | int] = {"pts_time": pts, "keyframe": idx in keyframes} + key = f"{basename}_{idx + 1:06d}{ext}" + if key not in self.data: + self.data[key] = {"video_meta": meta, "faces": []} + else: + self.data[key]["video_meta"] = meta + + logger.debug("Alignments count: %s, timestamp count: %s", len(self.data), len(pts_time)) + if len(self.data) != len(pts_time): + raise FaceswapError( + "There is a mismatch between the number of frames found in the video file " + f"({len(pts_time)}) and the number of frames found in the alignments file " + f"({len(self.data)}).\nThis can be caused by a number of issues:" + "\n - The video has a Variable Frame Rate and FFMPEG is having a hard time " + "calculating the correct number of frames." + "\n - You are working with a Merged Alignments file. This is not supported for " + "your current use case." + "\nYou should either extract the video to individual frames, re-encode the " + "video at a constant frame rate and re-run extraction or work with a dedicated " + "alignments file for your requested video.") + self._io.save() + + @classmethod + def _pad_leading_frames(cls, pts_time: list[float], keyframes: list[int]) -> tuple[list[float], + list[int]]: + """ Calculate the number of frames to pad the video by when the first frame is not + a key frame. + + A somewhat crude method by obtaining the gaps between existing frames and calculating + how many frames should be inserted at the beginning based on the first presentation + timestamp. + + Parameters + ---------- + pts_time: list + A list of presentation timestamps (`float`) in frame index order for every frame in + the input video + keyframes: list + A list of keyframes (`int`) for the input video + + Returns + ------- + tuple + The presentation time stamps with extra frames padded to the beginning and the + keyframes adjusted to include the new frames + """ + start_pts = pts_time[0] + logger.debug("Video not cut on keyframe. Start pts: %s", start_pts) + gaps: list[float] = [] + prev_time = None + for item in pts_time: + if prev_time is not None: + gaps.append(item - prev_time) + prev_time = item + data_points = len(gaps) + avg_gap = sum(gaps) / data_points + frame_count = int(round(start_pts / avg_gap)) + pad_pts = [avg_gap * i for i in range(frame_count)] + logger.debug("data_points: %s, avg_gap: %s, frame_count: %s, pad_pts: %s", + data_points, avg_gap, frame_count, pad_pts) + pts_time = pad_pts + pts_time + keyframes = [i + frame_count for i in keyframes] + return pts_time, keyframes + + # << VALIDATION >> # + def frame_exists(self, frame_name: str) -> bool: + """ Check whether a given frame_name exists within the alignments :attr:`data`. + + Parameters + ---------- + frame_name: str + The frame name to check. This should be the base name of the frame, not the full path + + Returns + ------- + bool + ``True`` if the given frame_name exists within the alignments :attr:`data` + otherwise ``False`` + """ + retval = frame_name in self._data.keys() + logger.trace("'%s': %s", frame_name, retval) # type:ignore[attr-defined] + return retval + + def frame_has_faces(self, frame_name: str) -> bool: + """ Check whether a given frame_name exists within the alignments :attr:`data` and contains + at least 1 face. + + Parameters + ---------- + frame_name: str + The frame name to check. This should be the base name of the frame, not the full path + + Returns + ------- + bool + ``True`` if the given frame_name exists within the alignments :attr:`data` and has at + least 1 face associated with it, otherwise ``False`` + """ + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) + retval = bool(frame_data.get("faces", [])) + logger.trace("'%s': %s", frame_name, retval) # type:ignore[attr-defined] + return retval + + def frame_has_multiple_faces(self, frame_name: str) -> bool: + """ Check whether a given frame_name exists within the alignments :attr:`data` and contains + more than 1 face. + + Parameters + ---------- + frame_name: str + The frame_name name to check. This should be the base name of the frame, not the full + path + + Returns + ------- + bool + ``True`` if the given frame_name exists within the alignments :attr:`data` and has more + than 1 face associated with it, otherwise ``False`` + """ + if not frame_name: + retval = False + else: + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) + retval = bool(len(frame_data.get("faces", [])) > 1) + logger.trace("'%s': %s", frame_name, retval) # type:ignore[attr-defined] + return retval + + def mask_is_valid(self, mask_type: str) -> bool: + """ Ensure the given ``mask_type`` is valid for the alignments :attr:`data`. + + Every face in the alignments :attr:`data` must have the given mask type to successfully + pass the test. + + Parameters + ---------- + mask_type: str + The mask type to check against the current alignments :attr:`data` + + Returns + ------- + bool: + ``True`` if all faces in the current alignments possess the given ``mask_type`` + otherwise ``False`` + """ + retval = all((face.get("mask") is not None and + face["mask"].get(mask_type) is not None) + for val in self._data.values() + for face in val["faces"]) + logger.debug(retval) + return retval + + # << DATA >> # + def get_faces_in_frame(self, frame_name: str) -> list[AlignmentFileDict]: + """ Obtain the faces from :attr:`data` associated with a given frame_name. + + Parameters + ---------- + frame_name: str + The frame name to return faces for. This should be the base name of the frame, not the + full path + + Returns + ------- + list + The list of face dictionaries that appear within the requested frame_name + """ + logger.trace("Getting faces for frame_name: '%s'", frame_name) # type:ignore[attr-defined] + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) + return frame_data.get("faces", T.cast(list[AlignmentFileDict], [])) + + def count_faces_in_frame(self, frame_name: str) -> int: + """ Return number of faces that appear within :attr:`data` for the given frame_name. + + Parameters + ---------- + frame_name: str + The frame name to return the count for. This should be the base name of the frame, not + the full path + + Returns + ------- + int + The number of faces that appear in the given frame_name + """ + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) + retval = len(frame_data.get("faces", [])) + logger.trace(retval) # type:ignore[attr-defined] + return retval + + # << MANIPULATION >> # + def delete_face_at_index(self, frame_name: str, face_index: int) -> bool: + """ Delete the face for the given frame_name at the given face index from :attr:`data`. + + Parameters + ---------- + frame_name: str + The frame name to remove the face from. This should be the base name of the frame, not + the full path + face_index: int + The index number of the face within the given frame_name to remove + + Returns + ------- + bool + ``True`` if a face was successfully deleted otherwise ``False`` + """ + logger.debug("Deleting face %s for frame_name '%s'", face_index, frame_name) + face_index = int(face_index) + if face_index + 1 > self.count_faces_in_frame(frame_name): + logger.debug("No face to delete: (frame_name: '%s', face_index %s)", + frame_name, face_index) + return False + del self._data[frame_name]["faces"][face_index] + logger.debug("Deleted face: (frame_name: '%s', face_index %s)", frame_name, face_index) + return True + + def add_face(self, frame_name: str, face: AlignmentFileDict) -> int: + """ Add a new face for the given frame_name in :attr:`data` and return it's index. + + Parameters + ---------- + frame_name: str + The frame name to add the face to. This should be the base name of the frame, not the + full path + face: dict + The face information to add to the given frame_name, correctly formatted for storing in + :attr:`data` + + Returns + ------- + int + The index of the newly added face within :attr:`data` for the given frame_name + """ + logger.debug("Adding face to frame_name: '%s'", frame_name) + if frame_name not in self._data: + self._data[frame_name] = {"faces": [], "video_meta": {}} + self._data[frame_name]["faces"].append(face) + retval = self.count_faces_in_frame(frame_name) - 1 + logger.debug("Returning new face index: %s", retval) + return retval + + def update_face(self, frame_name: str, face_index: int, face: AlignmentFileDict) -> None: + """ Update the face for the given frame_name at the given face index in :attr:`data`. + + Parameters + ---------- + frame_name: str + The frame name to update the face for. This should be the base name of the frame, not + the full path + face_index: int + The index number of the face within the given frame_name to update + face: dict + The face information to update to the given frame_name at the given face_index, + correctly formatted for storing in :attr:`data` + """ + logger.debug("Updating face %s for frame_name '%s'", face_index, frame_name) + self._data[frame_name]["faces"][face_index] = face + + def filter_faces(self, filter_dict: dict[str, list[int]], filter_out: bool = False) -> None: + """ Remove faces from :attr:`data` based on a given filter list. + + Parameters + ---------- + filter_dict: dict + Dictionary of source filenames as key with a list of face indices to filter as value. + filter_out: bool, optional + ``True`` if faces should be removed from :attr:`data` when there is a corresponding + match in the given filter_dict. ``False`` if faces should be kept in :attr:`data` when + there is a corresponding match in the given filter_dict, but removed if there is no + match. Default: ``False`` + """ + logger.debug("filter_dict: %s, filter_out: %s", filter_dict, filter_out) + for source_frame, frame_data in self._data.items(): + face_indices = filter_dict.get(source_frame, []) + if filter_out: + filter_list = face_indices + else: + filter_list = [idx for idx in range(len(frame_data["faces"])) + if idx not in face_indices] + logger.trace("frame: '%s', filter_list: %s", # type:ignore[attr-defined] + source_frame, filter_list) + + for face_idx in reversed(sorted(filter_list)): + logger.verbose( # type:ignore[attr-defined] + "Filtering out face: (filename: %s, index: %s)", source_frame, face_idx) + del frame_data["faces"][face_idx] + + def update_from_dict(self, data: dict[str, AlignmentDict]) -> None: + """ Replace all alignments with the contents of the given dictionary + + Parameters + ---------- + data: dict[str, AlignmentDict] + The alignments, in correctly formatted dictionary form, to be populated into this + :class:`Alignments` + """ + logger.debug("Populating alignments with %s entries", len(data)) + self._data = data + + # << GENERATORS >> # + def yield_faces(self) -> Generator[tuple[str, list[AlignmentFileDict], int, str], None, None]: + """ Generator to obtain all faces with meta information from :attr:`data`. The results + are yielded by frame. + + Notes + ----- + The yielded order is non-deterministic. + + Yields + ------ + frame_name: str + The frame name that the face belongs to. This is the base name of the frame, as it + appears in :attr:`data`, not the full path + faces: list + The list of face `dict` objects that exist for this frame + face_count: int + The number of faces that exist within :attr:`data` for this frame + frame_fullname: str + The full path (folder and filename) for the yielded frame + """ + for frame_fullname, val in self._data.items(): + frame_name = os.path.splitext(frame_fullname)[0] + face_count = len(val["faces"]) + logger.trace( # type:ignore[attr-defined] + "Yielding: (frame: '%s', faces: %s, frame_fullname: '%s')", + frame_name, face_count, frame_fullname) + yield frame_name, val["faces"], face_count, frame_fullname + + def update_legacy_has_source(self, filename: str) -> None: + """ Update legacy alignments files when we have the source filename available. + + Updates here can only be performed when we have the source filename + + Parameters + ---------- + filename: str: + The filename/folder of the original source images/video for the current alignments + """ + updates = [updater.is_updated for updater in (VideoExtension(self, filename), )] + if any(updates): + self._io.update_version() + self.save() + + +class _IO(): + """ Class to handle the saving/loading of an alignments file. + + Parameters + ---------- + alignments: :class:'~Alignments` + The parent alignments class that these IO operations belong to + folder: str + The folder that contains the alignments ``.fsa`` file + filename: str + The filename of the ``.fsa`` alignments file. + """ + def __init__(self, alignments: Alignments, folder: str, filename: str) -> None: + logger.debug("Initializing %s: (alignments: %s)", self.__class__.__name__, alignments) + self._alignments = alignments + self._serializer = get_serializer("compressed") + self._file = self._get_location(folder, filename) + self._version: float = _VERSION + + @property + def file(self) -> str: + """ str: The full path to the currently loaded alignments file. """ + return self._file + + @property + def version(self) -> float: + """ float: The alignments file version number. """ + return self._version + + @property + def have_alignments_file(self) -> bool: + """ bool: ``True`` if an alignments file exists at location :attr:`file` otherwise + ``False``. """ + retval = os.path.exists(self._file) + logger.trace(retval) # type:ignore[attr-defined] + return retval + + def _get_location(self, folder: str, filename: str) -> str: + """ Obtains the location of an alignments file. + + Parameters + ---------- + folder: str + The folder that the alignments file is located in + filename: str + The filename of the alignments file + + Returns + ------- + str + The full path to the alignments file + """ + logger.debug("Getting location: (folder: '%s', filename: '%s')", folder, filename) + noext_name, extension = os.path.splitext(filename) + if extension[1:] == self._serializer.file_extension: + logger.debug("Valid Alignments filename provided: '%s'", filename) + else: + filename = f"{noext_name}.{self._serializer.file_extension}" + logger.debug("File extension set from serializer: '%s'", + self._serializer.file_extension) + location = os.path.join(str(folder), filename) + + logger.verbose("Alignments filepath: '%s'", location) # type:ignore[attr-defined] + return location + + def update_legacy(self) -> None: + """ Check whether the alignments are legacy, and if so update them to current alignments + format. """ + updates = [updater.is_updated for updater in (FileStructure(self._alignments), + LandmarkRename(self._alignments), + ListToNumpy(self._alignments), + MaskCentering(self._alignments), + IdentityAndVideoMeta(self._alignments))] + if any(updates): + self.update_version() + self.save() + + def update_version(self) -> None: + """ Update the version of the alignments file to the latest version """ + self._version = _VERSION + logger.info("Updating alignments file to version %s", self._version) + + def load(self) -> dict[str, AlignmentDict]: + """ Load the alignments data from the serialized alignments :attr:`file`. + + Populates :attr:`_version` with the alignment file's loaded version as well as returning + the serialized data. + + Returns + ------- + dict: + The loaded alignments data + """ + logger.debug("Loading alignments") + if not self.have_alignments_file: + raise FaceswapError(f"Error: Alignments file not found at {self._file}") + + logger.info("Reading alignments from: '%s'", self._file) + data = self._serializer.load(self._file) + meta = data.get("__meta__", {"version": 1.0}) + self._version = meta["version"] + data = data.get("__data__", data) + logger.debug("Loaded alignments") + return data + + def save(self) -> None: + """ Write the contents of :attr:`data` and :attr:`_meta` to a serialized ``.fsa`` file at + the location :attr:`file`. """ + logger.debug("Saving alignments") + logger.info("Writing alignments to: '%s'", self._file) + data = {"__meta__": {"version": self._version}, + "__data__": self._alignments.data} + self._serializer.save(self._file, data) + logger.debug("Saved alignments") + + def backup(self) -> None: + """ Create a backup copy of the alignments :attr:`file`. + + Creates a copy of the serialized alignments :attr:`file` appending a + timestamp onto the end of the file name and storing in the same folder as + the original :attr:`file`. + """ + logger.debug("Backing up alignments") + if not os.path.isfile(self._file): + logger.debug("No alignments to back up") + return + now = datetime.now().strftime("%Y%m%d_%H%M%S") + src = self._file + split = os.path.splitext(src) + dst = f"{split[0]}_{now}{split[1]}" + idx = 1 + while True: + if not os.path.exists(dst): + break + logger.debug("Backup file %s exists. Incrementing", dst) + dst = f"{split[0]}_{now}({idx}){split[1]}" + idx += 1 + + logger.info("Backing up original alignments to '%s'", dst) + os.rename(src, dst) + logger.debug("Backed up alignments") + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/constants.py b/lib/align/constants.py new file mode 100644 index 0000000000..27f4eb51bf --- /dev/null +++ b/lib/align/constants.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +""" Constants that are required across faceswap's lib.align package """ +from __future__ import annotations + +import typing as T +from enum import Enum + +import numpy as np + +from lib.utils import get_module_objects + +CenteringType = T.Literal["face", "head", "legacy"] + +EXTRACT_RATIOS: dict[CenteringType, float] = {"legacy": 0.375, "face": 0.5, "head": 0.625} +"""dict[Literal["legacy", "face", head"] float]: The amount of padding applied to each +centering type when generating aligned faces """ + + +class LandmarkType(Enum): + """ Enumeration for the landmark types that Faceswap supports """ + LM_2D_4 = 1 + LM_2D_51 = 2 + LM_2D_68 = 3 + LM_3D_26 = 4 + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> LandmarkType: + """ The landmark type for a given shape + + Parameters + ---------- + shape: tuple[int, ...] + The shape to get the landmark type for + + Returns + ------- + Type[LandmarkType] + The enum for the given shape + + Raises + ------ + ValueError + If the requested shape is not valid + """ + shapes: dict[tuple[int, ...], LandmarkType] = {(4, 2): cls.LM_2D_4, + (51, 2): cls.LM_2D_51, + (68, 2): cls.LM_2D_68, + (26, 3): cls.LM_3D_26} + if shape not in shapes: + raise ValueError(f"The given shape {shape} is not valid. Valid shapes: {list(shapes)}") + return shapes[shape] + + +_MEAN_FACE: dict[LandmarkType, np.ndarray] = { + LandmarkType.LM_2D_4: np.array( + [[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]), # Clockwise from TL + LandmarkType.LM_2D_51: np.array([ + [0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748], [0.300643, 0.034489], + [0.403270, 0.077391], [0.596729, 0.077391], [0.699356, 0.034489], [0.808997, 0.018748], + [0.914864, 0.038915], [0.989913, 0.106454], [0.500000, 0.203352], [0.500000, 0.307009], + [0.500000, 0.409805], [0.500000, 0.515625], [0.376753, 0.587326], [0.435909, 0.609345], + [0.500000, 0.628106], [0.564090, 0.609345], [0.623246, 0.587326], [0.131610, 0.216423], + [0.196995, 0.178758], [0.275698, 0.179852], [0.344479, 0.231733], [0.270791, 0.245099], + [0.192616, 0.244077], [0.655520, 0.231733], [0.724301, 0.179852], [0.803005, 0.178758], + [0.868389, 0.216423], [0.807383, 0.244077], [0.729208, 0.245099], [0.264022, 0.780233], + [0.350858, 0.745405], [0.438731, 0.727388], [0.500000, 0.742578], [0.561268, 0.727388], + [0.649141, 0.745405], [0.735977, 0.780233], [0.652032, 0.864805], [0.566594, 0.902192], + [0.500000, 0.909281], [0.433405, 0.902192], [0.347967, 0.864805], [0.300252, 0.784792], + [0.437969, 0.778746], [0.500000, 0.785343], [0.562030, 0.778746], [0.699747, 0.784792], + [0.563237, 0.824182], [0.500000, 0.831803], [0.436763, 0.824182]]), + LandmarkType.LM_3D_26: np.array([ + [4.056931, -11.432347, 1.636229], # 8 chin LL + [1.833492, -12.542305, 4.061275], # 7 chin L + [0.0, -12.901019, 4.070434], # 6 chin C + [-1.833492, -12.542305, 4.061275], # 5 chin R + [-4.056931, -11.432347, 1.636229], # 4 chin RR + [6.825897, 1.275284, 4.402142], # 33 L eyebrow L + [1.330353, 1.636816, 6.903745], # 29 L eyebrow R + [-1.330353, 1.636816, 6.903745], # 34 R eyebrow L + [-6.825897, 1.275284, 4.402142], # 38 R eyebrow R + [1.930245, -5.060977, 5.914376], # 54 nose LL + [0.746313, -5.136947, 6.263227], # 53 nose L + [0.0, -5.485328, 6.76343], # 52 nose C + [-0.746313, -5.136947, 6.263227], # 51 nose R + [-1.930245, -5.060977, 5.914376], # 50 nose RR + [5.311432, 0.0, 3.987654], # 13 L eye L + [1.78993, -0.091703, 4.413414], # 17 L eye R + [-1.78993, -0.091703, 4.413414], # 25 R eye L + [-5.311432, 0.0, 3.987654], # 21 R eye R + [2.774015, -7.566103, 5.048531], # 43 mouth L + [0.509714, -7.056507, 6.566167], # 42 mouth top L + [0.0, -7.131772, 6.704956], # 41 mouth top C + [-0.509714, -7.056507, 6.566167], # 40 mouth top R + [-2.774015, -7.566103, 5.048531], # 39 mouth R + [-0.589441, -8.443925, 6.109526], # 46 mouth bottom R + [0.0, -8.601736, 6.097667], # 45 mouth bottom C + [0.589441, -8.443925, 6.109526]])} # 44 mouth bottom L +"""dict[:class:`~LandmarkType, np.ndarray]: 'Mean' landmark points for various landmark types. Used +for aligning faces """ + +LANDMARK_PARTS: dict[LandmarkType, dict[str, tuple[int, int, bool]]] = { + LandmarkType.LM_2D_68: {"mouth_outer": (48, 60, True), + "mouth_inner": (60, 68, True), + "right_eyebrow": (17, 22, False), + "left_eyebrow": (22, 27, False), + "right_eye": (36, 42, True), + "left_eye": (42, 48, True), + "nose": (27, 36, False), + "jaw": (0, 17, False), + "chin": (8, 11, False)}, + LandmarkType.LM_2D_4: {"face": (0, 4, True)}} +"""dict[:class:`LandmarkType`, dict[str, tuple[int, int, bool]]: For each landmark type, stores +the (start index, end index, is polygon) information about each part of the face. """ + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/detected_face.py b/lib/align/detected_face.py new file mode 100644 index 0000000000..87835257de --- /dev/null +++ b/lib/align/detected_face.py @@ -0,0 +1,573 @@ +#!/usr/bin python3 +""" Face and landmarks detection for faceswap.py """ +from __future__ import annotations +import logging +import os +import typing as T + +from hashlib import sha1 +from zlib import compress, decompress + +import numpy as np + +from lib.image import encode_image, read_image +from lib.logger import parse_class_init +from lib.utils import FaceswapError, get_module_objects +from .alignments import (Alignments, AlignmentFileDict, PNGHeaderAlignmentsDict, + PNGHeaderDict, PNGHeaderSourceDict) +from .aligned_face import AlignedFace +from .aligned_mask import LandmarksMask, Mask +from .constants import LANDMARK_PARTS + +if T.TYPE_CHECKING: + from .aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +class DetectedFace(): # pylint:disable=too-many-instance-attributes + """ Detected face and landmark information + + Holds information about a detected face, it's location in a source image + and the face's 68 point landmarks. + + Methods for aligning a face are also callable from here. + + Parameters + ---------- + image : :class:`numpy.ndarray` | None, optional + Original frame that holds this face. Optional (not required if just storing coordinates). + Default: ``None`` + left : int + The left most point (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` + width : int + The width (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` + top : int + The top most point (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` + height : int + The height (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` + landmarks_xy : :class:`numpy.ndarray` + The 68 point landmarks as discovered in :mod:`plugins.extract.align`. Should be an array + of 68 `(x, y)` points of each of the landmark co-ordinates. + mask : dict[str: :class:`~lib.align.aligned_mask.Mask`] + The generated mask(s) for the face as generated in :mod:`plugins.extract.mask`. + """ + def __init__(self, + image: np.ndarray | None = None, + left: int | None = None, + width: int | None = None, + top: int | None = None, + height: int | None = None, + landmarks_xy: np.ndarray | None = None, + mask: dict[str, Mask] | None = None) -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self.image = image + """ :class:`numpy.ndarray` | None : This is a generic image placeholder that should not be + relied on to be holding a particular image. It may hold the source frame that holds the + face, a cropped face or a scaled image depending on the method using this object. """ + self.left = left + """ int : The left most point (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` """ + self.width = width + """ int : The width (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` """ + self.top = top + """ int : The top most point (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` """ + self.height = height + """ int : The height (in pixels) of the face's bounding box as discovered in + :mod:`plugins.extract.detect` """ + self._landmarks_xy = landmarks_xy + self._identity: dict[str, np.ndarray] = {} + self.thumbnail: np.ndarray | None = None + + self.mask = {} if mask is None else mask + """ dict[str: :class:`~lib.align.aligned_mask.Mask`] : The generated mask(s) for the face + as generated in :mod:`plugins.extract.mask` """ + + self._training_masks: tuple[bytes, tuple[int, int, int]] | None = None + self._aligned: AlignedFace | None = None + logger.trace("Initialized %s", self.__class__.__name__) # type:ignore[attr-defined] + + @property + def aligned(self) -> AlignedFace: + """ :class:`~lib.align.aligned_face.AlignedFace` : The aligned face connected to this + detected face. """ + assert self._aligned is not None + return self._aligned + + @property + def landmarks_xy(self) -> np.ndarray: + """ :class:`numpy.ndarray` : The aligned face connected to this detected face. """ + assert self._landmarks_xy is not None + return self._landmarks_xy + + @property + def right(self) -> int: + """int : Right point (in pixels) of face detection bounding box within the parent image """ + assert self.left is not None and self.width is not None + return self.left + self.width + + @property + def bottom(self) -> int: + """int : Bottom point (in pixels) of face detection bounding box within the parent + image """ + assert self.top is not None and self.height is not None + return self.top + self.height + + @property + def identity(self) -> dict[str, np.ndarray]: + """ dict[str, :class:`numpy.ndarray`] : Identity mechanism as key, identity embedding as + value. """ + return self._identity + + def add_mask(self, + name: str, + mask: np.ndarray, + affine_matrix: np.ndarray, + interpolator: int, + storage_size: int = 128, + storage_centering: CenteringType = "face") -> None: + """ Add a :class:`~lib.align.aligned_mask.Mask` to this detected face + + The mask should be the original output from :mod:`plugins.extract.mask` + If a mask with this name already exists it will be overwritten by the given + mask. + + Parameters + ---------- + name : str + The name of the mask as defined by the :attr:`plugins.extract.mask._base.name` + parameter. + mask : :class:`numpy.ndarray` + The mask that is to be added as output from :mod:`plugins.extract.mask` + It should be in the range 0.0 - 1.0 ideally with a ``dtype`` of ``float32`` + affine_matrix : :class:`numpy.ndarray` + The transformation matrix required to transform the mask to the original frame. + interpolator : int + The CV2 interpolator required to transform this mask to it's original frame. + storage_size : int, optional + The size the mask is to be stored at. Default: 128 + storage_centering : Literal["face", "head", "legacy"], optional: + The centering to store the mask at. One of `"legacy"`, `"face"`, `"head"`. + Default: `"face"` + """ + logger.trace("name: '%s', mask shape: %s, affine_matrix: %s, " # type:ignore[attr-defined] + "interpolator: %s, storage_size: %s, storage_centering: %s)", name, + mask.shape, affine_matrix, interpolator, storage_size, storage_centering) + fsmask = Mask(storage_size=storage_size, storage_centering=storage_centering) + fsmask.add(mask, affine_matrix, interpolator) + self.mask[name] = fsmask + + def add_landmarks_xy(self, landmarks: np.ndarray) -> None: + """ Add landmarks to the detected face object. If landmarks alread exist, they will be + overwritten. + + Parameters + ---------- + landmarks : :class:`numpy.ndarray` + The 68 point face landmarks to add for the face + """ + logger.trace("landmarks shape: '%s'", landmarks.shape) # type:ignore[attr-defined] + self._landmarks_xy = landmarks + + def add_identity(self, name: str, embedding: np.ndarray, ) -> None: + """ Add an identity embedding to this detected face. If an identity already exists for the + given :attr:`name` it will be overwritten + + Parameters + ---------- + name : str + The name of the mechanism that calculated the identity + embedding : :class:`numpy.ndarray` + The identity embedding + """ + logger.trace("name: '%s', embedding shape: %s", # type:ignore[attr-defined] + name, embedding.shape) + assert name == "vggface2" + assert embedding.shape[0] == 512 + self._identity[name] = embedding + + def clear_all_identities(self) -> None: + """ Remove all stored identity embeddings """ + self._identity = {} + + def get_landmark_mask(self, + area: T.Literal["eye", "face", "mouth"], + blur_kernel: int, + dilation: float) -> np.ndarray: + """ Add a :class:`L~lib.align.aligned_mask.LandmarksMask` to this detected face + + Landmark based masks are generated from face Aligned Face landmark points. An aligned + face must be loaded. As the data is coming from the already aligned face, no further mask + cropping is required. + + Parameters + ---------- + area : Literal["face", "mouth", "eye"] + The type of mask to obtain. `face` is a full face mask the others are masks for those + specific areas + blur_kernel : int + The size of the kernel for blurring the mask edges + dilation : float + The amount of dilation to apply to the mask. as a percentage of the mask size + + Returns + ------- + :class:`numpy.ndarray` + The generated landmarks mask for the selected area + + Raises + ------ + :class:`lib.utils.FaceSwapError` + If the aligned face does not contain the correct landmarks to generate a landmark mask + """ + # TODO Face mask generation from landmarks + logger.trace("area: %s, dilation: %s", area, dilation) # type:ignore[attr-defined] + + lm_type = self.aligned.landmark_type + if lm_type not in LANDMARK_PARTS: + raise FaceswapError(f"Landmark based masks cannot be created for {lm_type.name}") + + lm_parts = LANDMARK_PARTS[self.aligned.landmark_type] + mapped = {"mouth": ["mouth_outer"], "eye": ["right_eye", "left_eye"]} + if not all(part in lm_parts for parts in mapped.values() for part in parts): + raise FaceswapError(f"Landmark based masks cannot be created for {lm_type.name}") + + areas = {key: [slice(*lm_parts[v][:2]) for v in val]for key, val in mapped.items()} + points = [self.aligned.landmarks[zone] for zone in areas[area]] + + lmmask = LandmarksMask(points, + storage_size=self.aligned.size, + storage_centering=self.aligned.centering, + dilation=dilation) + lmmask.set_blur_and_threshold(blur_kernel=blur_kernel) + lmmask.generate_mask( + self.aligned.adjusted_matrix, + self.aligned.interpolators[1]) + return lmmask.mask + + def store_training_masks(self, + masks: list[np.ndarray | None], + delete_masks: bool = False) -> None: + """ Concatenate and compress the given training masks and store for retrieval. + + Parameters + ---------- + masks : list[:class:`numpy.ndarray` | None] + A list of training mask. Must be all be uint-8 3D arrays of the same size in + 0-255 range + delete_masks : bool, optional + ``True`` to delete any of the :class:`~lib.align.aligned_mask.Mask` objects owned by + this detected face. Use to free up unrequired memory usage. Default: ``False`` + """ + if delete_masks: + del self.mask + self.mask = {} + + valid = [msk for msk in masks if msk is not None] + if not valid: + return + combined = np.concatenate(valid, axis=-1) + self._training_masks = (compress(combined), combined.shape) + + def get_training_masks(self) -> np.ndarray | None: + """ Obtain the decompressed combined training masks. + + Returns + ------- + :class:`numpy.ndarray` + A 3D array containing the decompressed training masks as uint8 in 0-255 range if + training masks are present otherwise ``None`` + """ + if not self._training_masks: + return None + return np.frombuffer(decompress(self._training_masks[0]), + dtype="uint8").reshape(self._training_masks[1]) + + def to_alignment(self) -> AlignmentFileDict: + """ Return the detected face formatted for an alignments file + + returns + ------- + alignment : :class:`lib.align.alignments.AlignmentFileDict` + The alignment dict will be returned with the keys ``x``, ``w``, ``y``, ``h``, + ``landmarks_xy``, ``mask``. The additional key ``thumb`` will be provided if the + detected face object contains a thumbnail. + """ + if (self.left is None or self.width is None or self.top is None or self.height is None): + raise AssertionError("Some detected face variables have not been initialized") + alignment = AlignmentFileDict(x=self.left, + w=self.width, + y=self.top, + h=self.height, + landmarks_xy=self.landmarks_xy, + mask={name: mask.to_dict() + for name, mask in self.mask.items()}, + identity={k: v.tolist() for k, v in self._identity.items()}, + thumb=self.thumbnail) + logger.trace("Returning: %s", alignment) # type:ignore[attr-defined] + return alignment + + def from_alignment(self, alignment: AlignmentFileDict, + image: np.ndarray | None = None, with_thumb: bool = False) -> None: + """ Set the attributes of this class from an alignments file and optionally load the face + into the ``image`` attribute. + + Parameters + ---------- + alignment : :class:`lib.align.alignments.AlignmentFileDict` + A dictionary entry for a face from an alignments file containing the keys + ``x``, ``w``, ``y``, ``h``, ``landmarks_xy``. + Optionally the key ``thumb`` will be provided. This is for use in the manual tool and + contains the compressed jpg thumbnail of the face to be allocated to :attr:`thumbnail. + Optionally the key ``mask`` will be provided, but legacy alignments will not have + this key. + image : :class:`numpy.ndarray`, optional + If an image is passed in, then the ``image`` attribute will + be set to the cropped face based on the passed in bounding box co-ordinates + with_thumb : bool, optional + Whether to load the jpg thumbnail into the detected face object, if provided. + Default: ``False`` + """ + + logger.trace("Creating from alignment: (alignment: %s," # type:ignore[attr-defined] + " has_image: %s)", alignment, bool(image is not None)) + self.left = alignment["x"] + self.width = alignment["w"] + self.top = alignment["y"] + self.height = alignment["h"] + landmarks = alignment["landmarks_xy"] + if not isinstance(landmarks, np.ndarray): + landmarks = np.array(landmarks, dtype="float32") + self._identity = {T.cast(T.Literal["vggface2"], k): np.array(v, dtype="float32") + for k, v in alignment.get("identity", {}).items()} + self._landmarks_xy = landmarks.copy() + + if with_thumb: + # Thumbnails currently only used for manual tool. Default to None + self.thumbnail = alignment.get("thumb") + # Manual tool and legacy alignments will not have a mask + self._aligned = None + + if alignment.get("mask", None) is not None: + self.mask = {} + for name, mask_dict in alignment["mask"].items(): + self.mask[name] = Mask() + self.mask[name].from_dict(mask_dict) + if image is not None and image.any(): + self._image_to_face(image) + logger.trace("Created from alignment: (left: %s, width: %s, " # type:ignore[attr-defined] + "top: %s, height: %s, landmarks: %s, mask: %s)", + self.left, self.width, self.top, self.height, self.landmarks_xy, self.mask) + + def to_png_meta(self) -> PNGHeaderAlignmentsDict: + """ Return the detected face formatted for insertion into a png itxt header. + + Returns + ------- + :class:`lib.align.alignments.PNGHeaderAlignmentsDict` + The alignments dict will be returned with the keys ``x``, ``w``, ``y``, ``h``, + ``landmarks_xy`` and ``mask`` + """ + if (self.left is None or self.width is None or self.top is None or self.height is None): + raise AssertionError("Some detected face variables have not been initialized") + alignment = PNGHeaderAlignmentsDict( + x=self.left, + w=self.width, + y=self.top, + h=self.height, + landmarks_xy=self.landmarks_xy.tolist(), + mask={name: mask.to_png_meta() for name, mask in self.mask.items()}, + identity={k: v.tolist() for k, v in self._identity.items()}) + return alignment + + def from_png_meta(self, alignment: PNGHeaderAlignmentsDict) -> None: + """ Set the attributes of this class from alignments stored in a png exif header. + + Parameters + ---------- + alignment : :class:`lib.align.alignments.PNGHeaderAlignmentsDict` + A dictionary entry for a face from alignments stored in a png exif header containing + the keys ``x``, ``w``, ``y``, ``h``, ``landmarks_xy`` and ``mask`` + """ + self.left = alignment["x"] + self.width = alignment["w"] + self.top = alignment["y"] + self.height = alignment["h"] + self._landmarks_xy = np.array(alignment["landmarks_xy"], dtype="float32") + self.mask = {} + for name, mask_dict in alignment["mask"].items(): + self.mask[name] = Mask() + self.mask[name].from_dict(mask_dict) + self._identity = {} + for key, val in alignment.get("identity", {}).items(): + assert key in ["vggface2"] + self._identity[T.cast(T.Literal["vggface2"], key)] = np.array(val, dtype="float32") + logger.trace("Created from png exif header: (left: %s, " # type:ignore[attr-defined] + "width: %s, top: %s height: %s, landmarks: %s, mask: %s, identity: %s)", + self.left, self.width, self.top, self.height, self.landmarks_xy, self.mask, + {k: v.shape for k, v in self._identity.items()}) + + def _image_to_face(self, image: np.ndarray) -> None: + """ set self.image to be the cropped face from detected bounding box + + Parameters + ---------- + image : class:`numpy.ndarray` + The image to be cropped + """ + logger.trace("Cropping face from image") # type:ignore[attr-defined] + self.image = image[self.top: self.bottom, + self.left: self.right] + + # <<< Aligned Face methods and properties >>> # + def load_aligned(self, + image: np.ndarray | None, + size: int = 256, + dtype: str | None = None, + centering: CenteringType = "head", + coverage_ratio: float = 1.0, + y_offset: float = 0.0, + force: bool = False, + is_aligned: bool = False, + is_legacy: bool = False) -> None: + """ Align a face from a given image. + + Aligning a face is a relatively expensive task and is not required for all uses of + the :class:`~lib.align.DetectedFace` object, so call this function explicitly to + load an aligned face. + + This method plugs into :mod:`lib.align.AlignedFace` to perform face alignment based on this + face's ``landmarks_xy``. If the face has already been aligned, then this function will + return having performed no action. + + Parameters + ---------- + image : :class:`numpy.ndarray` | None, optional + The image that contains the face to be aligned. Default: ``None`` + size : int, optional + The size of the output face in pixels. Default: `256` + dtype : str, optional + Optionally set a ``dtype`` for the final face to be formatted in. Default: ``None`` + centering : Literal["legacy", "face", "head"], optional + The type of extracted face that should be loaded. "legacy" places the nose in the + center of the image (the original method for aligning). "face" aligns for the nose to + be in the center of the face (top to bottom) but the center of the skull for left to + right. "head" aligns for the center of the skull (in 3D space) being the center of the + extracted image, with the crop holding the full head. + Default: `"head"` + coverage_ratio : float, optional + The amount of the aligned image to return. A ratio of 1.0 will return the full contents + of the aligned image. A ratio of 0.5 will return an image of the given size, but will + crop to the central 50%% of the image. Default: `1.0` + y_offset : float, optional + The amount to adjust the aligned face along the y_axis in -1. to 1. range. + Default: `0.0` + force : bool, optional + Force an update of the aligned face, even if it is already loaded. Default: ``False`` + is_aligned : bool, optional + Indicates that the :attr:`image` is an aligned face rather than a frame. + Default: ``False`` + is_legacy : bool, optional + Only used if `is_aligned` is ``True``. ``True`` indicates that the aligned image being + loaded is a legacy extracted face rather than a current head extracted face + + Notes + ----- + This method must be executed to get access to the following a + :class:`lib.align.aligned_face.AlignedFace` object + """ + if self._aligned and not force: + # Don't reload an already aligned face + logger.trace("Skipping alignment calculation for already " # type:ignore[attr-defined] + "aligned face") + else: + logger.trace("Loading aligned face: (size: %s, " # type:ignore[attr-defined] + "dtype: %s)", size, dtype) + self._aligned = AlignedFace(self.landmarks_xy, + image=image, + centering=centering, + size=size, + coverage_ratio=coverage_ratio, + y_offset=y_offset, + dtype=dtype, + is_aligned=is_aligned, + is_legacy=is_aligned and is_legacy) + + +_HASHES_SEEN: dict[str, dict[str, int]] = {} + + +def update_legacy_png_header(filename: str, alignments: Alignments + ) -> PNGHeaderDict | None: + """ Update a legacy extracted face from pre v2.1 alignments by placing the alignment data for + the face in the png exif header for the given filename with the given alignment data. + + If the given file is not a .png then a png is created and the original file is removed + + Parameters + ---------- + filename : str + The image file to update + alignments : :class:`lib.align.alignments.Alignments` + The alignments data the contains the information to store in the image header. This must be + a v2.0 or less alignments file as later versions no longer store the face hash (not + required) + + Returns + ------- + :class:`lib.align.alignments.PNGHeaderDict` + The metadata that has been applied to the given image + """ + if alignments.version > 2.0: + raise FaceswapError("The faces being passed in do not correspond to the given Alignments " + "file. Please double check your sources and try again.") + # Track hashes for multiple files with the same hash. Not the most robust but should be + # effective enough + folder = os.path.dirname(filename) + if folder not in _HASHES_SEEN: + _HASHES_SEEN[folder] = {} + hashes_seen = _HASHES_SEEN[folder] + + in_image = read_image(filename, raise_error=True) + in_hash = sha1(T.cast(bytes, in_image)).hexdigest() + hashes_seen[in_hash] = hashes_seen.get(in_hash, -1) + 1 + + alignment = alignments.hashes_to_alignment.get(in_hash) + if not alignment: + logger.debug("Alignments not found for image: '%s'", filename) + return None + + detected_face = DetectedFace() + detected_face.from_alignment(alignment) + # For dupe hash handling, make sure we get a different filename for repeat hashes + src_fname, face_idx = list(alignments.hashes_to_frame[in_hash].items())[hashes_seen[in_hash]] + orig_filename = f"{os.path.splitext(src_fname)[0]}_{face_idx}.png" + meta = PNGHeaderDict(alignments=detected_face.to_png_meta(), + source=PNGHeaderSourceDict( + alignments_version=alignments.version, + original_filename=orig_filename, + face_index=face_idx, + source_filename=src_fname, + source_is_video=False, # Can't check so set false + source_frame_dims=None)) + + out_filename = f"{os.path.splitext(filename)[0]}.png" # Make sure saved file is png + out_image = encode_image(in_image, ".png", metadata=meta) + + with open(out_filename, "wb") as out_file: + out_file.write(out_image) + + if filename != out_filename: # Remove the old non-png: + logger.debug("Removing replaced face with deprecated extension: '%s'", filename) + os.remove(filename) + + return meta + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/pose.py b/lib/align/pose.py new file mode 100644 index 0000000000..cac8337cfd --- /dev/null +++ b/lib/align/pose.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +""" Holds estimated pose information for a faceswap aligned face """ +from __future__ import annotations + +import logging +import typing as T + +import cv2 +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .constants import _MEAN_FACE, LandmarkType + +logger = logging.getLogger(__name__) + +if T.TYPE_CHECKING: + from .constants import CenteringType + + +class PoseEstimate(): + """ Estimates pose from a generic 3D head model for the given 2D face landmarks. + + Parameters + ---------- + landmarks: :class:`numpy.ndarry` + The original 68 point landmarks aligned to 0.0 - 1.0 range + landmarks_type: :class:`~LandmarksType` + The type of landmarks that are generating this face + + References + ---------- + Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/ + 3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp + """ + _logged_once = False + + def __init__(self, landmarks: np.ndarray, landmarks_type: LandmarkType) -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion + self._xyz_2d: np.ndarray | None = None + + if landmarks_type != LandmarkType.LM_2D_68: + self._log_once("Pose estimation is not available for non-68 point landmarks. Pose and " + "offset data will all be returned as the incorrect value of '0'") + self._landmarks_type = landmarks_type + self._camera_matrix = self._get_camera_matrix() + self._rotation, self._translation = self._solve_pnp(landmarks) + self._offset = self._get_offset() + self._pitch_yaw_roll: tuple[float, float, float] = (0, 0, 0) + logger.trace("Initialized %s", self.__class__.__name__) # type:ignore[attr-defined] + + @property + def xyz_2d(self) -> np.ndarray: + """ :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a + constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """ + if self._xyz_2d is None: + xyz = cv2.projectPoints(np.array([[6., 0., -2.3], + [0., 6., -2.3], + [0., 0., 3.7]]).astype("float32"), + self._rotation, + self._translation, + self._camera_matrix, + self._distortion_coefficients)[0].squeeze() + self._xyz_2d = xyz - self._offset["head"] + return self._xyz_2d + + @property + def offset(self) -> dict[CenteringType, np.ndarray]: + """ dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a + from the center of the face (between the eyes) or center of the head (middle of skull) + rather than the nose area. """ + return self._offset + + @property + def pitch(self) -> float: + """ float: The pitch of the aligned face in eular angles """ + if not any(self._pitch_yaw_roll): + self._get_pitch_yaw_roll() + return self._pitch_yaw_roll[0] + + @property + def yaw(self) -> float: + """ float: The yaw of the aligned face in eular angles """ + if not any(self._pitch_yaw_roll): + self._get_pitch_yaw_roll() + return self._pitch_yaw_roll[1] + + @property + def roll(self) -> float: + """ float: The roll of the aligned face in eular angles """ + if not any(self._pitch_yaw_roll): + self._get_pitch_yaw_roll() + return self._pitch_yaw_roll[2] + + @classmethod + def _log_once(cls, message: str) -> None: + """ Log a warning about unsupported landmarks if a message has not already been logged """ + if cls._logged_once: + return + logger.warning(message) + cls._logged_once = True + + def _get_pitch_yaw_roll(self) -> None: + """ Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """ + proj_matrix = np.zeros((3, 4), dtype="float32") + proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0] + euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1] + self._pitch_yaw_roll = T.cast(tuple[float, float, float], tuple(euler.squeeze())) + logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type:ignore[attr-defined] + + @classmethod + def _get_camera_matrix(cls) -> np.ndarray: + """ Obtain an estimate of the camera matrix based off the original frame dimensions. + + Returns + ------- + :class:`numpy.ndarray` + An estimated camera matrix + """ + focal_length = 4 + camera_matrix = np.array([[focal_length, 0, 0.5], + [0, focal_length, 0.5], + [0, 0, 1]], dtype="double") + logger.trace("camera_matrix: %s", camera_matrix) # type:ignore[attr-defined] + return camera_matrix + + def _solve_pnp(self, landmarks: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ Solve the Perspective-n-Point for the given landmarks. + + Takes 2D landmarks in world space and estimates the rotation and translation vectors + in 3D space. + + Parameters + ---------- + landmarks: :class:`numpy.ndarry` + The original 68 point landmark co-ordinates relating to the original frame + + Returns + ------- + rotation: :class:`numpy.ndarray` + The solved rotation vector + translation: :class:`numpy.ndarray` + The solved translation vector + """ + if self._landmarks_type != LandmarkType.LM_2D_68: + points: np.ndarray = np.empty([]) + rotation = np.array([[0.0], [0.0], [0.0]]) + translation = rotation.copy() + else: + points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34, + 35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]] + _, rotation, translation = cv2.solvePnP(_MEAN_FACE[LandmarkType.LM_3D_26], + points, + self._camera_matrix, + self._distortion_coefficients, + flags=cv2.SOLVEPNP_ITERATIVE) + logger.trace("points: %s, rotation: %s, translation: %s", # type:ignore[attr-defined] + points, rotation, translation) + return rotation, translation + + def _get_offset(self) -> dict[CenteringType, np.ndarray]: + """ Obtain the offset between the original center of the extracted face to the new center + of the head in 2D space. + + Returns + ------- + :class:`numpy.ndarray` + The x, y offset of the new center from the old center. + """ + offset: dict[CenteringType, np.ndarray] = {"legacy": np.array([0.0, 0.0])} + if self._landmarks_type != LandmarkType.LM_2D_68: + offset["face"] = np.array([0.0, 0.0]) + offset["head"] = np.array([0.0, 0.0]) + else: + points: dict[T.Literal["face", "head"], tuple[float, ...]] = {"head": (0.0, 0.0, -2.3), + "face": (0.0, -1.5, 4.2)} + for key, pnts in points.items(): + center = cv2.projectPoints(np.array([pnts]).astype("float32"), + self._rotation, + self._translation, + self._camera_matrix, + self._distortion_coefficients)[0].squeeze() + logger.trace("center %s: %s", key, center) # type:ignore[attr-defined] + offset[key] = center - np.array([0.5, 0.5]) + logger.trace("offset: %s", offset) # type:ignore[attr-defined] + return offset + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/thumbnails.py b/lib/align/thumbnails.py new file mode 100644 index 0000000000..ccdfa1ed56 --- /dev/null +++ b/lib/align/thumbnails.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" Handles the generation of thumbnail jpgs for storing inside an alignments file/png header """ +from __future__ import annotations + +import logging +import typing as T + +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from lib import align + +logger = logging.getLogger(__name__) + + +class Thumbnails(): + """ Thumbnail images stored in the alignments file. + + The thumbnails are stored as low resolution (64px), low quality jpg in the alignments file + and are used for the Manual Alignments tool. + + Parameters + ---------- + alignments: :class:'~lib.align.alignments.Alignments` + The parent alignments class that these thumbs belong to + """ + def __init__(self, alignments: align.alignments.Alignments) -> None: + logger.debug(parse_class_init(locals())) + self._alignments_dict = alignments.data + self._frame_list = list(sorted(self._alignments_dict)) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def has_thumbnails(self) -> bool: + """ bool: ``True`` if all faces in the alignments file contain thumbnail images + otherwise ``False``. """ + retval = all(np.any(T.cast(np.ndarray, face.get("thumb"))) + for frame in self._alignments_dict.values() + for face in frame["faces"]) + logger.trace(retval) # type:ignore[attr-defined] + return retval + + def get_thumbnail_by_index(self, frame_index: int, face_index: int) -> np.ndarray: + """ Obtain a jpg thumbnail from the given frame index for the given face index + + Parameters + ---------- + frame_index: int + The frame index that contains the thumbnail + face_index: int + The face index within the frame to retrieve the thumbnail for + + Returns + ------- + :class:`numpy.ndarray` + The encoded jpg thumbnail + """ + retval = self._alignments_dict[self._frame_list[frame_index]]["faces"][face_index]["thumb"] + assert retval is not None + logger.trace( # type:ignore[attr-defined] + "frame index: %s, face_index: %s, thumb shape: %s", + frame_index, face_index, retval.shape) + return retval + + def add_thumbnail(self, frame: str, face_index: int, thumb: np.ndarray) -> None: + """ Add a thumbnail for the given face index for the given frame. + + Parameters + ---------- + frame: str + The name of the frame to add the thumbnail for + face_index: int + The face index within the given frame to add the thumbnail for + thumb: :class:`numpy.ndarray` + The encoded jpg thumbnail at 64px to add to the alignments file + """ + logger.debug("frame: %s, face_index: %s, thumb shape: %s thumb dtype: %s", + frame, face_index, thumb.shape, thumb.dtype) + self._alignments_dict[frame]["faces"][face_index]["thumb"] = thumb + + +__all__ = get_module_objects(__name__) diff --git a/lib/align/updater.py b/lib/align/updater.py new file mode 100644 index 0000000000..a877656613 --- /dev/null +++ b/lib/align/updater.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +""" Handles updating of an alignments file from an older version to the current version. """ +from __future__ import annotations + +import logging +import os +import typing as T + +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects, VIDEO_EXTENSIONS + +logger = logging.getLogger(__name__) + +if T.TYPE_CHECKING: + from lib import align + + +class _Updater(): + """ Base class for inheriting to test for and update of an alignments file property + + Parameters + ---------- + alignments : :class:`~lib.align.alignments.Alignments` + The alignments object that is being tested and updated + """ + def __init__(self, alignments: align.alignments.Alignments) -> None: + logger.debug(parse_class_init(locals())) + self._alignments = alignments + self._needs_update = self._test() + if self._needs_update: + self._update() + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def is_updated(self) -> bool: + """ bool : ``True`` if this updater has been run otherwise ``False`` """ + return self._needs_update + + def _test(self) -> bool: + """ Calls the child's :func:`test` method and logs output + + Returns + ------- + bool + ``True`` if the test condition is met otherwise ``False`` + """ + logger.debug("checking %s", self.__class__.__name__) + retval = self.test() + logger.debug("legacy %s: %s", self.__class__.__name__, retval) + return retval + + def test(self) -> bool: + """ Override to set the condition to test for. + + Returns + ------- + bool + ``True`` if the test condition is met otherwise ``False`` + """ + raise NotImplementedError() + + def _update(self) -> int: + """ Calls the child's :func:`update` method, logs output and sets the + :attr:`is_updated` flag + + Returns + ------- + int + The number of items that were updated + """ + retval = self.update() + logger.debug("Updated %s: %s", self.__class__.__name__, retval) + return retval + + def update(self) -> int: + """ Override to set the action to perform on the alignments object if the test has + passed + + Returns + ------- + int + The number of items that were updated + """ + raise NotImplementedError() + + +class VideoExtension(_Updater): + """ Alignments files from video files used to have a dummy '.png' extension for each of the + keys. This has been changed to be file extension of the original input video (for better) + identification of alignments files generated from video files + + Parameters + ---------- + alignments : :class:`~lib.align.alignments.Alignments` + The alignments object that is being tested and updated + video_filename : str + The video filename that holds these alignments + """ + def __init__(self, alignments: align.alignments.Alignments, video_filename: str) -> None: + self._video_name, self._extension = os.path.splitext(video_filename) + super().__init__(alignments) + + def test(self) -> bool: + """ Requires update if the extension of the key in the alignment file is not the same + as for the input video file + + Returns + ------- + bool + ``True`` if the key extensions need updating otherwise ``False`` + """ + # Note: Don't check on alignments file version. It's possible that the file gets updated to + # a newer version before this check is run + if self._extension.lower() not in VIDEO_EXTENSIONS: + return False + + exts = set(os.path.splitext(k)[-1] for k in self._alignments.data) + if len(exts) != 1: + logger.debug("Alignments file has multiple key extensions. Skipping") + return False + + if self._extension in exts: + logger.debug("Alignments file contains correct key extensions. Skipping") + return False + + logger.debug("Needs update for video extension (version: %s, extension: %s)", + self._alignments.version, self._extension) + return True + + def update(self) -> int: + """ Update alignments files that have been extracted from videos to have the key end in the + video file extension rather than ',png' (the old way) + + Parameters + ---------- + video_filename : str + The filename of the video file that created these alignments + """ + updated = 0 + for key in list(self._alignments.data): + fname = os.path.splitext(key)[0] + if fname.rsplit("_", maxsplit=1)[0] != self._video_name: + continue # Key is from a different source + + val = self._alignments.data[key] + new_key = f"{fname}{self._extension}" + + del self._alignments.data[key] + self._alignments.data[new_key] = val + + updated += 1 + + logger.debug("Updated alignment keys for video extension: %s", updated) + return updated + + +class FileStructure(_Updater): + """ Alignments were structured: {frame_name: }. We need to be able to store + information at the frame level, so new structure is: {frame_name: {faces: }} + """ + def test(self) -> bool: + """ Test whether the alignments file is laid out in the old structure of + `{frame_name: [faces]}` + + Returns + ------- + bool + ``True`` if the file has legacy structure otherwise ``False`` + """ + return any(isinstance(val, list) for val in self._alignments.data.values()) + + def update(self) -> int: + """ Update legacy alignments files from the format `{frame_name: [faces}` to the + format `{frame_name: {faces: [faces]}`. + + Returns + ------- + int + The number of items that were updated + """ + updated = 0 + for key, val in self._alignments.data.items(): + if not isinstance(val, list): + continue + self._alignments.data[key] = {"faces": val} + updated += 1 + return updated + + +class LandmarkRename(_Updater): + """ Landmarks renamed from landmarksXY to landmarks_xy for PEP compliance """ + def test(self) -> bool: + """ check for legacy landmarksXY keys. + + Returns + ------- + bool + ``True`` if the alignments file contains legacy `landmarksXY` keys otherwise ``False`` + """ + return (any(key == "landmarksXY" + for val in self._alignments.data.values() + for alignment in val["faces"] + for key in alignment)) + + def update(self) -> int: + """ Update legacy `landmarksXY` keys to PEP compliant `landmarks_xy` keys. + + Returns + ------- + int + The number of landmarks keys that were changed + """ + update_count = 0 + for val in self._alignments.data.values(): + for alignment in val["faces"]: + if "landmarksXY" in alignment: + alignment["landmarks_xy"] = alignment.pop("landmarksXY") # type:ignore + update_count += 1 + return update_count + + +class ListToNumpy(_Updater): + """ Landmarks stored as list instead of numpy array """ + def test(self) -> bool: + """ check for legacy landmarks stored as `list` rather than :class:`numpy.ndarray`. + + Returns + ------- + bool + ``True`` if not all landmarks are :class:`numpy.ndarray` otherwise ``False`` + """ + return not all(isinstance(face["landmarks_xy"], np.ndarray) + for val in self._alignments.data.values() + for face in val["faces"]) + + def update(self) -> int: + """ Update landmarks stored as `list` to :class:`numpy.ndarray`. + + Returns + ------- + int + The number of landmarks keys that were changed + """ + update_count = 0 + for val in self._alignments.data.values(): + for alignment in val["faces"]: + test = alignment["landmarks_xy"] + if not isinstance(test, np.ndarray): + alignment["landmarks_xy"] = np.array(test, dtype="float32") + update_count += 1 + return update_count + + +class MaskCentering(_Updater): + """ Masks not containing the stored_centering parameters. Prior to this implementation all + masks were stored with face centering """ + + def test(self) -> bool: + """ Mask centering was introduced in alignments version 2.2 + + Returns + ------- + bool + ``True`` mask centering requires updating otherwise ``False`` + """ + return self._alignments.version < 2.2 + + def update(self) -> int: + """ Add the mask key to the alignment file and update the centering of existing masks + + Returns + ------- + int + The number of masks that were updated + """ + update_count = 0 + for val in self._alignments.data.values(): + for alignment in val["faces"]: + if "mask" not in alignment: + alignment["mask"] = {} + for mask in alignment["mask"].values(): + mask["stored_centering"] = "face" + update_count += 1 + return update_count + + +class IdentityAndVideoMeta(_Updater): + """ Prior to version 2.3 the identity key did not exist and the video_meta key was not + compulsory. These should now both always appear, but do not need to be populated. """ + + def test(self) -> bool: + """ Identity Key was introduced in alignments version 2.3 + + Returns + ------- + bool + ``True`` identity key needs inserting otherwise ``False`` + """ + return self._alignments.version < 2.3 + + # Identity information was not previously stored in the alignments file. + def update(self) -> int: + """ Add the video_meta and identity keys to the alignment file and leave empty + + Returns + ------- + int + The number of keys inserted + """ + update_count = 0 + for val in self._alignments.data.values(): + this_update = 0 + if "video_meta" not in val: + val["video_meta"] = {} + this_update = 1 + for alignment in val["faces"]: + if "identity" not in alignment: + alignment["identity"] = {} + this_update = 1 + update_count += this_update + return update_count + + +class Legacy(): + """ Legacy alignments properties that are no longer used, but are still required for backwards + compatibility/upgrading reasons. + + Parameters + ---------- + alignments : :class:`~lib.align.alignments.Alignments` + The alignments object that requires these legacy properties + """ + def __init__(self, alignments: align.alignments.Alignments) -> None: + self._alignments = alignments + self._hashes_to_frame: dict[str, dict[str, int]] = {} + self._hashes_to_alignment: dict[str, align.alignments.AlignmentFileDict] = {} + + @property + def hashes_to_frame(self) -> dict[str, dict[str, int]]: + """ dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame + that the hash corresponds to. The structure of the dictionary is: + + {**SHA1_hash** (`str`): {**filename** (`str`): **face_index** (`int`)}}. + + Notes + ----- + This method is deprecated and exists purely for updating legacy hash based alignments + to new png header storage in :class:`lib.align.update_legacy_png_header`. + + The first time this property is referenced, the dictionary will be created and cached. + Subsequent references will be made to this cached dictionary. + """ + if not self._hashes_to_frame: + logger.debug("Generating hashes to frame") + for frame_name, val in self._alignments.data.items(): + for idx, face in enumerate(val["faces"]): + self._hashes_to_frame.setdefault( + face["hash"], {})[frame_name] = idx # type:ignore + return self._hashes_to_frame + + @property + def hashes_to_alignment(self) -> dict[str, align.alignments.AlignmentFileDict]: + """ dict: The SHA1 hash of the face mapped to the alignment for the face that the hash + corresponds to. The structure of the dictionary is: + + Notes + ----- + This method is deprecated and exists purely for updating legacy hash based alignments + to new png header storage in :class:`lib.align.update_legacy_png_header`. + + The first time this property is referenced, the dictionary will be created and cached. + Subsequent references will be made to this cached dictionary. + """ + if not self._hashes_to_alignment: + logger.debug("Generating hashes to alignment") + self._hashes_to_alignment = {face["hash"]: face # type:ignore + for val in self._alignments.data.values() + for face in val["faces"]} + return self._hashes_to_alignment + + +__all__ = get_module_objects(__name__) diff --git a/lib/align_eyes.py b/lib/align_eyes.py deleted file mode 100644 index dc8a1ef2d6..0000000000 --- a/lib/align_eyes.py +++ /dev/null @@ -1,71 +0,0 @@ -# Code borrowed from https://github.com/jrosebr1/imutils/blob/d5cb29d02cf178c399210d5a139a821dfb0ae136/imutils/face_utils/helpers.py -""" -The MIT License (MIT) - -Copyright (c) 2015-2016 Adrian Rosebrock, http://www.pyimagesearch.com - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from collections import OrderedDict -import numpy as np -import cv2 - -# define a dictionary that maps the indexes of the facial -# landmarks to specific face regions -FACIAL_LANDMARKS_IDXS = OrderedDict([ - ("mouth", (48, 68)), - ("right_eyebrow", (17, 22)), - ("left_eyebrow", (22, 27)), - ("right_eye", (36, 42)), - ("left_eye", (42, 48)), - ("nose", (27, 36)), - ("jaw", (0, 17)), - ("chin", (8, 11)) -]) - -# Returns a rotation matrix that when applied to the 68 input facial landmarks -# results in landmarks with eyes aligned horizontally -def align_eyes(landmarks, size): - desiredLeftEye = (0.35, 0.35) # (y, x) value - desiredFaceWidth = desiredFaceHeight = size - - # extract the left and right eye (x, y)-coordinates - (lStart, lEnd) = FACIAL_LANDMARKS_IDXS["left_eye"] - (rStart, rEnd) = FACIAL_LANDMARKS_IDXS["right_eye"] - leftEyePts = landmarks[lStart:lEnd] - rightEyePts = landmarks[rStart:rEnd] - - # compute the center of mass for each eye - leftEyeCenter = leftEyePts.mean(axis=0).astype("int") - rightEyeCenter = rightEyePts.mean(axis=0).astype("int") - - # compute the angle between the eye centroids - dY = rightEyeCenter[0,1] - leftEyeCenter[0,1] - dX = rightEyeCenter[0,0] - leftEyeCenter[0,0] - angle = np.degrees(np.arctan2(dY, dX)) - 180 - - # compute center (x, y)-coordinates (i.e., the median point) - # between the two eyes in the input image - eyesCenter = ((leftEyeCenter[0,0] + rightEyeCenter[0,0]) // 2, (leftEyeCenter[0,1] + rightEyeCenter[0,1]) // 2) - - # grab the rotation matrix for rotating and scaling the face - M = cv2.getRotationMatrix2D(eyesCenter, angle, 1.0) - - return M diff --git a/lib/aligner.py b/lib/aligner.py deleted file mode 100644 index 4770f908eb..0000000000 --- a/lib/aligner.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -""" Aligner for faceswap.py """ - -import logging - -import cv2 -import numpy as np - -from lib.umeyama import umeyama -from lib.align_eyes import align_eyes as func_align_eyes, FACIAL_LANDMARKS_IDXS - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Extract(): - """ Based on the original https://www.reddit.com/r/deepfakes/ - code sample + contribs """ - - def extract(self, image, face, size, align_eyes): - """ Extract a face from an image """ - logger.trace("size: %s. align_eyes: %s", size, align_eyes) - padding = int(size * 0.1875) - alignment = get_align_mat(face, size, align_eyes) - extracted = self.transform(image, alignment, size, padding) - logger.trace("Returning face and alignment matrix: (alignment_matrix: %s)", alignment) - return extracted, alignment - - @staticmethod - def transform_matrix(mat, size, padding): - """ Transform the matrix for current size and padding """ - logger.trace("size: %s. padding: %s", size, padding) - matrix = mat * (size - 2 * padding) - matrix[:, 2] += padding - logger.trace("Returning: %s", matrix) - return matrix - - def transform(self, image, mat, size, padding=0): - """ Transform Image """ - logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding) - matrix = self.transform_matrix(mat, size, padding) - interpolators = get_matrix_scaling(matrix) - return cv2.warpAffine( # pylint: disable=no-member - image, matrix, (size, size), flags=interpolators[0]) - - def transform_points(self, points, mat, size, padding=0): - """ Transform points along matrix """ - logger.trace("points: %s, matrix: %s, size: %s. padding: %s", points, mat, size, padding) - matrix = self.transform_matrix(mat, size, padding) - points = np.expand_dims(points, axis=1) - points = cv2.transform( # pylint: disable=no-member - points, matrix, points.shape) - retval = np.squeeze(points) - logger.trace("Returning: %s", retval) - return retval - - def get_original_roi(self, mat, size, padding=0): - """ Return the square aligned box location on the original - image """ - logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding) - matrix = self.transform_matrix(mat, size, padding) - points = np.array([[0, 0], - [0, size - 1], - [size - 1, size - 1], - [size - 1, 0]], np.int32) - points = points.reshape((-1, 1, 2)) - matrix = cv2.invertAffineTransform(matrix) # pylint: disable=no-member - logger.trace("Returning: (points: %s, matrix: %s", points, matrix) - return cv2.transform(points, matrix) # pylint: disable=no-member - - @staticmethod - def get_feature_mask(aligned_landmarks_68, size, - padding=0, dilation=30): - """ Return the face feature mask """ - # pylint: disable=no-member - logger.trace("aligned_landmarks_68: %s, size: %s, padding: %s, dilation: %s", - aligned_landmarks_68, size, padding, dilation) - scale = size - 2 * padding - translation = padding - pad_mat = np.matrix([[scale, 0.0, translation], - [0.0, scale, translation]]) - aligned_landmarks_68 = np.expand_dims(aligned_landmarks_68, axis=1) - aligned_landmarks_68 = cv2.transform(aligned_landmarks_68, - pad_mat, - aligned_landmarks_68.shape) - aligned_landmarks_68 = np.squeeze(aligned_landmarks_68) - - (l_start, l_end) = FACIAL_LANDMARKS_IDXS["left_eye"] - (r_start, r_end) = FACIAL_LANDMARKS_IDXS["right_eye"] - (m_start, m_end) = FACIAL_LANDMARKS_IDXS["mouth"] - (n_start, n_end) = FACIAL_LANDMARKS_IDXS["nose"] - (lb_start, lb_end) = FACIAL_LANDMARKS_IDXS["left_eyebrow"] - (rb_start, rb_end) = FACIAL_LANDMARKS_IDXS["right_eyebrow"] - (c_start, c_end) = FACIAL_LANDMARKS_IDXS["chin"] - - l_eye_points = aligned_landmarks_68[l_start:l_end].tolist() - l_brow_points = aligned_landmarks_68[lb_start:lb_end].tolist() - r_eye_points = aligned_landmarks_68[r_start:r_end].tolist() - r_brow_points = aligned_landmarks_68[rb_start:rb_end].tolist() - nose_points = aligned_landmarks_68[n_start:n_end].tolist() - chin_points = aligned_landmarks_68[c_start:c_end].tolist() - mouth_points = aligned_landmarks_68[m_start:m_end].tolist() - l_eye_points = l_eye_points + l_brow_points - r_eye_points = r_eye_points + r_brow_points - mouth_points = mouth_points + nose_points + chin_points - - l_eye_hull = cv2.convexHull(np.array(l_eye_points).reshape( - (-1, 2)).astype(int)).flatten().reshape((-1, 2)) - r_eye_hull = cv2.convexHull(np.array(r_eye_points).reshape( - (-1, 2)).astype(int)).flatten().reshape((-1, 2)) - mouth_hull = cv2.convexHull(np.array(mouth_points).reshape( - (-1, 2)).astype(int)).flatten().reshape((-1, 2)) - - mask = np.zeros((size, size, 3), dtype=float) - cv2.fillConvexPoly(mask, l_eye_hull, (1, 1, 1)) - cv2.fillConvexPoly(mask, r_eye_hull, (1, 1, 1)) - cv2.fillConvexPoly(mask, mouth_hull, (1, 1, 1)) - - if dilation > 0: - kernel = np.ones((dilation, dilation), np.uint8) - mask = cv2.dilate(mask, kernel, iterations=1) - - logger.trace("Returning: %s", mask) - return mask - - -def get_matrix_scaling(mat): - """ Get the correct interpolator """ - x_scale = np.sqrt(mat[0, 0] * mat[0, 0] + mat[0, 1] * mat[0, 1]) - y_scale = (mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]) / x_scale - avg_scale = (x_scale + y_scale) * 0.5 - if avg_scale >= 1.0: - interpolators = cv2.INTER_CUBIC, cv2.INTER_AREA # pylint: disable=no-member - else: - interpolators = cv2.INTER_AREA, cv2.INTER_CUBIC # pylint: disable=no-member - logger.trace("interpolator: %s, inverse interpolator: %s", interpolators[0], interpolators[1]) - return interpolators - - -def get_align_mat(face, size, should_align_eyes): - """ Return the alignment Matrix """ - logger.trace("size: %s, should_align_eyes: %s", size, should_align_eyes) - mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]), True)[0:2] - - if should_align_eyes is False: - return mat_umeyama - - mat_umeyama = mat_umeyama * size - - # Convert to matrix - landmarks = np.matrix(face.landmarks_as_xy) - - # cv2 expects points to be in the form - # np.array([ [[x1, y1]], [[x2, y2]], ... ]), we'll expand the dim - landmarks = np.expand_dims(landmarks, axis=1) - - # Align the landmarks using umeyama - umeyama_landmarks = cv2.transform( # pylint: disable=no-member - landmarks, - mat_umeyama, - landmarks.shape) - - # Determine a rotation matrix to align eyes horizontally - mat_align_eyes = func_align_eyes(umeyama_landmarks, size) - - # Extend the 2x3 transform matrices to 3x3 so we can multiply them - # and combine them as one - mat_umeyama = np.matrix(mat_umeyama) - mat_umeyama.resize((3, 3)) - mat_align_eyes = np.matrix(mat_align_eyes) - mat_align_eyes.resize((3, 3)) - mat_umeyama[2] = mat_align_eyes[2] = [0, 0, 1] - - # Combine the umeyama transform with the extra rotation matrix - transform_mat = mat_align_eyes * mat_umeyama - - # Remove the extra row added, shape needs to be 2x3 - transform_mat = np.delete(transform_mat, 2, 0) - transform_mat = transform_mat / size - logger.trace("Returning: %s", transform_mat) - return transform_mat diff --git a/lib/alignments.py b/lib/alignments.py deleted file mode 100644 index 2780631529..0000000000 --- a/lib/alignments.py +++ /dev/null @@ -1,361 +0,0 @@ -#!/usr/bin/env python3 -""" Alignments file functions for reading, writing and manipulating - a serialized alignments file """ - -import logging -import os -from datetime import datetime - -import cv2 - -from lib import Serializer -from lib.utils import rotate_landmarks - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Alignments(): - """ Holds processes pertaining to the alignments file. - - folder: folder alignments file is stored in - filename: Filename of alignments file excluding extension. If a - valid extension is provided, then it will be used to - decide the serializer, and the serializer argument will - be ignored. - serializer: If provided, this will be the format that the data is - saved in (if data is to be saved). Can be 'json', 'pickle' - or 'yaml' - """ - # pylint: disable=too-many-public-methods - def __init__(self, folder, filename="alignments", serializer="json"): - logger.debug("Initializing %s: (folder: '%s', filename: '%s', serializer: '%s')", - self.__class__.__name__, folder, filename, serializer) - self.serializer = self.get_serializer(filename, serializer) - self.file = self.get_location(folder, filename) - - self.data = self.load() - logger.debug("Initialized %s", self.__class__.__name__) - - # << PROPERTIES >> # - - @property - def frames_count(self): - """ Return current frames count """ - retval = len(self.data) - logger.trace(retval) - return retval - - @property - def faces_count(self): - """ Return current faces count """ - retval = sum(len(faces) for faces in self.data.values()) - logger.trace(retval) - return retval - - @property - def have_alignments_file(self): - """ Return whether an alignments file exists """ - retval = os.path.exists(self.file) - logger.trace(retval) - return retval - - @property - def hashes_to_frame(self): - """ Return a dict of each face_hash with their parent - frame name(s) and their index in the frame - """ - hash_faces = dict() - for frame_name, faces in self.data.items(): - for idx, face in enumerate(faces): - hash_faces.setdefault(face["hash"], dict())[frame_name] = idx - return hash_faces - - # << INIT FUNCTIONS >> # - - @staticmethod - def get_serializer(filename, serializer): - """ Set the serializer to be used for loading and - saving alignments - - If a filename with a valid extension is passed in - this will be used as the serializer, otherwise the - specified serializer will be used """ - logger.debug("Getting serializer: (filename: '%s', serializer: '%s')", - filename, serializer) - extension = os.path.splitext(filename)[1] - if extension in (".json", ".p", ".yaml", ".yml"): - logger.debug("Serializer set from file extension: '%s'", extension) - retval = Serializer.get_serializer_from_ext(extension) - elif serializer not in ("json", "pickle", "yaml"): - raise ValueError("Error: {} is not a valid serializer. Use " - "'json', 'pickle' or 'yaml'") - else: - logger.debug("Serializer set from argument: '%s'", serializer) - retval = Serializer.get_serializer(serializer) - logger.verbose("Using '%s' serializer for alignments", retval.ext) - return retval - - def get_location(self, folder, filename): - """ Return the path to alignments file """ - logger.debug("Getting location: (folder: '%s', filename: '%s')", folder, filename) - extension = os.path.splitext(filename)[1] - if extension in (".json", ".p", ".yaml", ".yml"): - logger.debug("File extension set from filename: '%s'", extension) - location = os.path.join(str(folder), filename) - else: - location = os.path.join(str(folder), - "{}.{}".format(filename, - self.serializer.ext)) - logger.debug("File extension set from serializer: '%s'", self.serializer.ext) - logger.verbose("Alignments filepath: '%s'", location) - return location - - # << I/O >> # - - def load(self): - """ Load the alignments data - Override for custom loading logic """ - logger.debug("Loading alignments") - if not self.have_alignments_file: - raise ValueError("Error: Alignments file not found at " - "{}".format(self.file)) - - try: - logger.info("Reading alignments from: '%s'", self.file) - with open(self.file, self.serializer.roptions) as align: - data = self.serializer.unmarshal(align.read()) - except IOError as err: - logger.error("'%s' not read: %s", self.file, err.strerror) - exit(1) - logger.debug("Loaded alignments") - return data - - def reload(self): - """ Read the alignments data from the correct format """ - logger.debug("Re-loading alignments") - self.data = self.load() - logger.debug("Re-loaded alignments") - - def save(self): - """ Write the serialized alignments file """ - logger.debug("Saving alignments") - try: - logger.info("Writing alignments to: '%s'", self.file) - with open(self.file, self.serializer.woptions) as align: - align.write(self.serializer.marshal(self.data)) - logger.debug("Saved alignments") - except IOError as err: - logger.error("'%s' not written: %s", self.file, err.strerror) - - def backup(self): - """ Backup copy of old alignments """ - logger.debug("Backing up alignments") - if not os.path.isfile(self.file): - logger.debug("No alignments to back up") - return - now = datetime.now().strftime("%Y%m%d_%H%M%S") - src = self.file - split = os.path.splitext(src) - dst = split[0] + "_" + now + split[1] - logger.info("Backing up original alignments to '%s'", dst) - os.rename(src, dst) - logger.debug("Backed up alignments") - - # << VALIDATION >> # - - def frame_exists(self, frame): - """ return path of images that have faces """ - retval = frame in self.data.keys() - logger.trace("'%s': %s", frame, retval) - return retval - - def frame_has_faces(self, frame): - """ Return true if frame exists and has faces """ - retval = bool(self.data.get(frame, list())) - logger.trace("'%s': %s", frame, retval) - return retval - - def frame_has_multiple_faces(self, frame): - """ Return true if frame exists and has faces """ - if not frame: - retval = False - else: - retval = bool(len(self.data.get(frame, list())) > 1) - logger.trace("'%s': %s", frame, retval) - return retval - - # << DATA >> # - - def get_faces_in_frame(self, frame): - """ Return the alignments for the selected frame """ - logger.trace("Getting faces for frame: '%s'", frame) - return self.data.get(frame, list()) - - def get_full_frame_name(self, frame): - """ Return a frame with extension for when the extension is - not known """ - retval = next(key for key in self.data.keys() - if key.startswith(frame)) - logger.trace("Requested: '%s', Returning: '%s'", frame, retval) - return retval - - def count_faces_in_frame(self, frame): - """ Return number of alignments within frame """ - retval = len(self.data.get(frame, list())) - logger.trace(retval) - return retval - - # << MANIPULATION >> # - - def delete_face_at_index(self, frame, idx): - """ Delete the face alignment for given frame at given index """ - logger.debug("Deleting face %s for frame '%s'", idx, frame) - idx = int(idx) - if idx + 1 > self.count_faces_in_frame(frame): - logger.debug("No face to delete: (frame: '%s', idx %s)", frame, idx) - return False - del self.data[frame][idx] - logger.debug("Deleted face: (frame: '%s', idx %s)", frame, idx) - return True - - def add_face(self, frame, alignment): - """ Add a new face for a frame and return it's index """ - logger.debug("Adding face to frame: '%s'", frame) - self.data[frame].append(alignment) - retval = self.count_faces_in_frame(frame) - 1 - logger.debug("Returning new face index: %s", retval) - return retval - - def update_face(self, frame, idx, alignment): - """ Replace a face for given frame and index """ - logger.debug("Updating face %s for frame '%s'", idx, frame) - self.data[frame][idx] = alignment - - def filter_hashes(self, hashlist, filter_out=False): - """ Filter in or out faces that match the hashlist - - filter_out=True: Remove faces that match in the hashlist - filter_out=False: Remove faces that are not in the hashlist - """ - hashset = set(hashlist) - for filename, frame in self.data.items(): - for idx, face in reversed(list(enumerate(frame))): - if ((filter_out and face.get("hash", None) in hashset) or - (not filter_out and face.get("hash", None) not in hashset)): - logger.verbose("Filtering out face: (filename: %s, index: %s)", filename, idx) - del frame[idx] - else: - logger.trace("Not filtering out face: (filename: %s, index: %s)", - filename, idx) - - # << GENERATORS >> # - - def yield_faces(self): - """ Yield face alignments for one image """ - for frame_fullname, alignments in self.data.items(): - frame_name = os.path.splitext(frame_fullname)[0] - face_count = len(alignments) - logger.trace("Yielding: (frame: '%s', faces: %s, frame_fullname: '%s')", - frame_name, face_count, frame_fullname) - yield frame_name, alignments, face_count, frame_fullname - - @staticmethod - def yield_original_index_reverse(image_alignments, number_alignments): - """ Return the correct original index for - alignment in reverse order """ - for idx, _ in enumerate(reversed(image_alignments)): - original_idx = number_alignments - 1 - idx - logger.trace("Yielding: face index %s", original_idx) - yield original_idx - - # << LEGACY FUNCTIONS >> # - - # < Rotation > # - # The old rotation method would rotate the image to find a face, then - # store the rotated landmarks along with a rotation value to tell the - # convert process that it had to rotate the frame to find the landmarks. - # This is problematic for numerous reasons. The process now rotates the - # landmarks to correctly correspond with the original frame. The below are - # functions to convert legacy alignments to the currently supported - # infrastructure. - # This can eventually be removed - - def get_legacy_rotation(self): - """ Return a list of frames with legacy rotations - Looks for an 'r' value in the alignments file that - is not zero """ - logger.debug("Getting alignments containing legacy rotations") - keys = list() - for key, val in self.data.items(): - if any(alignment.get("r", None) for alignment in val): - keys.append(key) - logger.debug("Got alignments containing legacy rotations: %s", len(keys)) - return keys - - def rotate_existing_landmarks(self, frame_name, frame): - """ Backwards compatability fix. Rotates the landmarks to - their correct position and deletes r - - NB: The original frame must be passed in otherwise - the transformation cannot be performed """ - logger.trace("Rotating existing landmarks for frame: '%s'", frame_name) - dims = frame.shape[:2] - for face in self.get_faces_in_frame(frame_name): - angle = face.get("r", 0) - if not angle: - logger.trace("Landmarks do not require rotation: '%s'", frame_name) - return - logger.trace("Rotating landmarks: (frame: '%s', angle: %s)", frame_name, angle) - r_mat = self.get_original_rotation_matrix(dims, angle) - rotate_landmarks(face, r_mat) - del face["r"] - logger.trace("Rotatated existing landmarks for frame: '%s'", frame_name) - - @staticmethod - def get_original_rotation_matrix(dimensions, angle): - """ Calculate original rotation matrix and invert """ - logger.trace("Getting original rotation matrix: (dimensions: %s, angle: %s)", - dimensions, angle) - height, width = dimensions - center = (width/2, height/2) - r_mat = cv2.getRotationMatrix2D( # pylint: disable=no-member - center, -1.0 * angle, 1.) - - abs_cos = abs(r_mat[0, 0]) - abs_sin = abs(r_mat[0, 1]) - rotated_width = int(height*abs_sin + width*abs_cos) - rotated_height = int(height*abs_cos + width*abs_sin) - r_mat[0, 2] += rotated_width/2 - center[0] - r_mat[1, 2] += rotated_height/2 - center[1] - logger.trace("Returning rotation matrix: %s", r_mat) - return r_mat - - # # - # The old index based method of face matching is problematic. - # The SHA1 Hash of the extracted face is now stored in the alignments file. - # This has it's own issues, but they are far reduced from the index/filename method - # This can eventually be removed - def get_legacy_no_hashes(self): - """ Get alignments without face hashes """ - logger.debug("Getting alignments without face hashes") - keys = list() - for key, val in self.data.items(): - for alignment in val: - if "hash" not in alignment.keys(): - keys.append(key) - break - logger.debug("Got alignments without face hashes: %s", len(keys)) - return keys - - def add_face_hashes(self, frame_name, hashes): - """ Backward compatability fix. Add face hash to alignments """ - logger.trace("Adding face hash: (frame: '%s', hashes: %s)", frame_name, hashes) - faces = self.get_faces_in_frame(frame_name) - count_match = len(faces) - len(hashes) - if count_match != 0: - msg = "more" if count_match > 0 else "fewer" - logger.warning("There are %s %s face(s) in the alignments file than exist in the " - "faces folder. Check your sources for frame '%s'.", - abs(count_match), msg, frame_name) - for idx, i_hash in hashes.items(): - faces[idx]["hash"] = i_hash diff --git a/lib/cli.py b/lib/cli.py deleted file mode 100644 index 21eef4e642..0000000000 --- a/lib/cli.py +++ /dev/null @@ -1,881 +0,0 @@ -#!/usr/bin/env python3 -""" Command Line Arguments """ -import argparse -import logging -import os -import platform -import sys - -from importlib import import_module - -from lib.logger import crash_log, log_setup -from lib.utils import safe_shutdown -from plugins.plugin_loader import PluginLoader - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class ScriptExecutor(): - """ Loads the relevant script modules and executes the script. - This class is initialised in each of the argparsers for the relevant - command, then execute script is called within their set_default - function. """ - - def __init__(self, command, subparsers=None): - self.command = command.lower() - self.subparsers = subparsers - - def import_script(self): - """ Only import a script's modules when running that script.""" - self.test_for_gui() - cmd = os.path.basename(sys.argv[0]) - src = "tools" if cmd == "tools.py" else "scripts" - mod = ".".join((src, self.command.lower())) - module = import_module(mod) - script = getattr(module, self.command.title()) - return script - - def test_for_gui(self): - """ If running the gui, check the prerequisites """ - if self.command != "gui": - return - self.test_tkinter() - self.check_display() - - @staticmethod - def test_tkinter(): - """ If the user is running the GUI, test whether the - tkinter app is available on their machine. If not - exit gracefully. - - This avoids having to import every tk function - within the GUI in a wrapper and potentially spamming - traceback errors to console """ - - try: - # pylint: disable=unused-variable - import tkinter # noqa pylint: disable=unused-import - except ImportError: - logger.warning( - "It looks like TkInter isn't installed for your OS, so " - "the GUI has been disabled. To enable the GUI please " - "install the TkInter application. You can try:") - logger.info("Anaconda: conda install tk") - logger.info("Windows/macOS: Install ActiveTcl Community Edition from " - "http://www.activestate.com") - logger.info("Ubuntu/Mint/Debian: sudo apt install python3-tk") - logger.info("Arch: sudo pacman -S tk") - logger.info("CentOS/Redhat: sudo yum install tkinter") - logger.info("Fedora: sudo dnf install python3-tkinter") - exit(1) - - @staticmethod - def check_display(): - """ Check whether there is a display to output the GUI. If running on - Windows then assume not running in headless mode """ - if not os.environ.get("DISPLAY", None) and os.name != "nt": - logger.warning("No display detected. GUI mode has been disabled.") - if platform.system() == "Darwin": - logger.info("macOS users need to install XQuartz. " - "See https://support.apple.com/en-gb/HT201341") - exit(1) - - def execute_script(self, arguments): - """ Run the script for called command """ - log_setup(arguments.loglevel, arguments.logfile, self.command) - logger.debug("Executing: %s. PID: %s", self.command, os.getpid()) - try: - script = self.import_script() - process = script(arguments) - process.process() - except KeyboardInterrupt: # pylint: disable=try-except-raise - raise - except SystemExit: - pass - except Exception: # pylint: disable=broad-except - crash_file = crash_log() - logger.exception("Got Exception on main handler:") - logger.critical("An unexpected crash has occurred. Crash report written to %s. " - "Please verify you are running the latest version of faceswap " - "before reporting", crash_file) - - finally: - safe_shutdown() - - -class Slider(argparse.Action): # pylint: disable=too-few-public-methods - """ Adds support for the GUI slider - - An additional option 'min_max' must be provided containing tuple of min and max accepted - values. - - 'rounding' sets the decimal places for floats or the step interval for ints. - """ - def __init__(self, option_strings, dest, nargs=None, min_max=None, rounding=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super().__init__(option_strings, dest, **kwargs) - self.min_max = min_max - self.rounding = rounding - - def _get_kwargs(self): - names = ["option_strings", - "dest", - "nargs", - "const", - "default", - "type", - "choices", - "help", - "metavar", - "min_max", # Tuple containing min and max values of scale - "rounding"] # Decimal places to round floats to or step interval for ints - return [(name, getattr(self, name)) for name in names] - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, values) - - -class FullPaths(argparse.Action): # pylint: disable=too-few-public-methods - """ Expand user- and relative-paths """ - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, os.path.abspath( - os.path.expanduser(values))) - - -class DirFullPaths(FullPaths): - """ Class that gui uses to determine if you need to open a directory """ - # pylint: disable=too-few-public-methods,unnecessary-pass - pass - - -class FileFullPaths(FullPaths): - """ - Class that gui uses to determine if you need to open a file. - - see lib/gui/utils.py FileHandler for current GUI filetypes - """ - # pylint: disable=too-few-public-methods - def __init__(self, option_strings, dest, nargs=None, filetypes=None, **kwargs): - super(FileFullPaths, self).__init__(option_strings, dest, **kwargs) - if nargs is not None: - raise ValueError("nargs not allowed") - self.filetypes = filetypes - - def _get_kwargs(self): - names = ["option_strings", - "dest", - "nargs", - "const", - "default", - "type", - "choices", - "help", - "metavar", - "filetypes"] - return [(name, getattr(self, name)) for name in names] - - -class DirOrFileFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods - """ Class that the gui uses to determine that the input can take a folder or a filename. - Inherits functionality from FileFullPaths - Has the effect of giving the user 2 Open Dialogue buttons in the gui """ - pass - - -class SaveFileFullPaths(FileFullPaths): - """ - Class that gui uses to determine if you need to save a file. - - see lib/gui/utils.py FileHandler for current GUI filetypes - """ - # pylint: disable=too-few-public-methods,unnecessary-pass - pass - - -class ContextFullPaths(FileFullPaths): - """ - Class that gui uses to determine if you need to open a file or a - directory based on which action you are choosing - - To use ContextFullPaths the action_option item should indicate which - cli option dictates the context of the filesystem dialogue - - Bespoke actions are then set in lib/gui/utils.py FileHandler - """ - # pylint: disable=too-few-public-methods, too-many-arguments - def __init__(self, option_strings, dest, nargs=None, filetypes=None, - action_option=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(ContextFullPaths, self).__init__(option_strings, dest, - filetypes=None, **kwargs) - self.action_option = action_option - self.filetypes = filetypes - - def _get_kwargs(self): - names = ["option_strings", - "dest", - "nargs", - "const", - "default", - "type", - "choices", - "help", - "metavar", - "filetypes", - "action_option"] - return [(name, getattr(self, name)) for name in names] - - -class FullHelpArgumentParser(argparse.ArgumentParser): - """ Identical to the built-in argument parser, but on error it - prints full help message instead of just usage information """ - def error(self, message): - self.print_help(sys.stderr) - args = {"prog": self.prog, "message": message} - self.exit(2, "%(prog)s: error: %(message)s\n" % args) - - -class SmartFormatter(argparse.HelpFormatter): - """ Smart formatter for allowing raw formatting in help - text. - - To use prefix the help item with "R|" to overide - default formatting - - from: https://stackoverflow.com/questions/3853722 """ - - def _split_lines(self, text, width): - if text.startswith("R|"): - return text[2:].splitlines() - # this is the RawTextHelpFormatter._split_lines - return argparse.HelpFormatter._split_lines(self, text, width) - - -class FaceSwapArgs(): - """ Faceswap argument parser functions that are universal - to all commands. Should be the parent function of all - subsequent argparsers """ - def __init__(self, subparser, command, - description="default", subparsers=None): - - self.global_arguments = self.get_global_arguments() - self.argument_list = self.get_argument_list() - self.optional_arguments = self.get_optional_arguments() - if not subparser: - return - - self.parser = self.create_parser(subparser, command, description) - - self.add_arguments() - - script = ScriptExecutor(command, subparsers) - self.parser.set_defaults(func=script.execute_script) - - @staticmethod - def get_argument_list(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui override for command specific arguments """ - argument_list = [] - return argument_list - - @staticmethod - def get_optional_arguments(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui. This is used for when there are sub-children - (e.g. convert and extract) Override this for custom arguments """ - argument_list = [] - return argument_list - - @staticmethod - def get_global_arguments(): - """ Arguments that are used in ALL parts of Faceswap - DO NOT override this """ - global_args = list() - global_args.append({"opts": ("-L", "--loglevel"), - "type": str.upper, - "dest": "loglevel", - "default": "INFO", - "choices": ("INFO", "VERBOSE", "DEBUG", "TRACE"), - "help": "Log level. Stick with INFO or VERBOSE unless you need to " - "file an error report. Be careful with TRACE as it will " - "generate a lot of data"}) - global_args.append({"opts": ("-LF", "--logfile"), - "action": SaveFileFullPaths, - "filetypes": 'log', - "type": str, - "dest": "logfile", - "help": "Path to store the logfile. Leave blank to store in the " - "faceswap folder", - "default": None}) - # This is a hidden argument to indicate that the GUI is being used, - # so the preview window should be redirected Accordingly - global_args.append({"opts": ("-gui", "--gui"), - "action": "store_true", - "dest": "redirect_gui", - "default": False, - "help": argparse.SUPPRESS}) - return global_args - - @staticmethod - def create_parser(subparser, command, description): - """ Create the parser for the selected command """ - parser = subparser.add_parser( - command, - help=description, - description=description, - epilog="Questions and feedback: \ - https://github.com/deepfakes/faceswap-playground", - formatter_class=SmartFormatter) - return parser - - def add_arguments(self): - """ Parse the arguments passed in from argparse """ - options = self.global_arguments + self.argument_list + self.optional_arguments - for option in options: - args = option["opts"] - kwargs = {key: option[key] - for key in option.keys() if key != "opts"} - self.parser.add_argument(*args, **kwargs) - - -class ExtractConvertArgs(FaceSwapArgs): - """ This class is used as a parent class to capture arguments that - will be used in both the extract and convert process. - - Arguments that can be used in both of these processes should be - placed here, but no further processing should be done. This class - just captures arguments """ - - @staticmethod - def get_argument_list(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui """ - argument_list = list() - argument_list.append({"opts": ("-i", "--input-dir"), - "action": DirOrFileFullPaths, - "filetypes": "video", - "dest": "input_dir", - "default": "input", - "help": "Input directory or video. Either a " - "directory containing the image files " - "you wish to process or path to a " - "video file. Defaults to 'input'"}) - argument_list.append({"opts": ("-o", "--output-dir"), - "action": DirFullPaths, - "dest": "output_dir", - "default": "output", - "help": "Output directory. This is where the " - "converted files will be stored. " - "Defaults to 'output'"}) - argument_list.append({"opts": ("-al", "--alignments"), - "action": FileFullPaths, - "filetypes": 'alignments', - "type": str, - "dest": "alignments_path", - "help": "Optional path to an alignments file."}) - argument_list.append({"opts": ("-l", "--ref_threshold"), - "action": Slider, - "min_max": (0.01, 0.99), - "rounding": 2, - "type": float, - "dest": "ref_threshold", - "default": 0.6, - "help": "Threshold for positive face recognition. For use with " - "nfilter or filter. Lower values are stricter."}) - argument_list.append({"opts": ("-n", "--nfilter"), - "type": str, - "dest": "nfilter", - "nargs": "+", - "default": None, - "help": "Reference image for the persons you do " - "not want to process. Should be a front " - "portrait. Multiple images can be added " - "space separated"}) - argument_list.append({"opts": ("-f", "--filter"), - "type": str, - "dest": "filter", - "nargs": "+", - "default": None, - "help": "Reference images for the person you " - "want to process. Should be a front " - "portrait. Multiple images can be added " - "space separated"}) - return argument_list - - -class ExtractArgs(ExtractConvertArgs): - """ Class to parse the command line arguments for extraction. - Inherits base options from ExtractConvertArgs where arguments - that are used for both extract and convert should be placed """ - - @staticmethod - def get_optional_arguments(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui """ - argument_list = [] - argument_list.append({"opts": ("--serializer", ), - "type": str.lower, - "dest": "serializer", - "default": "json", - "choices": ("json", "pickle", "yaml"), - "help": "Serializer for alignments file. If " - "yaml is chosen and not available, then " - "json will be used as the default " - "fallback."}) - argument_list.append({ - "opts": ("-D", "--detector"), - "type": str.lower, - "choices": PluginLoader.get_available_extractors( - "detect"), - "default": "mtcnn", - "help": "R|Detector to use." - "\n'dlib-hog': uses least resources, but is the" - "\n\tleast reliable." - "\n'dlib-cnn': faster than mtcnn but detects" - "\n\tfewer faces and fewer false positives." - "\n'mtcnn': slower than dlib, but uses fewer" - "\n\tresources whilst detecting more faces and" - "\n\tmore false positives. Has superior" - "\n\talignment to dlib"}) - argument_list.append({ - "opts": ("-A", "--aligner"), - "type": str.lower, - "choices": PluginLoader.get_available_extractors( - "align"), - "default": "fan", - "help": "R|Aligner to use." - "\n'dlib': Dlib Pose Predictor. Faster, less " - "\n\tresource intensive, but less accurate." - "\n'fan': Face Alignment Network. Best aligner." - "\n\tGPU heavy, slow when not running on GPU"}) - argument_list.append({"opts": ("-r", "--rotate-images"), - "type": str, - "dest": "rotate_images", - "default": None, - "help": "If a face isn't found, rotate the " - "images to try to find a face. Can find " - "more faces at the cost of extraction " - "speed. Pass in a single number to use " - "increments of that size up to 360, or " - "pass in a list of numbers to enumerate " - "exactly what angles to check"}) - argument_list.append({"opts": ("-bt", "--blur-threshold"), - "type": float, - "action": Slider, - "min_max": (0.0, 100.0), - "rounding": 1, - "dest": "blur_thresh", - "default": 0.0, - "help": "Automatically discard images blurrier than the specified " - "threshold. Discarded images are moved into a \"blurry\" " - "sub-folder. Lower values allow more blur. Set to 0.0 to " - "turn off."}) - argument_list.append({"opts": ("-mp", "--multiprocess"), - "action": "store_true", - "default": False, - "help": "Run extraction in parallel. Offers " - "speed up for some extractor/detector " - "combinations, less so for others. " - "Only has an effect if both the " - "aligner and detector use the GPU, " - "otherwise this is automatic."}) - argument_list.append({"opts": ("-sz", "--size"), - "type": int, - "action": Slider, - "min_max": (128, 512), - "default": 256, - "rounding": 64, - "help": "The output size of extracted faces. Make sure that the " - "model you intend to train supports your required size. " - "This will only need to be changed for hi-res models."}) - argument_list.append({"opts": ("-min", "--min-size"), - "type": int, - "action": Slider, - "dest": "min_size", - "min_max": (0, 1080), - "default": 0, - "rounding": 20, - "help": "Filters out faces detected below this size. Length, in " - "pixels across the diagonal of the bounding box. Set to 0 " - "for off"}) - argument_list.append({"opts": ("-s", "--skip-existing"), - "action": "store_true", - "dest": "skip_existing", - "default": False, - "help": "Skips frames that have already been " - "extracted and exist in the alignments " - "file"}) - argument_list.append({"opts": ("-sf", "--skip-existing-faces"), - "action": "store_true", - "dest": "skip_faces", - "default": False, - "help": "Skip frames that already have " - "detected faces in the alignments " - "file"}) - argument_list.append({"opts": ("-dl", "--debug-landmarks"), - "action": "store_true", - "dest": "debug_landmarks", - "default": False, - "help": "Draw landmarks on the ouput faces for " - "debug"}) - argument_list.append({"opts": ("-ae", "--align-eyes"), - "action": "store_true", - "dest": "align_eyes", - "default": False, - "help": "Perform extra alignment to ensure " - "left/right eyes are at the same " - "height"}) - argument_list.append({"opts": ("-si", "--save-interval"), - "dest": "save_interval", - "type": int, - "action": Slider, - "min_max": (0, 1000), - "rounding": 10, - "default": 0, - "help": "Automatically save the alignments file after a set amount " - "of frames. Will only save at the end of extracting by " - "default. WARNING: Don't interrupt the script when writing " - "the file because it might get corrupted. Set to 0 to turn " - "off"}) - return argument_list - - -class ConvertArgs(ExtractConvertArgs): - """ Class to parse the command line arguments for conversion. - Inherits base options from ExtractConvertArgs where arguments - that are used for both extract and convert should be placed """ - - @staticmethod - def get_optional_arguments(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui """ - argument_list = [] - argument_list.append({"opts": ("-m", "--model-dir"), - "action": DirFullPaths, - "dest": "model_dir", - "default": "models", - "help": "Model directory. A directory " - "containing the trained model you wish " - "to process. Defaults to 'models'"}) - argument_list.append({"opts": ("-a", "--input-aligned-dir"), - "action": DirFullPaths, - "dest": "input_aligned_dir", - "default": None, - "help": "Input \"aligned directory\". A " - "directory that should contain the " - "aligned faces extracted from the input " - "files. If you delete faces from this " - "folder, they'll be skipped during " - "conversion. If no aligned dir is " - "specified, all faces will be " - "converted"}) - argument_list.append({"opts": ("-t", "--trainer"), - "type": str.lower, - "choices": PluginLoader.get_available_models(), - "default": PluginLoader.get_default_model(), - "help": "Select the trainer that was used to " - "create the model"}) - argument_list.append({"opts": ("-c", "--converter"), - "type": str.lower, - "choices": PluginLoader.get_available_converters(), - "default": "masked", - "help": "Converter to use"}) - argument_list.append({ - "opts": ("-M", "--mask-type"), - "type": str.lower, - "dest": "mask_type", - "choices": ["ellipse", - "facehull", - "dfl", - # "cnn", Removed until implemented - "none"], - "default": "facehull", - "help": "R|Mask to use to replace faces." - "\nellipse: Oval around face." - "\nfacehull: Face cutout based on landmarks." - "\ndfl: A Face Hull mask from DeepFaceLabs." - # "\ncnn: Not yet implemented" Removed until implemented - "\nnone: No mask. Can still use blur and erode on the edges of the swap box."}) - argument_list.append({"opts": ("-b", "--blur-size"), - "type": float, - "action": Slider, - "min_max": (0.0, 100.0), - "rounding": 2, - "default": 5.0, - "help": "Blur kernel size as a percentage of the swap area. Smooths " - "the transition between the swapped face and the background " - "image."}) - argument_list.append({"opts": ("-e", "--erosion-size"), - "dest": "erosion_size", - "type": float, - "action": Slider, - "min_max": (-100.0, 100.0), - "rounding": 2, - "default": 0.0, - "help": "Erosion kernel size as a percentage of the mask radius " - "area. Positive values apply erosion which reduces the size " - "of the swapped area. Negative values apply dilation which " - "increases the swapped area"}) - argument_list.append({"opts": ("-g", "--gpus"), - "type": int, - "action": Slider, - "min_max": (1, 10), - "rounding": 1, - "default": 1, - "help": "Number of GPUs to use for conversion"}) - argument_list.append({"opts": ("-sh", "--sharpen"), - "type": str.lower, - "dest": "sharpen_image", - "choices": ["box_filter", "gaussian_filter", "none"], - "default": "none", - "help": "Sharpen the masked facial region of " - "the converted images. Choice of filter " - "to use in sharpening process -- box" - "filter or gaussian filter."}) - argument_list.append({"opts": ("-fr", "--frame-ranges"), - "nargs": "+", - "type": str, - "help": "frame ranges to apply transfer to e.g. " - "For frames 10 to 50 and 90 to 100 use " - "--frame-ranges 10-50 90-100. Files " - "must have the frame-number as the last " - "number in the name!"}) - argument_list.append({"opts": ("-d", "--discard-frames"), - "action": "store_true", - "dest": "discard_frames", - "default": False, - "help": "When used with --frame-ranges discards " - "frames that are not processed instead " - "of writing them out unchanged"}) - argument_list.append({"opts": ("-s", "--swap-model"), - "action": "store_true", - "dest": "swap_model", - "default": False, - "help": "Swap the model. Instead of A -> B, " - "swap B -> A"}) - argument_list.append({"opts": ("-S", "--seamless"), - "action": "store_true", - "dest": "seamless_clone", - "default": False, - "help": "Use cv2's seamless clone function to " - "remove extreme gradients at the mask " - "seam by smoothing colors."}) - argument_list.append({"opts": ("-mh", "--match-histogram"), - "action": "store_true", - "dest": "match_histogram", - "default": False, - "help": "Adjust the histogram of each color " - "channel in the swapped reconstruction " - "to equal the histogram of the masked " - "area in the orginal image"}) - argument_list.append({"opts": ("-aca", "--avg-color-adjust"), - "action": "store_true", - "dest": "avg_color_adjust", - "default": False, - "help": "Adjust the mean of each color channel " - " in the swapped reconstruction to " - "equal the mean of the masked area in " - "the orginal image"}) - argument_list.append({"opts": ("-sb", "--smooth-box"), - "action": "store_true", - "dest": "smooth_box", - "default": False, - "help": "Perform a Gaussian blur on the edges of the face box " - "received from the model. Helps reduce pronounced edges " - "of the swap area"}) - argument_list.append({"opts": ("-dt", "--draw-transparent"), - "action": "store_true", - "dest": "draw_transparent", - "default": False, - "help": "Place the swapped face on a " - "transparent layer rather than the " - "original frame."}) - return argument_list - - -class TrainArgs(FaceSwapArgs): - """ Class to parse the command line arguments for training """ - - @staticmethod - def get_argument_list(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui """ - argument_list = list() - argument_list.append({"opts": ("-A", "--input-A"), - "action": DirFullPaths, - "dest": "input_a", - "default": "input_a", - "help": "Input directory. A directory " - "containing training images for face A. " - "Defaults to 'input'"}) - argument_list.append({"opts": ("-B", "--input-B"), - "action": DirFullPaths, - "dest": "input_b", - "default": "input_b", - "help": "Input directory. A directory " - "containing training images for face B. " - "Defaults to 'input'"}) - argument_list.append({"opts": ("-ala", "--alignments-A"), - "action": FileFullPaths, - "filetypes": 'alignments', - "type": str, - "dest": "alignments_path_a", - "default": None, - "help": "Path to alignments file for training set A. Only required " - "if you are using a masked model or warp-to-landmarks is " - "enabled. Defaults to /alignments.json if not " - "provided."}) - argument_list.append({"opts": ("-alb", "--alignments-B"), - "action": FileFullPaths, - "filetypes": 'alignments', - "type": str, - "dest": "alignments_path_b", - "default": None, - "help": "Path to alignments file for training set B. Only required " - "if you are using a masked model or warp-to-landmarks is " - "enabled. Defaults to /alignments.json if not " - "provided."}) - argument_list.append({"opts": ("-m", "--model-dir"), - "action": DirFullPaths, - "dest": "model_dir", - "default": "models", - "help": "Model directory. This is where the " - "training data will be stored. " - "Defaults to 'model'"}) - argument_list.append({"opts": ("-t", "--trainer"), - "type": str.lower, - "choices": PluginLoader.get_available_models(), - "default": PluginLoader.get_default_model(), - "help": "Select which trainer to use, Use " - "LowMem for cards with less than 2GB of " - "VRAM"}) - argument_list.append({"opts": ("-s", "--save-interval"), - "type": int, - "action": Slider, - "min_max": (10, 1000), - "rounding": 10, - "dest": "save_interval", - "default": 100, - "help": "Sets the number of iterations before saving the model"}) - argument_list.append({"opts": ("-bs", "--batch-size"), - "type": int, - "action": Slider, - "min_max": (2, 256), - "rounding": 2, - "dest": "batch_size", - "default": 64, - "help": "Batch size, as a power of 2 (64, 128, 256, etc)"}) - argument_list.append({"opts": ("-it", "--iterations"), - "type": int, - "action": Slider, - "min_max": (0, 5000000), - "rounding": 20000, - "default": 1000000, - "help": "Length of training in iterations."}) - argument_list.append({"opts": ("-g", "--gpus"), - "type": int, - "action": Slider, - "min_max": (1, 10), - "rounding": 1, - "default": 1, - "help": "Number of GPUs to use for training"}) - argument_list.append({"opts": ("-ps", "--preview-scale"), - "type": int, - "action": Slider, - "dest": "preview_scale", - "min_max": (25, 200), - "rounding": 25, - "default": 100, - "help": "Percentage amount to scale the preview by."}) - argument_list.append({"opts": ("-p", "--preview"), - "action": "store_true", - "dest": "preview", - "default": False, - "help": "Show preview output. If not specified, " - "write progress to file"}) - argument_list.append({"opts": ("-w", "--write-image"), - "action": "store_true", - "dest": "write_image", - "default": False, - "help": "Writes the training result to a file " - "even on preview mode"}) - argument_list.append({"opts": ("-ag", "--allow-growth"), - "action": "store_true", - "dest": "allow_growth", - "default": False, - "help": "Sets allow_growth option of Tensorflow " - "to spare memory on some configs"}) - argument_list.append({"opts": ("-nl", "--no-logs"), - "action": "store_true", - "dest": "no_logs", - "default": False, - "help": "Disables TensorBoard logging. NB: Disabling logs means " - "that you will not be able to use the graph or analysis " - "for this session in the GUI."}) - argument_list.append({"opts": ("-wl", "--warp-to-landmarks"), - "action": "store_true", - "dest": "warp_to_landmarks", - "default": False, - "help": "Warps training faces to closely matched Landmarks from the " - "opposite face-set rather than randomly warping the face. " - "This is the 'dfaker' way of doing warping. Alignments " - "files for both sets of faces must be provided if using " - "this option."}) - argument_list.append({"opts": ("-nf", "--no-flip"), - "action": "store_true", - "dest": "no_flip", - "default": False, - "help": "To effectively learn, a random set of images are flipped " - "horizontally. Sometimes it is desirable for this not to " - "occur. Generally this should be left off except for " - "during 'fit training'."}) - argument_list.append({"opts": ("-tia", "--timelapse-input-A"), - "action": DirFullPaths, - "dest": "timelapse_input_a", - "default": None, - "help": "For if you want a timelapse: " - "The input folder for the timelapse. " - "This folder should contain faces of A " - "which will be converted for the " - "timelapse. You must supply a " - "--timelapse-output and a " - "--timelapse-input-B parameter."}) - argument_list.append({"opts": ("-tib", "--timelapse-input-B"), - "action": DirFullPaths, - "dest": "timelapse_input_b", - "default": None, - "help": "For if you want a timelapse: " - "The input folder for the timelapse. " - "This folder should contain faces of B " - "which will be converted for the " - "timelapse. You must supply a " - "--timelapse-output and a " - "--timelapse-input-A parameter."}) - argument_list.append({"opts": ("-to", "--timelapse-output"), - "action": DirFullPaths, - "dest": "timelapse_output", - "default": None, - "help": "The output folder for the timelapse. " - "If the input folders are supplied but " - "no output folder, it will default to " - "your model folder /timelapse/"}) - return argument_list - - -class GuiArgs(FaceSwapArgs): - """ Class to parse the command line arguments for training """ - - @staticmethod - def get_argument_list(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui """ - argument_list = [] - argument_list.append({"opts": ("-d", "--debug"), - "action": "store_true", - "dest": "debug", - "default": False, - "help": "Output to Shell console instead of " - "GUI console"}) - return argument_list diff --git a/lib/cli/__init__.py b/lib/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/cli/actions.py b/lib/cli/actions.py new file mode 100644 index 0000000000..634b5983f8 --- /dev/null +++ b/lib/cli/actions.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +""" Custom :class:`argparse.Action` objects for Faceswap's Command Line Interface. + +The custom actions within this module allow for custom manipulation of Command Line Arguments +as well as adding a mechanism for indicating to the GUI how specific options should be rendered. +""" + +import argparse +import os +import typing as T + +from lib.utils import get_module_objects + + +# << FILE HANDLING >> + +class _FullPaths(argparse.Action): + """ Parent class for various file type and file path handling classes. + + Expands out given paths to their full absolute paths. This class should not be + called directly. It is the base class for the various different file handling + methods. + """ + def __call__(self, parser, namespace, values, option_string=None) -> None: + if isinstance(values, (list, tuple)): + vals = [os.path.abspath(os.path.expanduser(val)) for val in values] + else: + vals = os.path.abspath(os.path.expanduser(values)) + setattr(namespace, self.dest, vals) + + +class DirFullPaths(_FullPaths): + """ Adds support for a Directory browser in the GUI. + + This is a standard :class:`argparse.Action` (with stock parameters) which indicates to the GUI + that a dialog box should be opened in order to browse for a folder. + + No additional parameters are required. + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--folder_location"), + >>> action=DirFullPaths)), + """ + pass # pylint:disable=unnecessary-pass + + +class FileFullPaths(_FullPaths): + """ Adds support for a File browser to select a single file in the GUI. + + This extends the standard :class:`argparse.Action` and adds an additional parameter + :attr:`filetypes`, indicating to the GUI that it should pop a file browser for opening a file + and limit the results to the file types listed. As well as the standard parameters, the + following parameter is required: + + Parameters + ---------- + filetypes: str + The accepted file types for this option. This is the key for the GUIs lookup table which + can be found in :class:`lib.gui.utils.FileHandler` + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--video_location"), + >>> action=FileFullPaths, + >>> filetypes="video))" + """ + def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.filetypes = filetypes + + def _get_kwargs(self): + names = ["option_strings", + "dest", + "nargs", + "const", + "default", + "type", + "choices", + "help", + "metavar", + "filetypes"] + return [(name, getattr(self, name)) for name in names] + + +class FilesFullPaths(FileFullPaths): + """ Adds support for a File browser to select multiple files in the GUI. + + This extends the standard :class:`argparse.Action` and adds an additional parameter + :attr:`filetypes`, indicating to the GUI that it should pop a file browser, and limit + the results to the file types listed. Multiple files can be selected for opening, so the + :attr:`nargs` parameter must be set. As well as the standard parameters, the following + parameter is required: + + Parameters + ---------- + filetypes: str + The accepted file types for this option. This is the key for the GUIs lookup table which + can be found in :class:`lib.gui.utils.FileHandler` + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--images"), + >>> action=FilesFullPaths, + >>> filetypes="image", + >>> nargs="+")) + """ + def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None: + if kwargs.get("nargs", None) is None: + opt = kwargs["option_strings"] + raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}") + super().__init__(*args, **kwargs) + + +class DirOrFileFullPaths(FileFullPaths): + """ Adds support to the GUI to launch either a file browser or a folder browser. + + Some inputs (for example source frames) can come from a folder of images or from a + video file. This indicates to the GUI that it should place 2 buttons (one for a folder + browser, one for a file browser) for file/folder browsing. + + The standard :class:`argparse.Action` is extended with the additional parameter + :attr:`filetypes`, indicating to the GUI that it should pop a file browser, and limit + the results to the file types listed. As well as the standard parameters, the following + parameter is required: + + Parameters + ---------- + filetypes: str + The accepted file types for this option. This is the key for the GUIs lookup table which + can be found in :class:`lib.gui.utils.FileHandler`. NB: This parameter is only used for + the file browser and not the folder browser + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--input_frames"), + >>> action=DirOrFileFullPaths, + >>> filetypes="video))" + """ + + +class DirOrFilesFullPaths(FileFullPaths): + """ Adds support to the GUI to launch either a file browser for selecting multiple files + or a folder browser. + + Some inputs (for example face filter) can come from a folder of images or from multiple + image file. This indicates to the GUI that it should place 2 buttons (one for a folder + browser, one for a multi-file browser) for file/folder browsing. + + The standard :class:`argparse.Action` is extended with the additional parameter + :attr:`filetypes`, indicating to the GUI that it should pop a file browser, and limit + the results to the file types listed. As well as the standard parameters, the following + parameter is required: + + Parameters + ---------- + filetypes: str + The accepted file types for this option. This is the key for the GUIs lookup table which + can be found in :class:`lib.gui.utils.FileHandler`. NB: This parameter is only used for + the file browser and not the folder browser + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--input_frames"), + >>> action=DirOrFileFullPaths, + >>> filetypes="video))" + """ + def __call__(self, parser, namespace, values, option_string=None) -> None: + """ Override :class:`_FullPaths` __call__ function. + + The input for this option can be a space separated list of files or a single folder. + Folders can have spaces in them, so we don't want to blindly expand the paths. + + We check whether the input can be resolved to a folder first before expanding. + """ + assert isinstance(values, (list, tuple)) + folder = os.path.abspath(os.path.expanduser(" ".join(values))) + if os.path.isdir(folder): + setattr(namespace, self.dest, [folder]) + else: # file list so call parent method + super().__call__(parser, namespace, values, option_string) + + +class SaveFileFullPaths(FileFullPaths): + """ Adds support for a Save File dialog in the GUI. + + This extends the standard :class:`argparse.Action` and adds an additional parameter + :attr:`filetypes`, indicating to the GUI that it should pop a save file browser, and limit + the results to the file types listed. As well as the standard parameters, the following + parameter is required: + + Parameters + ---------- + filetypes: str + The accepted file types for this option. This is the key for the GUIs lookup table which + can be found in :class:`lib.gui.utils.FileHandler` + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--video_out"), + >>> action=SaveFileFullPaths, + >>> filetypes="video")) + """ + pass # pylint:disable=unnecessary-pass + + +class ContextFullPaths(FileFullPaths): + """ Adds support for context sensitive browser dialog opening in the GUI. + + For some tasks, the type of action (file load, folder open, file save etc.) can vary + depending on the task to be performed (a good example of this is the effmpeg tool). + Using this action indicates to the GUI that the type of dialog to be launched can change + depending on another option. As well as the standard parameters, the below parameters are + required. NB: :attr:`nargs` are explicitly disallowed. + + Parameters + ---------- + filetypes: str + The accepted file types for this option. This is the key for the GUIs lookup table which + can be found in :class:`lib.gui.utils.FileHandler` + action_option: str + The command line option that dictates the context of the file dialog to be opened. + Bespoke actions are set in :class:`lib.gui.utils.FileHandler` + + Example + ------- + Assuming an argument has already been set with option string `-a` indicating the action to be + performed, the following will pop a different type of dialog depending on the action selected: + + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--input_video"), + >>> action=ContextFullPaths, + >>> filetypes="video", + >>> action_option="-a")) + """ + # pylint:disable=too-many-arguments + def __init__(self, + *args, + filetypes: str | None = None, + action_option: str | None = None, + **kwargs) -> None: + opt = kwargs["option_strings"] + if kwargs.get("nargs", None) is not None: + raise ValueError(f"nargs not allowed for ContextFullPaths: {opt}") + if filetypes is None: + raise ValueError(f"filetypes is required for ContextFullPaths: {opt}") + if action_option is None: + raise ValueError(f"action_option is required for ContextFullPaths: {opt}") + super().__init__(*args, filetypes=filetypes, **kwargs) + self.action_option = action_option + + def _get_kwargs(self) -> list[tuple[str, T.Any]]: + names = ["option_strings", + "dest", + "nargs", + "const", + "default", + "type", + "choices", + "help", + "metavar", + "filetypes", + "action_option"] + return [(name, getattr(self, name)) for name in names] + + +# << GUI DISPLAY OBJECTS >> + +class Radio(argparse.Action): + """ Adds support for a GUI Radio options box. + + This is a standard :class:`argparse.Action` (with stock parameters) which indicates to the GUI + that the options passed should be rendered as a group of Radio Buttons rather than a combo box. + + No additional parameters are required, but the :attr:`choices` parameter must be provided as + these will be the Radio Box options. :attr:`nargs` are explicitly disallowed. + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--foobar"), + >>> action=Radio, + >>> choices=["foo", "bar")) + """ + def __init__(self, *args, **kwargs) -> None: + opt = kwargs["option_strings"] + if kwargs.get("nargs", None) is not None: + raise ValueError(f"nargs not allowed for Radio buttons: {opt}") + if not kwargs.get("choices", []): + raise ValueError(f"Choices must be provided for Radio buttons: {opt}") + super().__init__(*args, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None) -> None: + setattr(namespace, self.dest, values) + + +class MultiOption(argparse.Action): + """ Adds support for multiple option checkboxes in the GUI. + + This is a standard :class:`argparse.Action` (with stock parameters) which indicates to the GUI + that the options passed should be rendered as a group of Radio Buttons rather than a combo box. + + The :attr:`choices` parameter must be provided as this provides the valid option choices. + + Example + ------- + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--foobar"), + >>> action=MultiOption, + >>> choices=["foo", "bar")) + """ + def __init__(self, *args, **kwargs) -> None: + opt = kwargs["option_strings"] + if not kwargs.get("nargs", []): + raise ValueError(f"nargs must be provided for MultiOption: {opt}") + if not kwargs.get("choices", []): + raise ValueError(f"Choices must be provided for MultiOption: {opt}") + super().__init__(*args, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None) -> None: + setattr(namespace, self.dest, values) + + +class Slider(argparse.Action): + """ Adds support for a slider in the GUI. + + The standard :class:`argparse.Action` is extended with the additional parameters listed below. + The :attr:`default` value must be supplied and the :attr:`type` must be either :class:`int` or + :class:`float`. :attr:`nargs` are explicitly disallowed. + + Parameters + ---------- + min_max: tuple + The (`min`, `max`) values that the slider's range should be set to. The values should be a + pair of `float` or `int` data types, depending on the data type of the slider. NB: These + min/max values are not enforced, they are purely for setting the slider range. Values + outside of this range can still be explicitly passed in from the cli. + rounding: int + If the underlying data type for the option is a `float` then this value is the number of + decimal places to round the slider values to. If the underlying data type for the option is + an `int` then this is the step interval between each value for the slider. + + Examples + -------- + For integer values: + + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--foobar"), + >>> action=Slider, + >>> min_max=(0, 10) + >>> rounding=1 + >>> type=int, + >>> default=5)) + + For floating point values: + + >>> argument_list = [] + >>> argument_list.append(dict( + >>> opts=("-f", "--foobar"), + >>> action=Slider, + >>> min_max=(0.00, 1.00) + >>> rounding=2 + >>> type=float, + >>> default=5.00)) + """ + def __init__(self, + *args, + min_max: tuple[int, int] | tuple[float, float] | None = None, + rounding: int | None = None, + **kwargs) -> None: + opt = kwargs["option_strings"] + if kwargs.get("nargs", None) is not None: + raise ValueError(f"nargs not allowed for Slider: {opt}") + if kwargs.get("default", None) is None: + raise ValueError(f"A default value must be supplied for Slider: {opt}") + if kwargs.get("type", None) not in (int, float): + raise ValueError(f"Sliders only accept int and float data types: {opt}") + if min_max is None: + raise ValueError(f"min_max must be provided for Sliders: {opt}") + if rounding is None: + raise ValueError(f"rounding must be provided for Sliders: {opt}") + + super().__init__(*args, **kwargs) + self.min_max = min_max + self.rounding = rounding + + def _get_kwargs(self) -> list[tuple[str, T.Any]]: + names = ["option_strings", + "dest", + "nargs", + "const", + "default", + "type", + "choices", + "help", + "metavar", + "min_max", # Tuple containing min and max values of scale + "rounding"] # Decimal places to round floats to or step interval for ints + return [(name, getattr(self, name)) for name in names] + + def __call__(self, parser, namespace, values, option_string=None) -> None: + setattr(namespace, self.dest, values) + + +__all__ = get_module_objects(__name__) diff --git a/lib/cli/args.py b/lib/cli/args.py new file mode 100644 index 0000000000..10002045f8 --- /dev/null +++ b/lib/cli/args.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +""" The global and GUI Command Line Argument options for faceswap.py """ + +import argparse +import gettext +import logging +import re +import sys +import textwrap +import typing as T + +from lib.utils import get_backend, get_module_objects +from lib.gpu_stats import GPUStats + +from .actions import FileFullPaths, MultiOption, SaveFileFullPaths +from .launcher import ScriptExecutor + +logger = logging.getLogger(__name__) + + +if GPUStats is None: + _GPUS = [] +else: + _GPUS = GPUStats().cli_devices + +# LOCALES +_LANG = gettext.translation("lib.cli.args", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class FullHelpArgumentParser(argparse.ArgumentParser): + """ Extends :class:`argparse.ArgumentParser` to output full help on bad arguments. """ + def error(self, message: str) -> T.NoReturn: + self.print_help(sys.stderr) + self.exit(2, f"{self.prog}: error: {message}\n") + + +class SmartFormatter(argparse.HelpFormatter): + """ Extends the class :class:`argparse.HelpFormatter` to allow custom formatting in help text. + + Adapted from: https://stackoverflow.com/questions/3853722 + + Notes + ----- + Prefix help text with "R|" to override default formatting and use explicitly defined formatting + within the help text. + Prefixing a new line within the help text with "L|" will turn that line into a list item in + both the cli help text and the GUI. + """ + def __init__(self, + prog: str, + indent_increment: int = 2, + max_help_position: int = 24, + width: int | None = None) -> None: + super().__init__(prog, indent_increment, max_help_position, width) + self._whitespace_matcher_limited = re.compile(r'[ \r\f\v]+', re.ASCII) + + def _split_lines(self, text: str, width: int) -> list[str]: + """ Split the given text by the given display width. + + If the text is not prefixed with "R|" then the standard + :func:`argparse.HelpFormatter._split_lines` function is used, otherwise raw + formatting is processed, + + Parameters + ---------- + text: str + The help text that is to be formatted for display + width: int + The display width, in characters, for the help text + + Returns + ------- + list + A list of split strings + """ + if text.startswith("R|"): + text = self._whitespace_matcher_limited.sub(' ', text).strip()[2:] + output = [] + for txt in text.splitlines(): + indent = "" + if txt.startswith("L|"): + indent = " " + txt = f" - {txt[2:]}" + output.extend(textwrap.wrap(txt, width, subsequent_indent=indent)) + return output + return argparse.HelpFormatter._split_lines(self, # pylint:disable=protected-access + text, + width) + + +class FaceSwapArgs(): + """ Faceswap argument parser functions that are universal to all commands. + + This is the parent class to all subsequent argparsers which holds global arguments that pertain + to all commands. + + Process the incoming command line arguments, validates then launches the relevant faceswap + script with the given arguments. + + Parameters + ---------- + subparser: :class:`argparse._SubParsersAction` | None + The subparser for the given command. ``None`` if the class is being called for reading + rather than processing + command: str + The faceswap command that is to be executed + description: str, optional + The description for the given command. Default: "default" + """ + def __init__(self, + subparser: argparse._SubParsersAction | None, + command: str, + description: str = "default") -> None: + self.global_arguments = self._get_global_arguments() + self.info: str = self.get_info() + self.argument_list = self.get_argument_list() + self.optional_arguments = self.get_optional_arguments() + self._process_suppressions() + if not subparser: + return + self.parser = self._create_parser(subparser, command, description) + self._add_arguments() + script = ScriptExecutor(command) + self.parser.set_defaults(func=script.execute_script) + + @staticmethod + def get_info() -> str: + """ Returns the information text for the current command. + + This function should be overridden with the actual command help text for each + commands' parser. + + Returns + ------- + str + The information text for this command. + """ + return "" + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Returns the argument list for the current command. + + The argument list should be a list of dictionaries pertaining to each option for a command. + This function should be overridden with the actual argument list for each command's + argument list. + + See existing parsers for examples. + + Returns + ------- + list + The list of command line options for the given command + """ + argument_list: list[dict[str, T.Any]] = [] + return argument_list + + @staticmethod + def get_optional_arguments() -> list[dict[str, T.Any]]: + """ Returns the optional argument list for the current command. + + The optional arguments list is not always required, but is used when there are shared + options between multiple commands (e.g. convert and extract). Only override if required. + + Returns + ------- + list + The list of optional command line options for the given command + """ + argument_list: list[dict[str, T.Any]] = [] + return argument_list + + @staticmethod + def _get_global_arguments() -> list[dict[str, T.Any]]: + """ Returns the global Arguments list that are required for ALL commands in Faceswap. + + This method should NOT be overridden. + + Returns + ------- + list + The list of global command line options for all Faceswap commands. + """ + global_args: list[dict[str, T.Any]] = [] + if _GPUS: + global_args.append({ + "opts": ("-X", "--exclude-gpus"), + "dest": "exclude_gpus", + "action": MultiOption, + "type": str.lower, + "nargs": "+", + "choices": [str(idx) for idx in range(len(_GPUS))], + "group": _("Global Options"), + "help": _( + "R|Exclude GPUs from use by Faceswap. Select the number(s) which correspond " + "to any GPU(s) that you do not wish to be made available to Faceswap. " + "Selecting all GPUs here will force Faceswap into CPU mode." + "\nL|{}".format(' \nL|'.join(_GPUS)))}) + global_args.append({ + "opts": ("-C", "--configfile"), + "action": FileFullPaths, + "filetypes": "ini", + "type": str, + "group": _("Global Options"), + "help": _( + "Optionally overide the saved config with the path to a custom config file.")}) + global_args.append({ + "opts": ("-L", "--loglevel"), + "type": str.upper, + "dest": "loglevel", + "default": "INFO", + "choices": ("INFO", "VERBOSE", "DEBUG", "TRACE"), + "group": _("Global Options"), + "help": _( + "Log level. Stick with INFO or VERBOSE unless you need to file an error report. " + "Be careful with TRACE as it will generate a lot of data")}) + global_args.append({ + "opts": ("-F", "--logfile"), + "action": SaveFileFullPaths, + "filetypes": 'log', + "type": str, + "dest": "logfile", + "default": None, + "group": _("Global Options"), + "help": _("Path to store the logfile. Leave blank to store in the faceswap folder")}) + # These are hidden arguments to indicate that the GUI/Colab is being used + global_args.append({ + "opts": ("-G", "--gui"), + "action": "store_true", + "dest": "redirect_gui", + "default": False, + "help": argparse.SUPPRESS}) + return global_args + + @staticmethod + def _create_parser(subparser: argparse._SubParsersAction, + command: str, + description: str) -> argparse.ArgumentParser: + """ Create the parser for the selected command. + + Parameters + ---------- + subparser: :class:`argparse._SubParsersAction` + The subparser for the given command + command: str + The faceswap command that is to be executed + description: str + The description for the given command + + + Returns + ------- + :class:`~lib.cli.args.FullHelpArgumentParser` + The parser for the given command + """ + parser = subparser.add_parser(command, + help=description, + description=description, + epilog="Questions and feedback: https://faceswap.dev/forum", + formatter_class=SmartFormatter) + return parser + + def _add_arguments(self) -> None: + """ Parse the list of dictionaries containing the command line arguments and convert to + argparse parser arguments. """ + options = self.global_arguments + self.argument_list + self.optional_arguments + for option in options: + args = option["opts"] + kwargs = {key: option[key] for key in option.keys() if key not in ("opts", "group")} + self.parser.add_argument(*args, **kwargs) + + def _process_suppressions(self) -> None: + """ Certain options are only available for certain backends. + + Suppresses command line options that are not available for the running backend. + """ + fs_backend = get_backend() + for opt_list in [self.global_arguments, self.argument_list, self.optional_arguments]: + for opts in opt_list: + if opts.get("backend", None) is None: + continue + opt_backend = opts.pop("backend") + if isinstance(opt_backend, (list, tuple)): + opt_backend = [backend.lower() for backend in opt_backend] + else: + opt_backend = [opt_backend.lower()] + if fs_backend not in opt_backend: + opts["help"] = argparse.SUPPRESS + + +class GuiArgs(FaceSwapArgs): + """ Creates the command line arguments for the GUI. """ + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Returns the argument list for GUI arguments. + + Returns + ------- + list + The list of command line options for the GUI + """ + argument_list: list[dict[str, T.Any]] = [] + argument_list.append({ + "opts": ("-d", "--debug"), + "action": "store_true", + "dest": "debug", + "default": False, + "help": _("Output to Shell console instead of GUI console")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/lib/cli/args_extract_convert.py b/lib/cli/args_extract_convert.py new file mode 100644 index 0000000000..5c0b3ce903 --- /dev/null +++ b/lib/cli/args_extract_convert.py @@ -0,0 +1,663 @@ +#!/usr/bin/env python3 +""" The Command Line Argument options for extracting and converting with faceswap.py """ +import gettext +import typing as T + +from lib.utils import get_module_objects +from lib.utils import get_backend +from plugins.plugin_loader import PluginLoader + +from .actions import (DirFullPaths, DirOrFileFullPaths, DirOrFilesFullPaths, FileFullPaths, + FilesFullPaths, MultiOption, Radio, Slider) +from .args import FaceSwapArgs + + +# LOCALES +_LANG = gettext.translation("lib.cli.args_extract_convert", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class ExtractConvertArgs(FaceSwapArgs): + """ Parent class to capture arguments that will be used in both extract and convert processes. + + Extract and Convert share a fair amount of arguments, so arguments that can be used in both of + these processes should be placed here. + + No further processing is done in this class (this is handled by the children), this just + captures the shared arguments. + """ + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Returns the argument list for shared Extract and Convert arguments. + + Returns + ------- + list + The list of command line options for the given Extract and Convert + """ + argument_list: list[dict[str, T.Any]] = [] + argument_list.append({ + "opts": ("-i", "--input-dir"), + "action": DirOrFileFullPaths, + "filetypes": "video", + "dest": "input_dir", + "required": True, + "group": _("Data"), + "help": _( + "Input directory or video. Either a directory containing the image files you wish " + "to process or path to a video file. NB: This should be the source video/frames " + "NOT the source faces.")}) + argument_list.append({ + "opts": ("-o", "--output-dir"), + "action": DirFullPaths, + "dest": "output_dir", + "required": True, + "group": _("Data"), + "help": _("Output directory. This is where the converted files will be saved.")}) + argument_list.append({ + "opts": ("-p", "--alignments"), + "action": FileFullPaths, + "filetypes": "alignments", + "type": str, + "dest": "alignments_path", + "group": _("Data"), + "help": _( + "Optional path to an alignments file. Leave blank if the alignments file is at " + "the default location.")}) + return argument_list + + +class ExtractArgs(ExtractConvertArgs): + """ Creates the command line arguments for extraction. + + This class inherits base options from :class:`ExtractConvertArgs` where arguments that are used + for both Extract and Convert should be placed. + + Commands explicit to Extract should be added in :func:`get_optional_arguments` + """ + + @staticmethod + def get_info() -> str: + """ The information text for the Extract command. + + Returns + ------- + str + The information text for the Extract command. + """ + return _("Extract faces from image or video sources.\n" + "Extraction plugins can be configured in the 'Settings' Menu") + + @staticmethod + def get_optional_arguments() -> list[dict[str, T.Any]]: + """ Returns the argument list unique to the Extract command. + + Returns + ------- + list + The list of optional command line options for the Extract command + """ + if get_backend() == "cpu": + default_detector = "mtcnn" + default_aligner = "cv2-dnn" + else: + default_detector = "s3fd" + default_aligner = "fan" + + argument_list: list[dict[str, T.Any]] = [] + argument_list.append({ + "opts": ("-b", "--batch-mode"), + "action": "store_true", + "dest": "batch_mode", + "default": False, + "group": _("Data"), + "help": _( + "R|If selected then the input_dir should be a parent folder containing multiple " + "videos and/or folders of images you wish to extract from. The faces will be " + "output to separate sub-folders in the output_dir.")}) + argument_list.append({ + "opts": ("-D", "--detector"), + "action": Radio, + "type": str.lower, + "default": default_detector, + "choices": PluginLoader.get_available_extractors("detect"), + "group": _("Plugins"), + "help": _( + "R|Detector to use. Some of these have configurable settings in " + "'/config/extract.ini' or 'Settings > Configure Extract 'Plugins':" + "\nL|cv2-dnn: A CPU only extractor which is the least reliable and least resource " + "intensive. Use this if not using a GPU and time is important." + "\nL|mtcnn: Good detector. Fast on CPU, faster on GPU. Uses fewer resources than " + "other GPU detectors but can often return more false positives." + "\nL|s3fd: Best detector. Slow on CPU, faster on GPU. Can detect more faces and " + "fewer false positives than other GPU detectors, but is a lot more resource " + "intensive." + "\nL|external: Import a face detection bounding box from a json file. (" + "configurable in Detect settings)")}) + argument_list.append({ + "opts": ("-A", "--aligner"), + "action": Radio, + "type": str.lower, + "default": default_aligner, + "choices": PluginLoader.get_available_extractors("align"), + "group": _("Plugins"), + "help": _( + "R|Aligner to use." + "\nL|cv2-dnn: A CPU only landmark detector. Faster, less resource intensive, but " + "less accurate. Only use this if not using a GPU and time is important." + "\nL|fan: Best aligner. Fast on GPU, slow on CPU." + "\nL|external: Import 68 point 2D landmarks or an aligned bounding box from a " + "json file. (configurable in Align settings)")}) + argument_list.append({ + "opts": ("-M", "--masker"), + "action": MultiOption, + "type": str.lower, + "nargs": "+", + "choices": [mask for mask in PluginLoader.get_available_extractors("mask") + if mask not in ("components", "extended")], + "group": _("Plugins"), + "help": _( + "R|Additional Masker(s) to use. The masks generated here will all take up GPU " + "RAM. You can select none, one or multiple masks, but the extraction may take " + "longer the more you select. NB: The Extended and Components (landmark based) " + "masks are automatically generated on extraction." + "\nL|bisenet-fp: Relatively lightweight NN based mask that provides more refined " + "control over the area to be masked including full head masking (configurable in " + "mask settings)." + "\nL|custom: A dummy mask that fills the mask area with all 1s or 0s (" + "configurable in settings). This is only required if you intend to manually edit " + "the custom masks yourself in the manual tool. This mask does not use the GPU so " + "will not use any additional VRAM." + "\nL|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " + "faces clear of obstructions. Profile faces and obstructions may result in " + "sub-par performance." + "\nL|vgg-obstructed: Mask designed to provide smart segmentation of mostly " + "frontal faces. The mask model has been specifically trained to recognize some " + "facial obstructions (hands and eyeglasses). Profile faces may result in sub-par " + "performance." + "\nL|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " + "faces. The mask model has been trained by community members and will need " + "testing for further description. Profile faces may result in sub-par " + "performance." + "\nThe auto generated masks are as follows:" + "\nL|components: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks to create a mask." + "\nL|extended: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks and the mask is extended upwards onto the forehead." + "\n(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)")}) + argument_list.append({ + "opts": ("-O", "--normalization"), + "action": Radio, + "type": str.lower, + "dest": "normalization", + "default": "none", + "choices": ["none", "clahe", "hist", "mean"], + "group": _("Plugins"), + "help": _( + "R|Performing normalization can help the aligner better align faces with " + "difficult lighting conditions at an extraction speed cost. Different methods " + "will yield different results on different sets. NB: This does not impact the " + "output face, just the input to the aligner." + "\nL|none: Don't perform normalization on the face." + "\nL|clahe: Perform Contrast Limited Adaptive Histogram Equalization on the face." + "\nL|hist: Equalize the histograms on the RGB channels." + "\nL|mean: Normalize the face colors to the mean.")}) + argument_list.append({ + "opts": ("-R", "--re-feed"), + "action": Slider, + "min_max": (0, 10), + "rounding": 1, + "type": int, + "dest": "re_feed", + "default": 0, + "group": _("Plugins"), + "help": _( + "The number of times to re-feed the detected face into the aligner. Each time the " + "face is re-fed into the aligner the bounding box is adjusted by a small amount. " + "The final landmarks are then averaged from each iteration. Helps to remove " + "'micro-jitter' but at the cost of slower extraction speed. The more times the " + "face is re-fed into the aligner, the less micro-jitter should occur but the " + "longer extraction will take.")}) + argument_list.append({ + "opts": ("-a", "--re-align"), + "action": "store_true", + "dest": "re_align", + "default": False, + "group": _("Plugins"), + "help": _( + "Re-feed the initially found aligned face through the aligner. Can help produce " + "better alignments for faces that are rotated beyond 45 degrees in the frame or " + "are at extreme angles. Slows down extraction.")}) + argument_list.append({ + "opts": ("-r", "--rotate-images"), + "type": str, + "dest": "rotate_images", + "default": None, + "group": _("Plugins"), + "help": _( + "If a face isn't found, rotate the images to try to find a face. Can find more " + "faces at the cost of extraction speed. Pass in a single number to use increments " + "of that size up to 360, or pass in a list of numbers to enumerate exactly what " + "angles to check.")}) + argument_list.append({ + "opts": ("-I", "--identity"), + "action": "store_true", + "default": False, + "group": _("Plugins"), + "help": _( + "Obtain and store face identity encodings from VGGFace2. Slows down extract a " + "little, but will save time if using 'sort by face'")}) + argument_list.append({ + "opts": ("-m", "--min-size"), + "action": Slider, + "min_max": (0, 1080), + "rounding": 20, + "type": int, + "dest": "min_size", + "default": 0, + "group": _("Face Processing"), + "help": _( + "Filters out faces detected below this size. Length, in pixels across the " + "diagonal of the bounding box. Set to 0 for off")}) + argument_list.append({ + "opts": ("-n", "--nfilter"), + "action": DirOrFilesFullPaths, + "filetypes": "image", + "dest": "nfilter", + "default": None, + "nargs": "+", + "group": _("Face Processing"), + "help": _( + "Optionally filter out people who you do not wish to extract by passing in images " + "of those people. Should be a small variety of images at different angles and in " + "different conditions. A folder containing the required images or multiple image " + "files, space separated, can be selected.")}) + argument_list.append({ + "opts": ("-f", "--filter"), + "action": DirOrFilesFullPaths, + "filetypes": "image", + "dest": "filter", + "default": None, + "nargs": "+", + "group": _("Face Processing"), + "help": _( + "Optionally select people you wish to extract by passing in images of that " + "person. Should be a small variety of images at different angles and in different " + "conditions A folder containing the required images or multiple image files, " + "space separated, can be selected.")}) + argument_list.append({ + "opts": ("-l", "--ref_threshold"), + "action": Slider, + "min_max": (0.01, 0.99), + "rounding": 2, + "type": float, + "dest": "ref_threshold", + "default": 0.60, + "group": _("Face Processing"), + "help": _( + "For use with the optional nfilter/filter files. Threshold for positive face " + "recognition. Higher values are stricter.")}) + argument_list.append({ + "opts": ("-z", "--size"), + "action": Slider, + "min_max": (256, 1024), + "rounding": 64, + "type": int, + "default": 512, + "group": _("output"), + "help": _( + "The output size of extracted faces. Make sure that the model you intend to train " + "supports your required size. This will only need to be changed for hi-res " + "models.")}) + argument_list.append({ + "opts": ("-N", "--extract-every-n"), + "action": Slider, + "min_max": (1, 100), + "rounding": 1, + "type": int, + "dest": "extract_every_n", + "default": 1, + "group": _("output"), + "help": _( + "Extract every 'nth' frame. This option will skip frames when extracting faces. " + "For example a value of 1 will extract faces from every frame, a value of 10 will " + "extract faces from every 10th frame.")}) + argument_list.append({ + "opts": ("-v", "--save-interval"), + "action": Slider, + "min_max": (0, 1000), + "rounding": 10, + "type": int, + "dest": "save_interval", + "default": 0, + "group": _("output"), + "help": _( + "Automatically save the alignments file after a set amount of frames. By default " + "the alignments file is only saved at the end of the extraction process. NB: If " + "extracting in 2 passes then the alignments file will only start to be saved out " + "during the second pass. WARNING: Don't interrupt the script when writing the " + "file because it might get corrupted. Set to 0 to turn off")}) + argument_list.append({ + "opts": ("-B", "--debug-landmarks"), + "action": "store_true", + "dest": "debug_landmarks", + "default": False, + "group": _("output"), + "help": _("Draw landmarks on the ouput faces for debugging purposes.")}) + argument_list.append({ + "opts": ("-P", "--singleprocess"), + "action": "store_true", + "default": False, + "backend": ("nvidia", "rocm", "apple_silicon"), + "group": _("settings"), + "help": _( + "Don't run extraction in parallel. Will run each part of the extraction process " + "separately (one after the other) rather than all at the same time. Useful if " + "VRAM is at a premium.")}) + argument_list.append({ + "opts": ("-s", "--skip-existing"), + "action": "store_true", + "dest": "skip_existing", + "default": False, + "group": _("settings"), + "help": _( + "Skips frames that have already been extracted and exist in the alignments file")}) + argument_list.append({ + "opts": ("-e", "--skip-existing-faces"), + "action": "store_true", + "dest": "skip_faces", + "default": False, + "group": _("settings"), + "help": _("Skip frames that already have detected faces in the alignments file")}) + argument_list.append({ + "opts": ("-K", "--skip-saving-faces"), + "action": "store_true", + "dest": "skip_saving_faces", + "default": False, + "group": _("settings"), + "help": _("Skip saving the detected faces to disk. Just create an alignments file")}) + return argument_list + + +class ConvertArgs(ExtractConvertArgs): + """ Creates the command line arguments for conversion. + + This class inherits base options from :class:`ExtractConvertArgs` where arguments that are used + for both Extract and Convert should be placed. + + Commands explicit to Convert should be added in :func:`get_optional_arguments` + """ + + @staticmethod + def get_info() -> str: + """ The information text for the Convert command. + + Returns + ------- + str + The information text for the Convert command. + """ + return _("Swap the original faces in a source video/images to your final faces.\n" + "Conversion plugins can be configured in the 'Settings' Menu") + + @staticmethod + def get_optional_arguments() -> list[dict[str, T.Any]]: + """ Returns the argument list unique to the Convert command. + + Returns + ------- + list + The list of optional command line options for the Convert command + """ + + argument_list: list[dict[str, T.Any]] = [] + argument_list.append({ + "opts": ("-r", "--reference-video"), + "action": FileFullPaths, + "filetypes": "video", + "type": str, + "dest": "reference_video", + "group": _("Data"), + "help": _( + "Only required if converting from images to video. Provide The original video " + "that the source frames were extracted from (for extracting the fps and audio).")}) + argument_list.append({ # pylint:disable=duplicate-code + "opts": ("-m", "--model-dir"), + "action": DirFullPaths, + "dest": "model_dir", + "required": True, + "group": _("Data"), + "help": _( + "Model directory. The directory containing the trained model you wish to use for " + "conversion.")}) + argument_list.append({ + "opts": ("-c", "--color-adjustment"), + "action": Radio, + "type": str.lower, + "dest": "color_adjustment", + "default": "avg-color", + "choices": PluginLoader.get_available_convert_plugins("color", True), + "group": _("Plugins"), + "help": _( + "R|Performs color adjustment to the swapped face. Some of these options have " + "configurable settings in '/config/convert.ini' or 'Settings > Configure Convert " + "Plugins':" + "\nL|avg-color: Adjust the mean of each color channel in the swapped " + "reconstruction to equal the mean of the masked area in the original image." + "\nL|color-transfer: Transfers the color distribution from the source to the " + "target image using the mean and standard deviations of the L*a*b* color space." + "\nL|manual-balance: Manually adjust the balance of the image in a variety of " + "color spaces. Best used with the Preview tool to set correct values." + "\nL|match-hist: Adjust the histogram of each color channel in the swapped " + "reconstruction to equal the histogram of the masked area in the original image." + "\nL|seamless-clone: Use cv2's seamless clone function to remove extreme " + "gradients at the mask seam by smoothing colors. Generally does not give very " + "satisfactory results." + "\nL|none: Don't perform color adjustment.")}) + argument_list.append({ + "opts": ("-M", "--mask-type"), + "action": Radio, + "type": str.lower, + "dest": "mask_type", + "default": "extended", + "choices": PluginLoader.get_available_extractors("mask", + add_none=True, + extend_plugin=True) + ["predicted"], + "group": _("Plugins"), + "help": _( + "R|Masker to use. NB: The mask you require must exist within the alignments file. " + "You can add additional masks with the Mask Tool." + "\nL|none: Don't use a mask." + "\nL|bisenet-fp_face: Relatively lightweight NN based mask that provides more " + "refined control over the area to be masked (configurable in mask settings). Use " + "this version of bisenet-fp if your model is trained with 'face' or " + "'legacy' centering." + "\nL|bisenet-fp_head: Relatively lightweight NN based mask that provides more " + "refined control over the area to be masked (configurable in mask settings). Use " + "this version of bisenet-fp if your model is trained with 'head' centering." + "\nL|custom_face: Custom user created, face centered mask." + "\nL|custom_head: Custom user created, head centered mask." + "\nL|components: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks to create a mask." + "\nL|extended: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks and the mask is extended upwards onto the forehead." + "\nL|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " + "faces clear of obstructions. Profile faces and obstructions may result in sub-" + "par performance." + "\nL|vgg-obstructed: Mask designed to provide smart segmentation of mostly " + "frontal faces. The mask model has been specifically trained to recognize some " + "facial obstructions (hands and eyeglasses). Profile faces may result in sub-par " + "performance." + "\nL|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " + "faces. The mask model has been trained by community members and will need " + "testing for further description. Profile faces may result in sub-par " + "performance." + "\nL|predicted: If the 'Learn Mask' option was enabled during training, this will " + "use the mask that was created by the trained model.")}) + argument_list.append({ + "opts": ("-w", "--writer"), + "action": Radio, + "type": str, + "default": "opencv", + "choices": PluginLoader.get_available_convert_plugins("writer", False), + "group": _("Plugins"), + "help": _( + "R|The plugin to use to output the converted images. The writers are configurable " + "in '/config/convert.ini' or 'Settings > Configure Convert Plugins:'" + "\nL|ffmpeg: [video] Writes out the convert straight to video. When the input is " + "a series of images then the '-ref' (--reference-video) parameter must be set." + "\nL|gif: [animated image] Create an animated gif." + "\nL|opencv: [images] The fastest image writer, but less options and formats than " + "other plugins." + "\nL|patch: [images] Outputs the raw swapped face patch, along with the " + "transformation matrix required to re-insert the face back into the original " + "frame. Use this option if you wish to post-process and composite the final face " + "within external tools." + "\nL|pillow: [images] Slower than opencv, but has more options and supports more " + "formats.")}) + argument_list.append({ + "opts": ("-O", "--output-scale"), + "action": Slider, + "min_max": (25, 400), + "rounding": 1, + "type": int, + "dest": "output_scale", + "default": 100, + "group": _("Frame Processing"), + "help": _( + "Scale the final output frames by this amount. 100%% will output the frames at " + "source dimensions. 50%% at half size 200%% at double size")}) + argument_list.append({ + "opts": ("-R", "--frame-ranges"), + "type": str, + "nargs": "+", + "dest": "frame_ranges", + "group": _("Frame Processing"), + "help": _( + "Frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use " + "--frame-ranges 10-50 90-100. Frames falling outside of the selected range will " + "be discarded unless '-k' (--keep-unchanged) is selected. NB: If you are " + "converting from images, then the filenames must end with the frame-number!")}) + argument_list.append({ + "opts": ("-S", "--face-scale"), + "action": Slider, + "min_max": (-10.0, 10.0), + "rounding": 2, + "dest": "face_scale", + "type": float, + "default": 0.0, + "group": _("Face Processing"), + "help": _( + "Scale the swapped face by this percentage. Positive values will enlarge the " + "face, Negative values will shrink the face.")}) + argument_list.append({ + "opts": ("-a", "--input-aligned-dir"), + "action": DirFullPaths, + "dest": "input_aligned_dir", + "default": None, + "group": _("Face Processing"), + "help": _( + "If you have not cleansed your alignments file, then you can filter out faces by " + "defining a folder here that contains the faces extracted from your input files/" + "video. If this folder is defined, then only faces that exist within your " + "alignments file and also exist within the specified folder will be converted. " + "Leaving this blank will convert all faces that exist within the alignments " + "file.")}) + argument_list.append({ + "opts": ("-n", "--nfilter"), + "action": FilesFullPaths, + "filetypes": "image", + "dest": "nfilter", + "default": None, + "nargs": "+", + "group": _("Face Processing"), + "help": _( + "Optionally filter out people who you do not wish to process by passing in an " + "image of that person. Should be a front portrait with a single person in the " + "image. Multiple images can be added space separated. NB: Using face filter will " + "significantly decrease extraction speed and its accuracy cannot be guaranteed.")}) + argument_list.append({ + "opts": ("-f", "--filter"), + "action": FilesFullPaths, + "filetypes": "image", + "dest": "filter", + "default": None, + "nargs": "+", + "group": _("Face Processing"), + "help": _( + "Optionally select people you wish to process by passing in an image of that " + "person. Should be a front portrait with a single person in the image. Multiple " + "images can be added space separated. NB: Using face filter will significantly " + "decrease extraction speed and its accuracy cannot be guaranteed.")}) + argument_list.append({ + "opts": ("-l", "--ref_threshold"), + "action": Slider, + "min_max": (0.01, 0.99), + "rounding": 2, + "type": float, + "dest": "ref_threshold", + "default": 0.4, + "group": _("Face Processing"), + "help": _( + "For use with the optional nfilter/filter files. Threshold for positive face " + "recognition. Lower values are stricter. NB: Using face filter will significantly " + "decrease extraction speed and its accuracy cannot be guaranteed.")}) + argument_list.append({ + "opts": ("-j", "--jobs"), + "action": Slider, + "min_max": (0, 40), + "rounding": 1, + "type": int, + "dest": "jobs", + "default": 0, + "group": _("settings"), + "help": _( + "The maximum number of parallel processes for performing conversion. Converting " + "images is system RAM heavy so it is possible to run out of memory if you have a " + "lot of processes and not enough RAM to accommodate them all. Setting this to 0 " + "will use the maximum available. No matter what you set this to, it will never " + "attempt to use more processes than are available on your system. If " + "singleprocess is enabled this setting will be ignored.")}) + argument_list.append({ + "opts": ("-T", "--on-the-fly"), + "action": "store_true", + "dest": "on_the_fly", + "default": False, + "group": _("settings"), + "help": _( + "Enable On-The-Fly Conversion. NOT recommended. You should generate a clean " + "alignments file for your destination video. However, if you wish you can " + "generate the alignments on-the-fly by enabling this option. This will use an " + "inferior extraction pipeline and will lead to substandard results. If an " + "alignments file is found, this option will be ignored.")}) + argument_list.append({ + "opts": ("-k", "--keep-unchanged"), + "action": "store_true", + "dest": "keep_unchanged", + "default": False, + "group": _("Frame Processing"), + "help": _( + "When used with --frame-ranges outputs the unchanged frames that are not " + "processed instead of discarding them.")}) + argument_list.append({ + "opts": ("-s", "--swap-model"), + "action": "store_true", + "dest": "swap_model", + "default": False, + "group": _("settings"), + "help": _("Swap the model. Instead converting from of A -> B, converts B -> A")}) + argument_list.append({ + "opts": ("-P", "--singleprocess"), + "action": "store_true", + "default": False, + "group": _("settings"), + "help": _("Disable multiprocessing. Slower but less resource intensive.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/lib/cli/args_train.py b/lib/cli/args_train.py new file mode 100644 index 0000000000..821268f321 --- /dev/null +++ b/lib/cli/args_train.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +""" The Command Line Argument options for training with faceswap.py """ +import gettext +import typing as T + +from lib.utils import get_module_objects +from plugins.plugin_loader import PluginLoader + +from .actions import DirFullPaths, FileFullPaths, Radio, Slider +from .args import FaceSwapArgs + + +# LOCALES +_LANG = gettext.translation("lib.cli.args_train", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class TrainArgs(FaceSwapArgs): + """ Creates the command line arguments for training. """ + + @staticmethod + def get_info() -> str: + """ The information text for the Train command. + + Returns + ------- + str + The information text for the Train command. + """ + return _("Train a model on extracted original (A) and swap (B) faces.\n" + "Training models can take a long time. Anything from 24hrs to over a week\n" + "Model plugins can be configured in the 'Settings' Menu") + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Returns the argument list for Train arguments. + + Returns + ------- + list + The list of command line options for training + """ + argument_list: list[dict[str, T.Any]] = [] + argument_list.append({ + "opts": ("-A", "--input-A"), + "action": DirFullPaths, + "dest": "input_a", + "required": True, + "group": _("faces"), + "help": _( + "Input directory. A directory containing training images for face A. This is the " + "original face, i.e. the face that you want to remove and replace with face B.")}) + argument_list.append({ + "opts": ("-B", "--input-B"), + "action": DirFullPaths, + "dest": "input_b", + "required": True, + "group": _("faces"), + "help": _( + "Input directory. A directory containing training images for face B. This is the " + "swap face, i.e. the face that you want to place onto the head of person A.")}) + argument_list.append({ + "opts": ("-m", "--model-dir"), + "action": DirFullPaths, + "dest": "model_dir", + "required": True, + "group": _("model"), + "help": _( + "Model directory. This is where the training data will be stored. You should " + "always specify a new folder for new models. If starting a new model, select " + "either an empty folder, or a folder which does not exist (which will be " + "created). If continuing to train an existing model, specify the location of the " + "existing model.")}) + argument_list.append({ + "opts": ("-l", "--load-weights"), + "action": FileFullPaths, + "filetypes": "model", + "dest": "load_weights", + "required": False, + "group": _("model"), + "help": _( + "R|Load the weights from a pre-existing model into a newly created model. For " + "most models this will load weights from the Encoder of the given model into the " + "encoder of the newly created model. Some plugins may have specific configuration " + "options allowing you to load weights from other layers. Weights will only be " + "loaded when creating a new model. This option will be ignored if you are " + "resuming an existing model. Generally you will also want to 'freeze-weights' " + "whilst the rest of your model catches up with your Encoder.\n" + "NB: Weights can only be loaded from models of the same plugin as you intend to " + "train.")}) + argument_list.append({ + "opts": ("-t", "--trainer"), + "action": Radio, + "type": str.lower, + "default": PluginLoader.get_default_model(), + "choices": PluginLoader.get_available_models(), + "group": _("model"), + "help": _( + "R|Select which trainer to use. Trainers can be configured from the Settings menu " + "or the config folder." + "\nL|original: The original model created by /u/deepfakes." + "\nL|dfaker: 64px in/128px out model from dfaker. Enable 'warp-to-landmarks' for " + "full dfaker method." + "\nL|dfl-h128: 128px in/out model from deepfacelab" + "\nL|dfl-sae: Adaptable model from deepfacelab" + "\nL|dlight: A lightweight, high resolution DFaker variant." + "\nL|iae: A model that uses intermediate layers to try to get better details" + "\nL|lightweight: A lightweight model for low-end cards. Don't expect great " + "results. Can train as low as 1.6GB with batch size 8." + "\nL|realface: A high detail, dual density model based on DFaker, with " + "customizable in/out resolution. The autoencoders are unbalanced so B>A swaps " + "won't work so well. By andenixa et al. Very configurable." + "\nL|unbalanced: 128px in/out model from andenixa. The autoencoders are " + "unbalanced so B>A swaps won't work so well. Very configurable." + "\nL|villain: 128px in/out model from villainguy. Very resource hungry (You will " + "require a GPU with a fair amount of VRAM). Good for details, but more " + "susceptible to color differences.")}) + argument_list.append({ + "opts": ("-u", "--summary"), + "action": "store_true", + "dest": "summary", + "default": False, + "group": _("model"), + "help": _( + "Output a summary of the model and exit. If a model folder is provided then a " + "summary of the saved model is displayed. Otherwise a summary of the model that " + "would be created by the chosen plugin and configuration settings is displayed.")}) + argument_list.append({ + "opts": ("-f", "--freeze-weights"), + "action": "store_true", + "dest": "freeze_weights", + "default": False, + "group": _("model"), + "help": _( + "Freeze the weights of the model. Freezing weights means that some of the " + "parameters in the model will no longer continue to learn, but those that are not " + "frozen will continue to learn. For most models, this will freeze the encoder, " + "but some models may have configuration options for freezing other layers.")}) + argument_list.append({ + "opts": ("-b", "--batch-size"), + "action": Slider, + "min_max": (1, 256), + "rounding": 1, + "type": int, + "dest": "batch_size", + "default": 16, + "group": _("training"), + "help": _( + "Batch size. This is the number of images processed through the model for each " + "side per iteration. NB: As the model is fed 2 sides at a time, the actual number " + "of images within the model at any one time is double the number that you set " + "here. Larger batches require more GPU RAM.")}) + argument_list.append({ + "opts": ("-i", "--iterations"), + "action": Slider, + "min_max": (0, 5000000), + "rounding": 20000, + "type": int, + "default": 1000000, + "group": _("training"), + "help": _( + "Length of training in iterations. This is only really used for automation. There " + "is no 'correct' number of iterations a model should be trained for. You should " + "stop training when you are happy with the previews. However, if you want the " + "model to stop automatically at a set number of iterations, you can set that " + "value here.")}) + argument_list.append({ + "opts": ("-a", "--warmup"), + "action": Slider, + "min_max": (0, 5000), + "rounding": 100, + "type": int, + "default": 0, + "group": _("training"), + "help": _( + "Learning rate warmup. Linearly increase the learning rate from 0 to the chosen " + "target rate over the number of iterations given here. 0 to disable.")}) + argument_list.append({ + "opts": ("-d", "--distributed"), + "dest": "distributed", + "action": "store_true", + "default": False, + "backend": ("nvidia", "rocm"), + "group": _("training"), + "help": _("Use distibuted training on multi-gpu setups.")}) + argument_list.append({ + "opts": ("-n", "--no-logs"), + "action": "store_true", + "dest": "no_logs", + "default": False, + "group": _("training"), + "help": _( + "Disables TensorBoard logging. NB: Disabling logs means that you will not be able " + "to use the graph or analysis for this session in the GUI.")}) + argument_list.append({ + "opts": ("-r", "--use-lr-finder"), + "action": "store_true", + "dest": "use_lr_finder", + "default": False, + "group": _("training"), + "help": _( + "Use the Learning Rate Finder to discover the optimal learning rate for training. " + "For new models, this will calculate the optimal learning rate for the model. For " + "existing models this will use the optimal learning rate that was discovered when " + "initializing the model. Setting this option will ignore the manually configured " + "learning rate (configurable in train settings).")}) + argument_list.append({ + "opts": ("-s", "--save-interval"), + "action": Slider, + "min_max": (10, 1000), + "rounding": 10, + "type": int, + "dest": "save_interval", + "default": 250, + "group": _("Saving"), + "help": _("Sets the number of iterations between each model save.")}) + argument_list.append({ + "opts": ("-I", "--snapshot-interval"), + "action": Slider, + "min_max": (0, 100000), + "rounding": 5000, + "type": int, + "dest": "snapshot_interval", + "default": 25000, + "group": _("Saving"), + "help": _( + "Sets the number of iterations before saving a backup snapshot of the model in " + "it's current state. Set to 0 for off.")}) + argument_list.append({ + "opts": ("-x", "--timelapse-input-A"), + "action": DirFullPaths, + "dest": "timelapse_input_a", + "default": None, + "group": _("timelapse"), + "help": _( + "Optional for creating a timelapse. Timelapse will save an image of your selected " + "faces into the timelapse-output folder at every save iteration. This should be " + "the input folder of 'A' faces that you would like to use for creating the " + "timelapse. You must also supply a --timelapse-output and a --timelapse-input-B " + "parameter.")}) + argument_list.append({ + "opts": ("-y", "--timelapse-input-B"), + "action": DirFullPaths, + "dest": "timelapse_input_b", + "default": None, + "group": _("timelapse"), + "help": _( + "Optional for creating a timelapse. Timelapse will save an image of your selected " + "faces into the timelapse-output folder at every save iteration. This should be " + "the input folder of 'B' faces that you would like to use for creating the " + "timelapse. You must also supply a --timelapse-output and a --timelapse-input-A " + "parameter.")}) + argument_list.append({ + "opts": ("-z", "--timelapse-output"), + "action": DirFullPaths, + "dest": "timelapse_output", + "default": None, + "group": _("timelapse"), + "help": _( + "Optional for creating a timelapse. Timelapse will save an image of your selected " + "faces into the timelapse-output folder at every save iteration. If the input " + "folders are supplied but no output folder, it will default to your model folder/" + "timelapse/")}) + argument_list.append({ + "opts": ("-p", "--preview"), + "action": "store_true", + "dest": "preview", + "default": False, + "group": _("preview"), + "help": _("Show training preview output. in a separate window.")}) + argument_list.append({ + "opts": ("-w", "--write-image"), + "action": "store_true", + "dest": "write_image", + "default": False, + "group": _("preview"), + "help": _( + "Writes the training result to a file. The image will be stored in the root of " + "your FaceSwap folder.")}) + argument_list.append({ + "opts": ("-M", "--warp-to-landmarks"), + "action": "store_true", + "dest": "warp_to_landmarks", + "default": False, + "group": _("augmentation"), + "help": _( + "Warps training faces to closely matched Landmarks from the opposite face-set " + "rather than randomly warping the face. This is the 'dfaker' way of doing " + "warping.")}) + argument_list.append({ + "opts": ("-P", "--no-flip"), + "action": "store_true", + "dest": "no_flip", + "default": False, + "group": _("augmentation"), + "help": _( + "To effectively learn, a random set of images are flipped horizontally. Sometimes " + "it is desirable for this not to occur. Generally this should be left off except " + "for during 'fit training'.")}) + argument_list.append({ + "opts": ("-c", "--no-augment-color"), + "action": "store_true", + "dest": "no_augment_color", + "default": False, + "group": _("augmentation"), + "help": _( + "Color augmentation helps make the model less susceptible to color differences " + "between the A and B sets, at an increased training time cost. Enable this option " + "to disable color augmentation.")}) + argument_list.append({ + "opts": ("-W", "--no-warp"), + "action": "store_true", + "dest": "no_warp", + "default": False, + "group": _("augmentation"), + "help": _( + "Warping is integral to training the Neural Network. This option should only be " + "enabled towards the very end of training to try to bring out more detail. Think " + "of it as 'fine-tuning'. Enabling this option from the beginning is likely to " + "kill a model and lead to terrible results.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/lib/cli/launcher.py b/lib/cli/launcher.py new file mode 100644 index 0000000000..b96b2cc13f --- /dev/null +++ b/lib/cli/launcher.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +""" Launches the correct script with the given Command Line Arguments """ +from __future__ import annotations +import logging +import os +import platform +import sys +import typing as T + +from importlib import import_module + +from lib.gpu_stats import GPUStats +from lib.logger import crash_log, log_setup +from lib.utils import (FaceswapError, get_backend, get_torch_version, + get_module_objects, safe_shutdown, set_backend) + +if T.TYPE_CHECKING: + import argparse + from collections.abc import Callable + +logger = logging.getLogger(__name__) + + +class ScriptExecutor(): + """ Loads the relevant script modules and executes the script. + + This class is initialized in each of the argparsers for the relevant + command, then execute script is called within their set_default + function. + + Parameters + ---------- + command: str + The faceswap command that is being executed + """ + def __init__(self, command: str) -> None: + self._command = command.lower() + + def _set_environment_variables(self) -> None: + """ Set the number of threads that numexpr can use. """ + # Allocate a decent number of threads to numexpr to suppress warnings + cpu_count = os.cpu_count() + allocate = max(1, cpu_count - cpu_count // 3 if cpu_count is not None else 1) + if "OMP_NUM_THREADS" in os.environ: + # If this is set above NUMEXPR_MAX_THREADS, numexpr will error. + # ref: https://github.com/pydata/numexpr/issues/322 + os.environ.pop("OMP_NUM_THREADS") + logger.debug("Setting NUMEXPR_MAX_THREADS to %s", allocate) + os.environ["NUMEXPR_MAX_THREADS"] = str(allocate) + + if get_backend() == "apple_silicon": # Let apple put unsupported ops on the CPU + logger.debug("Enabling unsupported Ops on CPU for Apple Silicon") + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + + def _import_script(self) -> Callable: + """ Imports the relevant script as indicated by :attr:`_command` from the scripts folder. + + Returns + ------- + class: Faceswap Script + The uninitialized script from the faceswap scripts folder. + """ + self._set_environment_variables() + self._test_for_torch_version() + self._test_for_gui() + cmd = os.path.basename(sys.argv[0]) + src = f"tools.{self._command.lower()}" if cmd == "tools.py" else "scripts" + mod = ".".join((src, self._command.lower())) + module = import_module(mod) + script = getattr(module, self._command.title()) + return script + + def _test_for_torch_version(self) -> None: + """ Check that the required PyTorch version is installed. + + Raises + ------ + FaceswapError + If PyTorch is not found, or is not between versions 2.3 and 2.9 + """ + min_ver = (2, 3) + max_ver = (2, 9) + try: + import torch # noqa:F401 pylint:disable=unused-import,import-outside-toplevel + except ImportError as err: + msg = ( + f"There was an error importing PyTorch. This is most likely because you do " + f"not have PyTorch installed. Original import error: {str(err)}") + self._handle_import_error(msg) + + torch_ver = get_torch_version() + if torch_ver < min_ver: + msg = (f"The minimum supported PyTorch is version {min_ver} but you have version " + f"{torch_ver} installed. Please upgrade PyTorch.") + self._handle_import_error(msg) + if torch_ver > max_ver: + msg = (f"The maximum supported PyTorch is version {max_ver} but you have version " + f"{torch_ver} installed. Please downgrade PyTorch.") + self._handle_import_error(msg) + logger.debug("Installed PyTorch Version: %s", torch_ver) + + @classmethod + def _handle_import_error(cls, message: str) -> None: + """ Display the error message to the console and wait for user input to dismiss it, if + running GUI under Windows, otherwise use standard error handling. + + Parameters + ---------- + message: str + The error message to display + """ + if "gui" in sys.argv and platform.system() == "Windows": + logger.error(message) + logger.info("Press \"ENTER\" to dismiss the message and close FaceSwap") + input() + sys.exit(1) + else: + raise FaceswapError(message) + + def _test_for_gui(self) -> None: + """ If running the gui, performs check to ensure necessary prerequisites are present. """ + if self._command != "gui": + return + self._test_tkinter() + self._check_display() + + @classmethod + def _test_tkinter(cls) -> None: + """ If the user is running the GUI, test whether the tkinter app is available on their + machine. If not exit gracefully. + + This avoids having to import every tkinter function within the GUI in a wrapper and + potentially spamming traceback errors to console. + + Raises + ------ + FaceswapError + If tkinter cannot be imported + """ + try: + import tkinter # noqa pylint:disable=unused-import,import-outside-toplevel + except ImportError as err: + logger.error("It looks like TkInter isn't installed for your OS, so the GUI has been " + "disabled. To enable the GUI please install the TkInter application. You " + "can try:") + logger.info("Anaconda: conda install tk") + logger.info("Windows/macOS: Install ActiveTcl Community Edition from " + "http://www.activestate.com") + logger.info("Ubuntu/Mint/Debian: sudo apt install python3-tk") + logger.info("Arch: sudo pacman -S tk") + logger.info("CentOS/Redhat: sudo yum install tkinter") + logger.info("Fedora: sudo dnf install python3-tkinter") + raise FaceswapError("TkInter not found") from err + + @classmethod + def _check_display(cls) -> None: + """ Check whether there is a display to output the GUI to. + + If running on Windows then it is assumed that we are not running in headless mode + + Raises + ------ + FaceswapError + If a DISPLAY environmental cannot be found + """ + if not os.environ.get("DISPLAY", None) and os.name != "nt": + if platform.system() == "Darwin": + logger.info("macOS users need to install XQuartz. " + "See https://support.apple.com/en-gb/HT201341") + raise FaceswapError("No display detected. GUI mode has been disabled.") + + def execute_script(self, arguments: argparse.Namespace) -> None: + """ Performs final set up and launches the requested :attr:`_command` with the given + command line arguments. + + Monitors for errors and attempts to shut down the process cleanly on exit. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments to be passed to the executing script. + """ + is_gui = hasattr(arguments, "redirect_gui") and arguments.redirect_gui + log_setup(arguments.loglevel, arguments.logfile, self._command, is_gui) + success = False + + if self._command != "gui": + self._configure_backend(arguments) + try: + script = self._import_script() + process = script(arguments) + process.process() + success = True + except FaceswapError as err: + for line in str(err).splitlines(): + logger.error(line) + except KeyboardInterrupt: # pylint:disable=try-except-raise + raise + except SystemExit: + pass + except Exception: # pylint:disable=broad-except + crash_file = crash_log() + logger.exception("Got Exception on main handler:") + logger.critical("An unexpected crash has occurred. Crash report written to '%s'. " + "You MUST provide this file if seeking assistance. Please verify you " + "are running the latest version of faceswap before reporting", + crash_file) + + finally: + safe_shutdown(got_error=not success) + + def _configure_backend(self, arguments: argparse.Namespace) -> None: + """ Configure the backend. + + Exclude any GPUs for use by Faceswap when requested. + + Set Faceswap backend to CPU if all GPUs have been deselected. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to Faceswap. + """ + if not hasattr(arguments, "exclude_gpus"): + # CPU backends and systems where no GPU was detected will not have this attribute + logger.debug("Adding missing exclude gpus argument to namespace") + setattr(arguments, "exclude_gpus", None) + return + + assert GPUStats is not None + if arguments.exclude_gpus: + if not all(idx.isdigit() for idx in arguments.exclude_gpus): + logger.error("GPUs passed to the ['-X', '--exclude-gpus'] argument must all be " + "integers.") + sys.exit(1) + arguments.exclude_gpus = [int(idx) for idx in arguments.exclude_gpus] + GPUStats().exclude_devices(arguments.exclude_gpus) + + if GPUStats().exclude_all_devices: + msg = "Switching backend to CPU" + set_backend("cpu") + logger.info(msg) + + logger.debug("Executing: %s. PID: %s", self._command, os.getpid()) + + +__all__ = get_module_objects(__name__) diff --git a/lib/config.py b/lib/config.py deleted file mode 100644 index fa4d9af215..0000000000 --- a/lib/config.py +++ /dev/null @@ -1,301 +0,0 @@ -#!/usr/bin/env python3 -""" Default configurations for faceswap - Extends out configparser funcionality - by checking for default config updates - and returning data in it's correct format """ - -import logging -import os -import sys -from collections import OrderedDict -from configparser import ConfigParser - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class FaceswapConfig(): - """ Config Items """ - def __init__(self, section): - """ Init Configuration """ - logger.debug("Initializing: %s", self.__class__.__name__) - self.configfile = self.get_config_file() - self.config = ConfigParser(allow_no_value=True) - self.defaults = OrderedDict() - self.config.optionxform = str - self.section = section - - self.set_defaults() - self.handle_config() - logger.debug("Initialized: %s", self.__class__.__name__) - - def set_defaults(self): - """ Override for plugin specific config defaults - - Should be a series of self.add_section() and self.add_item() calls - - e.g: - - section = "sect_1" - self.add_section(title=section, - info="Section 1 Information") - - self.add_item(section=section, - title="option_1", - datatype=bool, - default=False, - info="sect_1 option_1 information") - """ - raise NotImplementedError - - @property - def config_dict(self): - """ Collate global options and requested section into a dictionary - with the correct datatypes """ - conf = dict() - for sect in ("global", self.section): - if sect not in self.config.sections(): - continue - for key in self.config[sect]: - if key.startswith(("#", "\n")): # Skip comments - continue - conf[key] = self.get(sect, key) - return conf - - def get(self, section, option): - """ Return a config item in it's correct format """ - logger.debug("Getting config item: (section: '%s', option: '%s')", section, option) - datatype = self.defaults[section][option]["type"] - if datatype == bool: - func = self.config.getboolean - elif datatype == int: - func = self.config.getint - elif datatype == float: - func = self.config.getfloat - else: - func = self.config.get - retval = func(section, option) - if isinstance(retval, str) and retval.lower() == "none": - retval = None - logger.debug("Returning item: (type: %s, value: %s)", datatype, retval) - return retval - - def get_config_file(self): - """ Return the config file from the calling folder """ - dirname = os.path.dirname(sys.modules[self.__module__].__file__) - folder, fname = os.path.split(dirname) - retval = os.path.join(os.path.dirname(folder), "config", "{}.ini".format(fname)) - logger.debug("Config File location: '%s'", retval) - return retval - - def add_section(self, title=None, info=None): - """ Add a default section to config file """ - logger.debug("Add section: (title: '%s', info: '%s')", title, info) - if None in (title, info): - raise ValueError("Default config sections must have a title and " - "information text") - self.defaults[title] = OrderedDict() - self.defaults[title]["helptext"] = info - - def add_item(self, section=None, title=None, datatype=str, - default=None, info=None, rounding=None, min_max=None, choices=None): - """ Add a default item to a config section - - For int or float values, rounding and min_max must be set - This is for the slider in the GUI. The min/max values are not enforced: - rounding: sets the decimal places for floats or the step interval for ints. - min_max: tuple of min and max accepted values - - For str values choices can be set to validate input and create a combo box - in the GUI - - """ - logger.debug("Add item: (section: '%s', title: '%s', datatype: '%s', default: '%s', " - "info: '%s', rounding: '%s', min_max: %s, choices: %s)", - section, title, datatype, default, info, rounding, min_max, choices) - - choices = list() if not choices else choices - - if None in (section, title, default, info): - raise ValueError("Default config items must have a section, " - "title, defult and " - "information text") - if not self.defaults.get(section, None): - raise ValueError("Section does not exist: {}".format(section)) - if datatype not in (str, bool, float, int): - raise ValueError("'datatype' must be one of str, bool, float or " - "int: {} - {}".format(section, title)) - if datatype in (float, int) and (rounding is None or min_max is None): - raise ValueError("'rounding' and 'min_max' must be set for numerical options") - if not isinstance(choices, (list, tuple)): - raise ValueError("'choices' must be a list or tuple") - self.defaults[section][title] = {"default": default, - "helptext": info, - "type": datatype, - "rounding": rounding, - "min_max": min_max, - "choices": choices} - - def check_exists(self): - """ Check that a config file exists """ - if not os.path.isfile(self.configfile): - logger.debug("Config file does not exist: '%s'", self.configfile) - return False - logger.debug("Config file exists: '%s'", self.configfile) - return True - - def create_default(self): - """ Generate a default config if it does not exist """ - logger.debug("Creating default Config") - for section, items in self.defaults.items(): - logger.debug("Adding section: '%s')", section) - self.insert_config_section(section, items["helptext"]) - for item, opt in items.items(): - logger.debug("Adding option: (item: '%s', opt: '%s'", item, opt) - if item == "helptext": - continue - self.insert_config_item(section, - item, - opt["default"], - opt) - self.save_config() - - def insert_config_section(self, section, helptext, config=None): - """ Insert a section into the config """ - logger.debug("Inserting section: (section: '%s', helptext: '%s', config: '%s')", - section, helptext, config) - config = self.config if config is None else config - helptext = self.format_help(helptext, is_section=True) - config.add_section(section) - config.set(section, helptext) - logger.debug("Inserted section: '%s'", section) - - def insert_config_item(self, section, item, default, option, - config=None): - """ Insert an item into a config section """ - logger.debug("Inserting item: (section: '%s', item: '%s', default: '%s', helptext: '%s', " - "config: '%s')", section, item, default, option["helptext"], config) - config = self.config if config is None else config - helptext = option["helptext"] - helptext += self.set_helptext_choices(option) - helptext += "\n[Default: {}]".format(default) - helptext = self.format_help(helptext, is_section=False) - config.set(section, helptext) - config.set(section, item, str(default)) - logger.debug("Inserted item: '%s'", item) - - @staticmethod - def set_helptext_choices(option): - """ Set the helptext choices """ - choices = "" - if option["choices"]: - choices = "\nChoose from: {}".format(option["choices"]) - elif option["type"] == bool: - choices = "\nChoose from: True, False" - elif option["type"] == int: - cmin, cmax = option["min_max"] - choices = "\nSelect an integer between {} and {}".format(cmin, cmax) - elif option["type"] == float: - cmin, cmax = option["min_max"] - choices = "\nSelect a decimal number between {} and {}".format(cmin, cmax) - return choices - - @staticmethod - def format_help(helptext, is_section=False): - """ Format comments for default ini file """ - logger.debug("Formatting help: (helptext: '%s', is_section: '%s')", helptext, is_section) - helptext = '# {}'.format(helptext.replace("\n", "\n# ")) - if is_section: - helptext = helptext.upper() - else: - helptext = "\n{}".format(helptext) - logger.debug("formatted help: '%s'", helptext) - return helptext - - def load_config(self): - """ Load values from config """ - logger.info("Loading config: '%s'", self.configfile) - self.config.read(self.configfile) - - def save_config(self): - """ Save a config file """ - logger.info("Updating config at: '%s'", self.configfile) - f_cfgfile = open(self.configfile, "w") - self.config.write(f_cfgfile) - f_cfgfile.close() - - def validate_config(self): - """ Check for options in default config against saved config - and add/remove as appropriate """ - logger.debug("Validating config") - if self.check_config_change(): - self.add_new_config_items() - self.check_config_choices() - logger.debug("Validated config") - - def add_new_config_items(self): - """ Add new items to the config file """ - logger.debug("Updating config") - new_config = ConfigParser(allow_no_value=True) - for section, items in self.defaults.items(): - self.insert_config_section(section, items["helptext"], new_config) - for item, opt in items.items(): - if item == "helptext": - continue - if section not in self.config.sections(): - logger.debug("Adding new config section: '%s'", section) - opt_value = opt["default"] - else: - opt_value = self.config[section].get(item, opt["default"]) - self.insert_config_item(section, - item, - opt_value, - opt, - new_config) - self.config = new_config - self.config.optionxform = str - self.save_config() - logger.debug("Updated config") - - def check_config_choices(self): - """ Check that config items are valid choices """ - logger.debug("Checking config choices") - for section, items in self.defaults.items(): - for item, opt in items.items(): - if item == "helptext" or not opt["choices"]: - continue - opt_value = self.config.get(section, item) - if opt_value.lower() == "none" and any(choice.lower() == "none" - for choice in opt["choices"]): - continue - if opt_value not in opt["choices"]: - default = str(opt["default"]) - logger.warning("'%s' is not a valid config choice for '%s': '%s'. Defaulting " - "to: '%s'", opt_value, section, item, default) - self.config.set(section, item, default) - logger.debug("Checked config choices") - - def check_config_change(self): - """ Check whether new default items have been added or removed - from the config file compared to saved version """ - if set(self.config.sections()) != set(self.defaults.keys()): - logger.debug("Default config has new section(s)") - return True - - for section, items in self.defaults.items(): - opts = [opt for opt in items.keys() if opt != "helptext"] - exists = [opt for opt in self.config[section].keys() - if not opt.startswith(("# ", "\n# "))] - if set(exists) != set(opts): - logger.debug("Default config has new item(s)") - return True - logger.debug("Default config has not changed") - return False - - def handle_config(self): - """ Handle the config """ - logger.debug("Handling config") - if not self.check_exists(): - self.create_default() - self.load_config() - self.validate_config() - logger.debug("Handled config") diff --git a/lib/config/__init__.py b/lib/config/__init__.py new file mode 100644 index 0000000000..c24ec1d2ab --- /dev/null +++ b/lib/config/__init__.py @@ -0,0 +1,4 @@ +#! /usr/env/bin/python3 +""" Config handling for Faceswap """ +from .objects import ConfigItem, ConfigValueType, GlobalSection +from .config import generate_configs, get_configs, FaceswapConfig diff --git a/lib/config/config.py b/lib/config/config.py new file mode 100644 index 0000000000..16de36811e --- /dev/null +++ b/lib/config/config.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +""" Default configurations for faceswap. Handles parsing and validating of Faceswap Configs and +interfacing with :class:`configparser.ConfigParser` """ +from __future__ import annotations + +import inspect +import logging +import os +import sys + +from importlib import import_module + +from lib.utils import full_path_split, get_module_objects, PROJECT_ROOT + +from .ini import ConfigFile +from .objects import ConfigItem, ConfigSection, GlobalSection + + +logger = logging.getLogger(__name__) + +_CONFIGS: dict[str, FaceswapConfig] = {} +""" dict[str, FaceswapConfig] : plugin group to FaceswapConfig mapping for all loaded configs """ + + +class FaceswapConfig(): + """ Config Items """ + def __init__(self, configfile: str | None = None) -> None: + """ Init Configuration + + Parameters + ---------- + configfile : str, optional + Optional path to a config file. ``None`` for default location. Default: ``None`` + """ + logger.debug("Initializing: %s", self.__class__.__name__) + + self._plugin_group = self._get_plugin_group() + + self._ini = ConfigFile(self._plugin_group, ini_path=configfile) + self.sections: dict[str, ConfigSection] = {} + """ dict[str, :class:`ConfigSection`] : The Faceswap config sections and options """ + + self._set_defaults() + self._ini.on_load(self.sections) + _CONFIGS[self._plugin_group] = self + + logger.debug("Initialized: %s", self.__class__.__name__) + + def _get_plugin_group(self) -> str: + """ Obtain the name of the plugin group based on the child module's folder path + + Returns + ------- + str + The plugin group for this Config object + """ + mod_split = self.__module__.split(".") + mod_name = mod_split[-1] + retval = mod_name.rsplit("_", maxsplit=1)[0] + logger.debug("Got plugin group '%s' from module '%s'", + retval, self.__module__) + # Sanity check in case of defaults config file name/location changes + parent = mod_split[-2] + assert mod_name == f"{parent}_config" + return retval + + def add_section(self, title: str, info: str) -> None: + """ Add a default section to config file + + Parameters + ---------- + title : str + The title for the section + info : str + The helptext for the section + """ + logger.debug("Add section: (title: '%s', info: '%s')", title, info) + self.sections[title] = ConfigSection(helptext=info, options={}) + + def add_item(self, section: str, title: str, config_item: ConfigItem) -> None: + """ Add a default item to a config section + + Parameters + ---------- + section : str + The section of the config to add the item to + title : str + The name of the config item + config_item : :class:`~lib.config.objects.ConfigItem` + The default config item object to add to the config + """ + logger.debug("Add item: (section: '%s', item: %s", section, config_item) + self.sections[section].options[title] = config_item + + def _import_defaults_from_module(self, + filename: str, + module_path: str, + plugin_type: str) -> None: + """ Load the plugin's defaults module, extract defaults and add to default configuration. + + Parameters + ---------- + filename : str + The filename to load the defaults from + module_path : str + The path to load the module from + plugin_type : str + The type of plugin that the defaults are being loaded for + """ + logger.debug("Adding defaults: (filename: %s, module_path: %s, plugin_type: %s", + filename, module_path, plugin_type) + module = os.path.splitext(filename)[0] + section = ".".join((plugin_type, module.replace("_defaults", ""))) + logger.debug("Importing defaults module: %s.%s", module_path, module) + mod = import_module(f"{module_path}.{module}") + self.add_section(section, mod.HELPTEXT) # type:ignore[attr-defined] + for key, val in vars(mod).items(): + if isinstance(val, ConfigItem): + self.add_item(section=section, title=key, config_item=val) + logger.debug("Added defaults: %s", section) + + def _defaults_from_plugin(self, plugin_folder: str) -> None: + """ Scan the given plugins folder for config defaults.py files and update the + default configuration. + + Parameters + ---------- + plugin_folder : str + The folder to scan for plugins + """ + for dirpath, _, filenames in os.walk(plugin_folder): + default_files = [fname for fname in filenames if fname.endswith("_defaults.py")] + if not default_files: + continue + base_path = os.path.dirname(os.path.realpath(sys.argv[0])) + # Can't use replace as there is a bug on some Windows installs that lowers some paths + import_path = ".".join(full_path_split(dirpath[len(base_path):])[1:]) + plugin_type = import_path.rsplit(".", maxsplit=1)[-1] + for filename in default_files: + self._import_defaults_from_module(filename, import_path, plugin_type) + + def set_defaults(self, helptext: str = "") -> None: + """ Override for plugin specific config defaults. + + This method should always be overriden to add the help text for the global plugin group. + If `helptext` is not provided, then it is assumed that there is no global section for this + plugin group. + + The default action will parse the child class' module for + :class:`~lib.config.objects.ConfigItem` objects and add them to this plugin group's + "global" section of :attr:`sections`. + + The name of each config option will be the variable name found in the module. + + It will then parse the child class' module for subclasses of + :class:`~lib.config.objects.GlobalSection` objects and add each of these sections to this + plugin group's :attr:`sections`, adding any :class:`~lib.config.objects.ConfigItem` within + the GlobalSection to that sub-section. + + The section name will be the name of the GlobalSection subclass, lowercased + + Parameters + ---------- + helptext : str + The help text to display for the plugin group + + Raises + ------ + ValueError + If the plugin group's help text has not been provided + """ + section = "global" + logger.debug("[%s:%s] Adding defaults", self._plugin_group, section) + + if not helptext: + logger.debug("No help text provided for '%s'. Not creating global section", + self.__module__) + return + + self.add_section(section, helptext) + + for key, val in vars(sys.modules[self.__module__]).items(): + if isinstance(val, ConfigItem): + self.add_item(section=section, title=key, config_item=val) + logger.debug("[%s:%s] Added defaults", self._plugin_group, section) + + # Add global sub-sections + for key, val in vars(sys.modules[self.__module__]).items(): + if inspect.isclass(val) and issubclass(val, GlobalSection) and val != GlobalSection: + section_name = f"{section}.{key.lower()}" + self.add_section(section_name, val.helptext) + for opt_name, opt in val.__dict__.items(): + if isinstance(opt, ConfigItem): + self.add_item(section=section_name, title=opt_name, config_item=opt) + + def _set_defaults(self) -> None: + """Load the plugin's default values, set the object names and order the sections, global + first then alphabetically.""" + self.set_defaults() + for section_name, section in self.sections.items(): + for opt_name, opt in section.options.items(): + opt.set_name(f"{self._plugin_group}.{section_name}.{opt_name}") + + global_keys = sorted(s for s in self.sections if s.startswith("global")) + remaining_keys = sorted(s for s in self.sections if not s.startswith("global")) + ordered = {k: self.sections[k] for k in global_keys + remaining_keys} + + self.sections = ordered + + def save_config(self) -> None: + """Update the ini file with the currently stored app values and save the config file.""" + self._ini.update_from_app(self.sections) + + +def get_configs() -> dict[str, FaceswapConfig]: + """ Get all of the FaceswapConfig options. Loads any configs that have not been loaded and + return a dictionary of all configs. + + Returns + ------- + dict[str, :class:`FaceswapConfig`] + All of the loaded faceswap config objects + """ + generate_configs(force=True) + return _CONFIGS + + +def generate_configs(force: bool = False) -> None: + """ Generate config files if they don't exist. + + This script is run prior to anything being set up, so don't use logging + Generates the default config files for plugins in the faceswap config folder + + Logic: + - Scan the plugins path for files named _config.py> + - Import the discovered module and look for instances of FaceswapConfig + - If exists initialize the class + + Parameters + ---------- + force : bool + Force the loading of all plugin configs even if their .ini files pre-exist + """ + configs_path = os.path.join(PROJECT_ROOT, "config") + plugins_path = os.path.join(PROJECT_ROOT, "plugins") + for dirpath, _, filenames in os.walk(plugins_path): + relative_path = dirpath.replace(PROJECT_ROOT, "")[1:] + if len(full_path_split(relative_path)) > 2: # don't dig further than 1 folder deep + continue + plugin_group = os.path.basename(dirpath) + filename = f"{plugin_group}_config.py" + if filename not in filenames: + continue + + if plugin_group in _CONFIGS: + continue + + config_file = os.path.join(configs_path, f"{plugin_group}.ini") + if not os.path.exists(config_file) or force: + modname = os.path.splitext(filename)[0] + modpath = os.path.join(dirpath.replace(PROJECT_ROOT, ""), + modname)[1:].replace(os.sep, ".") + mod = import_module(modpath) + for obj in vars(mod).values(): + if (inspect.isclass(obj) + and issubclass(obj, FaceswapConfig) + and obj != FaceswapConfig): + obj() + + +__all__ = get_module_objects(__name__) diff --git a/lib/config/ini.py b/lib/config/ini.py new file mode 100644 index 0000000000..e48c91862c --- /dev/null +++ b/lib/config/ini.py @@ -0,0 +1,410 @@ +#! /usr/env/bin/python3 +""" Handles interfacing between Faceswap Configs and ConfigParser .ini files """ +from __future__ import annotations + +import logging +import os +import textwrap +import typing as T + +from configparser import ConfigParser + +from lib.logger import parse_class_init +from lib.utils import get_module_objects, PROJECT_ROOT + +if T.TYPE_CHECKING: + from .objects import ConfigSection, ConfigValueType + +logger = logging.getLogger(__name__) + + +class ConfigFile(): + """ Handles the interfacing between saved faceswap .ini configs and internal Config objects + + Parameters + ---------- + plugin_group : str + The plugin group that is requesting a config file + ini_path : str | None, optional + Optional path to a .ini config file. ``None`` for default location. Default: ``None`` + """ + def __init__(self, plugin_group: str, ini_path: str | None = None) -> None: + parse_class_init(locals()) + self._plugin_group = plugin_group + self._file_path = self._get_config_path(ini_path) + self._parser = self._get_new_configparser() + if self._exists: # Load or create new + self.load() + + @property + def _exists(self) -> bool: + """ bool : ``True`` if the config.ini file exists """ + return os.path.isfile(self._file_path) + + def _get_config_path(self, ini_path: str | None) -> str: + """ Return the path to the config file from the calling folder or the provided file + + Parameters + ---------- + ini_path : str | None + Path to a config ini file. ``None`` for default location. + + Returns + ------- + str + The full path to the configuration file + """ + if ini_path is not None: + if not os.path.isfile(ini_path): + err = f"Config file does not exist at: {ini_path}" + logger.error(err) + raise ValueError(err) + return ini_path + + retval = os.path.join(PROJECT_ROOT, "config", f"{self._plugin_group}.ini") + logger.debug("[%s] Config File location: '%s'", os.path.basename(retval), retval) + return retval + + def _get_new_configparser(self) -> ConfigParser: + """ Obtain a fresh ConfigParser object and set it to case-sensitive + + Returns + ------- + :class:`configparser.ConfigParser` + A new ConfigParser object set to case-sensitive + """ + retval = ConfigParser(allow_no_value=True) + retval.optionxform = str # type:ignore[assignment,method-assign] + return retval + + # I/O + def load(self) -> None: + """ Load values from the saved config ini file into our Config object """ + logger.verbose("[%s] Loading config: '%s'", # type:ignore[attr-defined] + self._plugin_group, self._file_path) + self._parser.read(self._file_path, encoding="utf-8") + + def save(self) -> None: + """ Save a config file """ + logger.debug("[%s] %s config: '%s'", + self._plugin_group, "Updating" if self._exists else "Saving", self._file_path) + # TODO in python >= 3.14 this will error when there are delimiters in the comments + with open(self._file_path, "w", encoding="utf-8", errors="replace") as f_cfgfile: + self._parser.write(f_cfgfile) + logger.info("[%s] Saved config: '%s'", self._plugin_group, self._file_path) + + # .ini vs Faceswap Config checking + def _sections_synced(self, app_config: dict[str, ConfigSection]) -> bool: + """ Validate that all of the sections within the application config match with all of the + sections in the ini file + + Parameters + ---------- + app_config : dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + + Returns + ------- + bool + ``True`` if application sections and saved ini sections match + """ + given_sections = set(app_config) + loaded_sections = set(self._parser.sections()) + retval = given_sections == loaded_sections + if not retval: + logger.debug("[%s] Config sections are not synced: (app: %s, ini: %s)", + self._plugin_group, sorted(given_sections), sorted(loaded_sections)) + return retval + + def _options_synced(self, app_config: dict[str, ConfigSection]) -> bool: + """ Validate that all of the option names within the application config match with all of + the option names in the ini file + + Note + ---- + As we need to write a new config anyway, we return on the first change found + + Parameters + ---------- + app_config : dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + + Returns + ------- + bool + ``True`` if application option names match with saved ini option names + """ + for name, section in app_config.items(): + given_opts = set(opt for opt in section.options) + loaded_opts = set(self._parser[name].keys()) + if given_opts != loaded_opts: + logger.debug("[%s:%s] Config options are not synced: (app: %s, ini: %s)", + self._plugin_group, name, sorted(given_opts), sorted(loaded_opts)) + return False + return True + + def _values_synced(self, app_section: ConfigSection, section: str) -> bool: + """ Validate that all of the option values within the application config match with all of + the option values in the ini file + + Parameters + ---------- + app_section : :class:`ConfigSection` + The latest configuration settings from the application for the given section + section : str + The section name to check the option values for + + Returns + ------- + bool + ``True`` if application option values match with saved ini option values + """ + # Need to also pull in keys as False is omitted from the set with just values which can + # cause edge-case false negatives + given_vals = set((k, v.ini_value) for k, v in app_section.options.items()) + loaded_vals = set((k, v) for k, v in self._parser[section].items()) + retval = given_vals == loaded_vals + if not retval: + logger.debug("[%s:%s] Config values are not synced: (app: %s, ini: %s)", + self._plugin_group, section, sorted(given_vals), sorted(loaded_vals)) + return retval + + def _is_synced_structure(self, app_config: dict[str, ConfigSection]) -> bool: + """ Validate that all the given sections and option names within the application config + match with their corresponding items in the save .ini file + + Parameters + ---------- + app_config: dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + + Returns + ------- + bool + ``True`` if the app config and saved ini config structure match + """ + if not self._sections_synced(app_config): + return False + if not self._options_synced(app_config): + return False + + logger.debug("[%s] Configs are synced", self._plugin_group) + return True + + # .ini file insertion + def format_help(self, helptext: str, is_section: bool = False) -> str: + """ Format comments for insertion into a config ini file + + Parameters + ---------- + helptext : str + The help text to be formatted + is_section : bool, optional + ``True`` if the help text pertains to a section. ``False`` if it pertains to an option. + Default: ``True`` + + Returns + ------- + str + The formatted help text + """ + logger.debug("[%s] Formatting help: (helptext: '%s', is_section: '%s')", + self._plugin_group, helptext, is_section) + formatted = "" + for hlp in helptext.split("\n"): + subsequent_indent = "\t\t" if hlp.startswith("\t") else "" + hlp = f"\t- {hlp[1:].strip()}" if hlp.startswith("\t") else hlp + formatted += textwrap.fill(hlp, + 100, + tabsize=4, + subsequent_indent=subsequent_indent) + "\n" + helptext = '# {}'.format(formatted[:-1].replace("\n", "\n# ")) # Strip last newline + helptext = helptext.upper() if is_section else f"\n{helptext}" + return helptext + + def _insert_section(self, section: str, helptext: str, config: ConfigParser) -> None: + """ Insert a section into the config + + Parameters + ---------- + section : str + The section title to insert + helptext : str + The help text for the config section + config : :class:`configparser.ConfigParser` + The config parser object to insert the section into. + """ + logger.debug("[%s:%s] Inserting section: (helptext: '%s', config: '%s')", + self._plugin_group, section, helptext, config) + helptext = self.format_help(helptext, is_section=True) + config.add_section(section) + config.set(section, helptext) + + def _insert_option(self, + section: str, + name: str, + helptext: str, + value: str, + config: ConfigParser) -> None: + """ Insert an option into a config section + + Parameters + ---------- + section : str + The section to insert the option into + name : str + The name of the option to insert + helptext : str + The help text for the option + value : str + The value for the option + config : :class:`configparser.ConfigParser` + The config parser object to insert the option into + """ + logger.debug( + "[%s:%s] Inserting option: (name: '%s', helptext: %s, value: '%s', config: '%s')", + self._plugin_group, section, name, helptext, value, config) + helptext = self.format_help(helptext, is_section=False) + config.set(section, helptext) + config.set(section, name, value) + + def _sync_from_app(self, app_config: dict[str, ConfigSection]) -> None: + """ Update the saved config.ini file from the values stored in the application config + + Existing options keep their saved values as per the .ini files. New options are added with + their application defined default value. Options in the .ini file not in application + provided config are removed. + + Note + ---- + A new configuration object is created as comments are stripped from the loaded ini files. + + Parameters + ---------- + app_config: dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + """ + logger.debug("[%s] Syncing from app", self._plugin_group) + parser = self._get_new_configparser() if self._exists else self._parser + for section_name, section in app_config.items(): + self._insert_section(section_name, section.helptext, parser) + for name, opt in section.options.items(): + + value = self._parser.get(section_name, name, fallback=None) + if value is None: + value = opt.ini_value + logger.debug( + "[%s:%s] Setting default value for non-existent config option '%s': '%s'", + self._plugin_group, section_name, name, value) + + self._insert_option(section_name, name, opt.helptext, value, parser) + + if parser != self._parser: + self._parser = parser + + self.save() + + # .ini extraction + def _get_converted_value(self, section: str, option: str, datatype: type) -> ConfigValueType: + """ Return a config item from the .ini file in it's correct type. + + Parameters + ---------- + section : str + The configuration section to obtain the config option for + option : str + The configuration option to obtain the converted value for + datatype : type + The type to return the value as + + Returns + ------- + bool | int | float | list[str] | str + The selected configuration option in the correct data format + """ + logger.debug("[%s:%s] Getting config item: (option: '%s', datatype: %s)", + self._plugin_group, section, option, datatype) + + assert datatype in (bool, int, float, str, list), ( + f"Expected (bool, int, float, str, list). Got {datatype}") + + retval: ConfigValueType + if datatype == bool: + retval = self._parser.getboolean(section, option) + elif datatype == int: + retval = self._parser.getint(section, option) + elif datatype == float: + retval = self._parser.getfloat(section, option) + else: + retval = self._parser.get(section, option) + + logger.debug("[%s:%s] Got config item: (value: %s, type: %s)", + self._plugin_group, section, retval, type(retval)) + return retval + + def _sync_to_app(self, app_config: dict[str, ConfigSection]) -> None: + """ Update the values in the application config to those loaded from the saved config.ini. + + Parameters + ---------- + app_config: dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + """ + logger.debug("[%s] Syncing to app", self._plugin_group) + for section_name, section in app_config.items(): + if self._values_synced(section, section_name): + continue + for opt_name, opt in section.options.items(): + if section_name not in self._parser or opt_name not in self._parser[section_name]: + logger.debug("[%s:%s] Skipping new option: '%s'", + self._plugin_group, section_name, opt_name) + continue + + ini_opt = self._parser[section_name][opt_name] + if opt.ini_value != ini_opt: + logger.debug("[%s:%s] Updating '%s' from '%s' to '%s'", + self._plugin_group, section_name, + opt_name, ini_opt, opt.ini_value) + opt.set(self._get_converted_value(section_name, opt_name, opt.datatype)) + + # .ini insertion and extraction + def on_load(self, app_config: dict[str, ConfigSection]) -> None: + """ Check whether there has been any change between the current application config and + the loaded ini config. If so, update the relevant object(s) appropriately. This check will + also create new config.ini files if they do not pre-exist + + Parameters + ---------- + app_config : dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + """ + if not self._exists: + logger.debug("[%s] Creating new ini file", self._plugin_group) + self._sync_from_app(app_config) + + if not self._is_synced_structure(app_config): + self._sync_from_app(app_config) + + self._sync_to_app(app_config) + + def update_from_app(self, app_config: dict[str, ConfigSection]) -> None: + """ Update the config.ini file to those values that are currently in Faceswap's app + config + + Parameters + ---------- + app_config : dict[str, :class:`ConfigSection`] + The latest configuration settings from the application. Section name is key + """ + logger.debug("[%s] Updating saved config", self._plugin_group) + parser = self._get_new_configparser() if self._exists else self._parser + for section_name, section in app_config.items(): + self._insert_section(section_name, section.helptext, parser) + for name, opt in section.options.items(): + self._insert_option(section_name, name, opt.helptext, opt.ini_value, parser) + if parser != self._parser: + self._parser = parser + self.save() + + +__all__ = get_module_objects(__name__) diff --git a/lib/config/objects.py b/lib/config/objects.py new file mode 100644 index 0000000000..d496c56904 --- /dev/null +++ b/lib/config/objects.py @@ -0,0 +1,463 @@ +#! /usr/env/bin/python3 +""" Dataclass objects for holding and validating Faceswap Config items """ +from __future__ import annotations + +import gettext +import logging +from typing import (Any, cast, Generic, get_args, get_origin, get_type_hints, + Literal, TypeVar, Union) +import types + +from dataclasses import dataclass, field + +from lib.utils import get_module_objects + + +# LOCALES +_LANG = gettext.translation("lib.config", localedir="locales", fallback=True) +_ = _LANG.gettext + + +logger = logging.getLogger(__name__) +ConfigValueType = bool | int | float | list[str] | str +T = TypeVar("T") + + +# TODO allow list items other than strings +@dataclass +class ConfigItem(Generic[T]): # pylint:disable=too-many-instance-attributes + """ A dataclass for storing config items loaded from config.ini files and dynamically assigning + and validating that the correct datatype is used. + + The value loaded from the .ini config file can be accessed with either: + + >>> conf.value + >>> conf() + >>> conf.get() + + Parameters + ---------- + datatype : type + A python type class. This limits the type of data that can be provided in the .ini file + and ensures that the value is returned to faceswap is correct. Valid datatypes are: + `int`, `float`, `str`, `bool` or `list`. Note that `list` items must all be strings. + default : Any + The default value for this option. It must be of the same type as :attr:`datatype`. + group : str + The group that this config item exists within in the config section + info : str + A description of what this option does. + choices : list[str] | Literal["colorchooser"], optional + If this option's datatype is a `str` then valid selections can be defined here, empty list + for any value. If the option's datatype is a `list`, then this option must be populated + with the valid selections. This validates the option and also enables a combobox / radio + option in the GUI. If the default value is a hex color value, then this should be the + literal "colorchooser" to present a color choosing interface in the GUI. Ignored for all + other datatypes + Default: [] (empty list: no options) + gui_radio : bool, optional + If :attr:`choices` are defined, this indicates that the GUI should use radio buttons rather + than a combobox to display this option. Default: ``False`` + min_max : tuple[int | float, int | float] | None, optional + For `int` and `float` :attr:`datatype` this is required otherwise it is ignored. Should be + a tuple of min and max accepted values of the same datatype as the option value. This is + used for controlling the GUI slider range. Values are not enforced. Default: ``None`` + rounding : int | None, optional + For `int` and `float :attr:datatypes this is required to be > 0 otherwise it is ignored. + Used for the GUI slider. For `float`, this is the number of decimal places to display. For + `int` this is the step size. Default: `-1` (ignored) + fixed : bool, optional + [train only]. Training configurations are fixed when the model is created, and then + reloaded from the state file. Marking an item as fixed=``False`` indicates that this value + can be changed for existing models, and will override the value saved in the state file + with the updated value in config. Default: ``True`` + """ + datatype: type[T] + """ type : A python type class. The datatype of the config value. One of `int`, `float`, `str`, + `bool` or `list`. `list` will only contain `str` items """ + default: T + """ Any : The default value for this option. It is of the same type as :attr:`datatype` """ + group: str + """ str : The group that this config option belongs to """ + info: str + """ str : A description of what this option does """ + choices: list[str] | Literal["colorchooser"] = field(default_factory=list) + """ list[str] | Literal["colorchooser"]: If this option's datatype is a `str` then valid + selections may be defined here, Empty list if any value is valid. If the datatype is a `list` + then valid choices will be populated here. If the default value is a hex color code, then the + literal "colorchooser" will display a color choosing interface in the GUI. """ + gui_radio: bool = False + """ bool : indicates that the GUI should use radio buttons rather than a combobox to display + this option if :attr:`choices` is populated """ + min_max: tuple[T, T] | None = None + """ tuple[int | float, int | float] | None : For `int` and `float` :attr:`datatype` this will + be populated otherwise it will be ``None``. Used for controlling the GUI slider range. Values + are not enforced. """ + rounding: int = -1 + """ int : For `int` and `float` :attr:`datatypes` this will be > 0 otherwise it will be `-1`. + Used for the GUI slider. For `float`, this is the number of decimal places to display. For + `int` this is the step size. """ + fixed: bool = True + """ bool : Only used for train.model configurations. Options marked as fixed=``False`` + indicates that this value can be changed for existing models, otherwise the option set when the + model commenced training is fixed and cannot be changed. Default: ``True`` """ + _value: T = field(init=False) + """ Any : The value of the config item of type :attr:`datatype`""" + _name: str = field(init=False) + """ str: The option name for this object. Set when the config is first loaded """ + + @property + def helptext(self) -> str: + """ str | Description of the config option with additional formating and helptext added + from the item parameters """ + retval = f"{self.info}\n" + if not self.fixed: + retval += _("\nThis option can be updated for existing models.\n") + if self.datatype == list: + retval += _("\nIf selecting multiple options then each option should be separated " + "by a space or a comma (e.g. item1, item2, item3)\n") + if self.choices and self.choices != "colorchooser": + retval += _("\nChoose from: {}").format(self.choices) + elif self.datatype == bool: + retval += _("\nChoose from: True, False") + elif self.datatype == int: + assert self.min_max is not None + cmin, cmax = self.min_max + retval += _("\nSelect an integer between {} and {}").format(cmin, cmax) + elif self.datatype == float: + assert self.min_max is not None + cmin, cmax = self.min_max + retval += _("\nSelect a decimal number between {} and {}").format(cmin, cmax) + default = ", ".join(self.default) if isinstance(self.default, list) else self.default + retval += _("\n[Default: {}]").format(default) + return retval + + @property + def value(self) -> T: + """ Any : The config value for this item loaded from the config .ini file. String values + will always be lowercase, regardless of what is loaded from Config """ + retval = self._value + if isinstance(self._value, str): + retval = cast(T, self._value.lower()) + if isinstance(self._value, list): + retval = cast(T, [x.lower() for x in self._value]) + return retval + + @property + def ini_value(self) -> str: + """ str : The current value of the ConfigItem as a string for writing to a .ini file """ + if isinstance(self._value, list): + return ", ".join(str(x) for x in self._value) + return str(self._value) + + @property + def name(self) -> str: + """str: The name associated with this option """ + return self._name + + def _validate_type(self, # pylint:disable=too-many-return-statements + expected_type: Any, + attr: Any, + depth=1) -> bool: + """ Validate that provided types are correct when this Dataclass is initialized + + Parameters + ---------- + expected_type : Any + The expected data type for the given attribute + attr : Any + The attribute to test for correctness + depth : int, optional + The current recursion depth + + Returns + ------- + bool + ``True`` if the given attribute is a valid datatype + + Raises + ------ + AssertionError + On explicit data type failure + ValueError + On unhandled data type failure + """ + value = getattr(self, attr) + attr_type = type(value) + expected_type = self.datatype if expected_type == T else expected_type # type:ignore[misc] + + if attr_type is expected_type: + return True + + if attr == "datatype": + assert value in (str, bool, float, int, list), ( + "'datatype' must be one of str, bool, float, int or list. Got {value}") + return True + + if expected_type == T: # type:ignore[misc] + assert attr_type == self.datatype, ( + f"'{attr}' expected: {self.datatype}. Got: {attr_type}") + return True + + if get_origin(expected_type) is Literal: + return value in get_args(expected_type) + + if get_origin(expected_type) in (Union, types.UnionType): + for subtype in get_args(expected_type): + if self._validate_type(subtype, attr, depth=depth + 1): + return True + + if get_origin(expected_type) in (list, tuple) and attr_type in (list, tuple): + sub_expected = [self.datatype if v == T # type:ignore[misc] + else v for v in get_args(expected_type)] + return set(type(v) for v in value).issubset(sub_expected) + + if depth == 1: + raise ValueError(f"'{attr}' expected: {expected_type}. Got: {attr_type}") + + return False + + def _validate_required(self) -> None: + """ Validate that required parameters are populated + + Raises + ------ + ValueError + If any required parameters are empty + """ + if not self.group: + raise ValueError("A group must be provided") + if not self.info: + raise ValueError("Option info must me provided") + + def _validate_choices(self) -> None: + """ Validate that choices have been used correctly + + Raises + ------ + ValueError + If any choices options have not been populated correctly + """ + if self.choices == "colorchooser": + if not isinstance(self.default, str): + raise ValueError(f"Config Item default must be a string when selecting " + f"choice='colorchooser'. Got {type(self.default)}") + if not self.default.startswith("#") or len(self.default) != 7: + raise ValueError(f"Hex color codes should start with a '#' and be 6 " + f"characters long. Got: '{self.default}'") + elif self.choices and isinstance(self.default, str) and self.default not in self.choices: + raise ValueError(f"Config item default value '{self.default}' must exist in " + f"in choices {self.choices}") + + if isinstance(self.choices, list) and self.choices: + unique_choices = set(x.lower() for x in self.choices) + if len(unique_choices) != len(self.choices): + raise ValueError("Config item choices must be a unique list") + if isinstance(self.default, list): + defaults = set(x.lower() for x in self.default) + else: + assert isinstance(self.default, str), type(self.default) + defaults = {self.default.lower()} + if not defaults.issubset(unique_choices): + raise ValueError(f"Config item default {self.default} must exist in choices " + f"{self.choices}") + + if not self.choices and isinstance(self.default, list): + raise ValueError("Config item of type list must have choices defined") + + def _validate_numeric(self) -> None: + """ Validate that float and int values have been set correctly + + Raises + ------ + ValueError + If any float or int options have not been configured correctly + """ + # NOTE: Have to include datatype filter in next check to exclude bools + if self.datatype in (float, int) and isinstance(self.default, (float, int)): + if self.rounding <= 0: + raise ValueError(f"Config Item rounding must be a positive number for " + f"datatypes float and int. Got {self.rounding}") + if self.min_max is None or len(self.min_max) != 2: + raise ValueError(f"Config Item min_max must be a tuple of (, " + f") values. Got {self.min_max}") + + def __post_init__(self) -> None: + """ Validate and type check that the given parameters are valid and set the default value. + + Raises + ------ + ValueError + If the Dataclass fails validation checks + """ + self._name = "" + self._value = self.default + try: + for attr, dtype in get_type_hints(self.__class__).items(): + self._validate_type(dtype, attr) + except (AssertionError, ValueError) as err: + raise ValueError(f"Config item failed type checking: {str(err)}") from err + + self._validate_required() + self._validate_choices() + self._validate_numeric() + + def get(self) -> T: + """ Obtain the currently stored configuration value + + Returns + ------- + Any + The config value for this item loaded from the config .ini file. String values will + always be lowecase, regardless of what is loaded from Config """ + return self.value + + def _parse_list(self, value: str | list[str]) -> list[str]: + """ Parse inbound list values. These can be space/comma-separated strings or a list. + + Parameters + ---------- + value : str | list[str] + The inbound value to be converted to a list + + Returns + ------- + list[str] + List of strings representing the inbound values. + """ + if not value: + return [] + if isinstance(value, list): + return [str(x) for x in value] + delimiter = "," if "," in value else None + retval = list(set(x.strip() for x in value.split(delimiter))) + logger.debug("[%s] Processed str value '%s' to unique list %s", self._name, value, retval) + return retval + + def _validate_selection(self, value: str | list[str]) -> str | list[str]: + """ Validate that the given value is valid within the stored choices + + Parameters + ---------- + str | list[str] + The inbound config value to validate + + Returns + ------- + bool + ``True`` if the selected value is a valid choice + """ + assert isinstance(self.choices, list) + choices = [x.lower() for x in self.choices] + logger.debug("[%s] Checking config choices", self._name) + + if isinstance(value, str): + if value.lower() not in choices: + logger.warning("[%s] '%s' is not a valid config choice. Defaulting to '%s'", + self._name, value, self.default) + return cast(str, self.default) + return value + + if all(x.lower() in choices for x in value): + return value + + valid = [x for x in value if x.lower() in choices] + valid = valid if valid else cast(list[str], self.default) + invalid = [x for x in value if x.lower() not in choices] + logger.warning("[%s] The option(s) %s are not valid selections. Setting to: %s", + self._name, invalid, valid) + + return valid + + def set(self, value: T) -> None: + """ Set the item's option value + + Parameters + ---------- + value : Any + The value to set this item to. Must be of type :attr:`datatype` + + Raises + ------ + ValueError + If the given value does not pass type and content validation checks + """ + if not self._name: + raise ValueError("The name of this object should have been set before any value is" + "added") + + if self.datatype is list: + if not isinstance(value, (str, list)): + raise ValueError(f"[{self._name}] List values should be set as a Str or List. Got " + f"{type(value)} ({value})") + value = cast(T, self._parse_list(value)) + + if not isinstance(value, self.datatype): + raise ValueError( + f"[{self._name}] Expected {self.datatype} got {type(value)} ({value})") + + if isinstance(self.choices, list) and self.choices: + assert isinstance(value, (list, str)) + value = cast(T, self._validate_selection(value)) + + if self.choices == "colorchooser": + assert isinstance(value, str) + if not value.startswith("#") or len(value) != 7: + raise ValueError(f"Hex color codes should start with a '#' and be 6 " + f"characters long. Got: '{value}'") + + self._value = value + + def set_name(self, name: str) -> None: + """ Set the logging name for this object for display purposes + + Parameters + ---------- + name : str + The name to assign to this option + """ + logger.debug("Setting name to '%s'", name) + assert isinstance(name, str) and name + self._name = name + + def __call__(self) -> T: + """ Obtain the currently stored configuration value + + Returns + ------- + Any + The config value for this item loaded from the config .ini file. String values will + always be lowecase, regardless of what is loaded from Config """ + return self.value + + +@dataclass +class ConfigSection: + """ Dataclass for holding information about configuration sections and the contained + configuration items + + Parameters + ---------- + helptext : str + The helptext to be displayed for the configuration section + options : dict[str, :class:`ConfigItem`] + Dictionary of configuration option name to the options for the section + """ + helptext: str + options: dict[str, ConfigItem] + + +@dataclass +class GlobalSection: + """ A dataclass for holding and identifying global sub-sections for plugin groups. Any global + subsections must inherit from this. + + Parameters + ---------- + helptext : str + The helptext to be displayed for the global configuration section + """ + helptext: str + + +__all__ = get_module_objects(__name__) diff --git a/lib/convert.py b/lib/convert.py new file mode 100644 index 0000000000..5b41a7817c --- /dev/null +++ b/lib/convert.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python3 +""" Converter for Faceswap """ +from __future__ import annotations +import logging +import typing as T +from dataclasses import dataclass + +import cv2 +import numpy as np + +from lib.utils import get_module_objects +from plugins.plugin_loader import PluginLoader + +if T.TYPE_CHECKING: + from argparse import Namespace + from collections.abc import Callable + from lib.align.aligned_face import AlignedFace, CenteringType + from lib.align.detected_face import DetectedFace + from lib.queue_manager import EventQueue + from scripts.convert import ConvertItem + from plugins.convert.color._base import Adjustment as ColorAdjust + from plugins.convert.color.seamless_clone import Color as SeamlessAdjust + from plugins.convert.mask.mask_blend import Mask as MaskAdjust + from plugins.convert.scaling._base import Adjustment as ScalingAdjust + +logger = logging.getLogger(__name__) + + +@dataclass +class Adjustments: + """ Dataclass to hold the optional processing plugins + + Parameters + ---------- + color: :class:`~plugins.color._base.Adjustment`, Optional + The selected color processing plugin. Default: `None` + mask: :class:`~plugins.mask_blend.Mask`, Optional + The selected mask processing plugin. Default: `None` + seamless: :class:`~plugins.color.seamless_clone.Color`, Optional + The selected mask processing plugin. Default: `None` + sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional + The selected mask processing plugin. Default: `None` + """ + color: ColorAdjust | None = None + mask: MaskAdjust | None = None + seamless: SeamlessAdjust | None = None + sharpening: ScalingAdjust | None = None + + +class Converter(): # pylint:disable=too-many-instance-attributes + """ The converter is responsible for swapping the original face(s) in a frame with the output + of a trained Faceswap model. + + Parameters + ---------- + output_size: int + The size of the face, in pixels, that is output from the Faceswap model + coverage_ratio: float + The ratio of the training image that was used for training the Faceswap model + centering: str + The extracted face centering that the model was trained on (`"face"` or "`legacy`") + draw_transparent: bool + Whether the final output should be drawn onto a transparent layer rather than the original + frame. Only available with certain writer plugins. + pre_encode: python function + Some writer plugins support the pre-encoding of images prior to saving out. As patching is + done in multiple threads, but writing is done in a single thread, it can speed up the + process to do any pre-encoding as part of the converter process. + arguments: :class:`argparse.Namespace` + The arguments that were passed to the convert process as generated from Faceswap's command + line arguments + configfile: str, optional + Optional location of custom configuration ``ini`` file. If ``None`` then use the default + config location. Default: ``None`` + """ + def __init__(self, + output_size: int, + coverage_ratio: float, + centering: CenteringType, + draw_transparent: bool, + pre_encode: Callable | None, + arguments: Namespace, + configfile: str | None = None) -> None: + logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, " + "draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)", + self.__class__.__name__, output_size, coverage_ratio, centering, + draw_transparent, pre_encode, arguments, configfile) + self._output_size = output_size + self._coverage_ratio = coverage_ratio + self._centering: CenteringType = centering + self._draw_transparent = draw_transparent + self._writer_pre_encode = pre_encode + self._args = arguments + self._configfile = configfile + + self._scale = arguments.output_scale / 100 + self._face_scale = 1.0 - arguments.face_scale / 100. + self._adjustments = Adjustments() + self._full_frame_output: bool = arguments.writer != "patch" + + self._load_plugins() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def cli_arguments(self) -> Namespace: + """:class:`argparse.Namespace`: The command line arguments passed to the convert + process """ + return self._args + + def reinitialize(self) -> None: + """ Reinitialize this :class:`Converter`. + + Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the + plugins as specified in the current config. + """ + logger.debug("Reinitializing converter") + self._face_scale = 1.0 - self._args.face_scale / 100. + self._adjustments = Adjustments() + self._load_plugins(disable_logging=True) + logger.debug("Reinitialized converter") + + def _load_plugins(self, disable_logging: bool = False) -> None: + """ Load the requested adjustment plugins. + + Loads the :mod:`plugins.converter` plugins that have been requested for this conversion + session. + + Parameters + ---------- + config: :class:`lib.config.FaceswapConfig`, optional + Optional pre-loaded :class:`lib.config.FaceswapConfig`. If passed, then this will be + used over any configuration on disk. If ``None`` then it is ignored. Default: ``None`` + """ + logger.debug("Loading plugins. disable_logging: %s", disable_logging) + self._adjustments.mask = PluginLoader.get_converter("mask", + "mask_blend", + disable_logging=disable_logging)( + self._args.mask_type, + self._output_size, + self._coverage_ratio, + configfile=self._configfile) + + if self._args.color_adjustment is not None: + self._adjustments.color = PluginLoader.get_converter("color", + self._args.color_adjustment, + disable_logging=disable_logging)( + configfile=self._configfile) + + sharpening = PluginLoader.get_converter("scaling", + "sharpen", + disable_logging=disable_logging)( + configfile=self._configfile) + self._adjustments.sharpening = sharpening + logger.debug("Loaded plugins: %s", self._adjustments) + + def process(self, in_queue: EventQueue, out_queue: EventQueue): + """ Main convert process. + + Takes items from the in queue, runs the relevant adjustments, patches faces to final frame + and outputs patched frame to the out queue. + + Parameters + ---------- + in_queue: :class:`~lib.queue_manager.EventQueue` + The output from :class:`scripts.convert.Predictor`. Contains detected faces from the + Faceswap model as well as the frame to be patched. + out_queue: :class:`~lib.queue_manager.EventQueue` + The queue to place patched frames into for writing by one of Faceswap's + :mod:`plugins.convert.writer` plugins. + """ + logger.debug("Starting convert process. (in_queue: %s, out_queue: %s)", + in_queue, out_queue) + logged = False + while True: + inbound: T.Literal["EOF"] | ConvertItem | list[ConvertItem] = in_queue.get() + if inbound == "EOF": + logger.debug("EOF Received") + logger.debug("Patch queue finished") + # Signal EOF to other processes in pool + logger.debug("Putting EOF back to in_queue") + in_queue.put(inbound) + break + + items = inbound if isinstance(inbound, list) else [inbound] + for item in items: + logger.trace("Patch queue got: '%s'", # type: ignore[attr-defined] + item.inbound.filename) + try: + image = self._patch_image(item) + except Exception as err: # pylint:disable=broad-except + # Log error and output original frame + logger.error("Failed to convert image: '%s'. Reason: %s", + item.inbound.filename, str(err)) + image = item.inbound.image + + lvl = logger.trace if logged else logger.warning # type: ignore[attr-defined] + lvl("Convert error traceback:", exc_info=True) + logged = True + # UNCOMMENT THIS CODE BLOCK TO PRINT TRACEBACK ERRORS + # import sys; import traceback + # exc_info = sys.exc_info(); traceback.print_exception(*exc_info) + logger.trace("Out queue put: %s", # type: ignore[attr-defined] + item.inbound.filename) + out_queue.put((item.inbound.filename, image)) + logger.debug("Completed convert process") + + def _get_warp_matrix(self, matrix: np.ndarray, size: int) -> np.ndarray: + """ Obtain the final scaled warp transformation matrix based on face scaling from the + original transformation matrix + + Parameters + ---------- + matrix: :class:`numpy.ndarray` + The transformation for patching the swapped face back onto the output frame + size: int + The size of the face patch, in pixels + + Returns + ------- + :class:`numpy.ndarray` + The final transformation matrix with any scaling applied + """ + if self._face_scale == 1.0: + mat = matrix + else: + mat = matrix * self._face_scale + patch_center = (size / 2, size / 2) + mat[..., 2] += (1 - self._face_scale) * np.array(patch_center) + + return mat + + def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]: + """ Patch a swapped face onto a frame. + + Run selected adjustments and swap the faces in a frame. + + Parameters + ---------- + predicted: :class:`~scripts.convert.ConvertItem` + The output from :class:`scripts.convert.Predictor`. + + Returns + ------- + :class: `numpy.ndarray` or pre-encoded image output + The final frame ready for writing by a :mod:`plugins.convert.writer` plugin. + Frame is either an array, or the pre-encoded output from the writer's pre-encode + function (if it has one) + + """ + logger.trace("Patching image: '%s'", # type: ignore[attr-defined] + predicted.inbound.filename) + frame_size = (predicted.inbound.image.shape[1], predicted.inbound.image.shape[0]) + new_image, background = self._get_new_image(predicted, frame_size) + + if self._full_frame_output: + patched_face = self._post_warp_adjustments(background, new_image) + patched_face = self._scale_image(patched_face) + patched_face *= 255.0 + patched_face = np.rint(patched_face, + out=np.empty(patched_face.shape, dtype="uint8"), + casting='unsafe') + else: + patched_face = new_image + + if self._writer_pre_encode is None: + retval: np.ndarray | list[bytes] = patched_face + else: + kwargs: dict[str, T.Any] = {} + if self.cli_arguments.writer == "patch": + kwargs["canvas_size"] = (background.shape[1], background.shape[0]) + kwargs["matrices"] = np.array([self._get_warp_matrix(face.adjusted_matrix, + patched_face.shape[1]) + for face in predicted.reference_faces], + dtype="float32") + retval = self._writer_pre_encode(patched_face, **kwargs) + logger.trace("Patched image: '%s'", # type: ignore[attr-defined] + predicted.inbound.filename) + return retval + + def _warp_to_frame(self, + reference: AlignedFace, + face: np.ndarray, + frame: np.ndarray, + multiple_faces: bool) -> None: + """ Perform affine transformation to place a face patch onto the given frame. + + Affine is done in place on the `frame` array, so this function does not return a value + + Parameters + ---------- + reference: :class:`lib.align.AlignedFace` + The object holding the original aligned face + face: :class:`numpy.ndarray` + The swapped face patch + frame: :class:`numpy.ndarray` + The frame to affine the face onto + multiple_faces: bool + Controls the border mode to use. Uses BORDER_CONSTANT if there is only 1 face in + the image, otherwise uses the inferior BORDER_TRANSPARENT + """ + # Warp face with the mask + mat = self._get_warp_matrix(reference.adjusted_matrix, face.shape[0]) + border = cv2.BORDER_TRANSPARENT if multiple_faces else cv2.BORDER_CONSTANT + cv2.warpAffine(face, + mat, + (frame.shape[1], frame.shape[0]), + frame, + flags=cv2.WARP_INVERSE_MAP | reference.interpolators[1], + borderMode=border) + + def _get_new_image(self, + predicted: ConvertItem, + frame_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]: + """ Get the new face from the predictor and apply pre-warp manipulations. + + Applies any requested adjustments to the raw output of the Faceswap model + before transforming the image into the target frame. + + Parameters + ---------- + predicted: :class:`~scripts.convert.ConvertItem` + The output from :class:`scripts.convert.Predictor`. + frame_size: tuple + The (`width`, `height`) of the final frame in pixels + + Returns + ------- + placeholder: :class: `numpy.ndarray` + The original frame with the swapped faces patched onto it + background: :class: `numpy.ndarray` + The original frame + """ + logger.trace("Getting: (filename: '%s', faces: %s)", # type: ignore[attr-defined] + predicted.inbound.filename, len(predicted.swapped_faces)) + + placeholder: np.ndarray = np.zeros((frame_size[1], frame_size[0], 4), dtype="float32") + faces: list[np.ndarray] | None = None + if self._full_frame_output: + background = predicted.inbound.image / np.array(255.0, dtype="float32") + placeholder[:, :, :3] = background + else: + faces = [] # Collect the faces into final array + background = placeholder # Used for obtaining original frame dimensions + + for new_face, detected_face, reference_face in zip(predicted.swapped_faces, + predicted.inbound.detected_faces, + predicted.reference_faces): + predicted_mask = new_face[:, :, -1] if new_face.shape[2] == 4 else None + new_face = new_face[:, :, :3] + new_face = self._pre_warp_adjustments(new_face, + detected_face, + reference_face, + predicted_mask) + + if self._full_frame_output: + self._warp_to_frame(reference_face, + new_face, placeholder, + len(predicted.swapped_faces) > 1) + else: + assert faces is not None + faces.append(new_face) + + if not self._full_frame_output: + placeholder = np.array(faces, dtype="float32") + + logger.trace("Got filename: '%s'. (placeholders: %s)", # type: ignore[attr-defined] + predicted.inbound.filename, placeholder.shape) + + return placeholder, background + + def _pre_warp_adjustments(self, + new_face: np.ndarray, + detected_face: DetectedFace, + reference_face: AlignedFace, + predicted_mask: np.ndarray | None) -> np.ndarray: + """ Run any requested adjustments that can be performed on the raw output from the Faceswap + model. + + Any adjustments that can be performed before warping the face into the final frame are + performed here. + + Parameters + ---------- + new_face: :class:`numpy.ndarray` + The swapped face received from the faceswap model. + detected_face: :class:`~lib.align.DetectedFace` + The detected_face object as defined in :class:`scripts.convert.Predictor` + reference_face: :class:`~lib.align.AlignedFace` + The aligned face object sized to the model output of the original face for reference + predicted_mask: :class:`numpy.ndarray` or ``None`` + The predicted mask output from the Faceswap model. ``None`` if the model + did not learn a mask + + Returns + ------- + :class:`numpy.ndarray` + The face output from the Faceswap Model with any requested pre-warp adjustments + performed. + """ + logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore[attr-defined] + new_face.shape, predicted_mask.shape if predicted_mask is not None else None) + old_face = T.cast(np.ndarray, reference_face.face)[..., :3] / 255.0 + new_face, raw_mask = self._get_image_mask(new_face, + detected_face, + predicted_mask, + reference_face) + if self._adjustments.color is not None: + new_face = self._adjustments.color.run(old_face, new_face, raw_mask) + if self._adjustments.seamless is not None: + new_face = self._adjustments.seamless.run(old_face, new_face, raw_mask) + logger.trace("returning: new_face shape %s", new_face.shape) # type: ignore[attr-defined] + return new_face + + def _get_image_mask(self, + new_face: np.ndarray, + detected_face: DetectedFace, + predicted_mask: np.ndarray | None, + reference_face: AlignedFace) -> tuple[np.ndarray, np.ndarray]: + """ Return any selected image mask + + Places the requested mask into the new face's Alpha channel. + + Parameters + ---------- + new_face: :class:`numpy.ndarray` + The swapped face received from the faceswap model. + detected_face: :class:`~lib.DetectedFace` + The detected_face object as defined in :class:`scripts.convert.Predictor` + predicted_mask: :class:`numpy.ndarray` or ``None`` + The predicted mask output from the Faceswap model. ``None`` if the model + did not learn a mask + reference_face: :class:`~lib.align.AlignedFace` + The aligned face object sized to the model output of the original face for reference + + Returns + ------- + :class:`numpy.ndarray` + The swapped face with the requested mask added to the Alpha channel + :class:`numpy.ndarray` + The raw mask with no erosion or blurring applied + """ + logger.trace("Getting mask. Image shape: %s", new_face.shape) # type: ignore[attr-defined] + mask_centering: CenteringType + if self._args.mask_type not in ("none", "predicted"): + mask_centering = detected_face.mask[self._args.mask_type].stored_centering + else: + mask_centering = "face" # Unused but requires a valid value + assert self._adjustments.mask is not None + mask, raw_mask = self._adjustments.mask.run(detected_face, + reference_face.pose.offset[mask_centering], + reference_face.pose.offset[self._centering], + self._centering, + predicted_mask=predicted_mask) + logger.trace("Adding mask to alpha channel") # type: ignore[attr-defined] + new_face = np.concatenate((new_face, mask), -1) + logger.trace("Got mask. Image shape: %s", new_face.shape) # type: ignore[attr-defined] + return new_face, raw_mask + + def _post_warp_adjustments(self, background: np.ndarray, new_image: np.ndarray) -> np.ndarray: + """ Perform any requested adjustments to the swapped faces after they have been transformed + into the final frame. + + Parameters + ---------- + background: :class:`numpy.ndarray` + The original frame + new_image: :class:`numpy.ndarray` + A blank frame of original frame size with the faces warped onto it + + Returns + ------- + :class:`numpy.ndarray` + The final merged and swapped frame with any requested post-warp adjustments applied + """ + if self._adjustments.sharpening is not None: + new_image = self._adjustments.sharpening.run(new_image) + + if self._draw_transparent: + frame = new_image + else: + foreground, mask = np.split(new_image, # pylint:disable=unbalanced-tuple-unpacking + (3, ), + axis=-1) + foreground *= mask + background *= (1.0 - mask) + background += foreground + frame = background + np.clip(frame, 0.0, 1.0, out=frame) + return frame + + def _scale_image(self, frame: np.ndarray) -> np.ndarray: + """ Scale the final image if requested. + + If output scale has been requested in command line arguments, scale the output + otherwise return the final frame. + + Parameters + ---------- + frame: :class:`numpy.ndarray` + The final frame with faces swapped + + Returns + ------- + :class:`numpy.ndarray` + The final frame scaled by the requested scaling factor + """ + if self._scale == 1: + return frame + logger.trace("source frame: %s", frame.shape) # type: ignore[attr-defined] + interp = cv2.INTER_CUBIC if self._scale > 1 else cv2.INTER_AREA + dims = (round((frame.shape[1] / 2 * self._scale) * 2), + round((frame.shape[0] / 2 * self._scale) * 2)) + frame = cv2.resize(frame, dims, interpolation=interp) + logger.trace("resized frame: %s", frame.shape) # type: ignore[attr-defined] + np.clip(frame, 0.0, 1.0, out=frame) + return frame + + +__all__ = get_module_objects(__name__) diff --git a/lib/face_filter.py b/lib/face_filter.py deleted file mode 100644 index a373fcb458..0000000000 --- a/lib/face_filter.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin python3 -""" Face Filterer for extraction in faceswap.py """ - -import logging - -import face_recognition - - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -def avg(arr): - """ Return an average """ - return sum(arr) * 1.0 / len(arr) - - -class FaceFilter(): - """ Face filter for extraction """ - def __init__(self, reference_file_paths, nreference_file_paths, threshold=0.6): - logger.debug("Initializing %s: (reference_file_paths: %s, nreference_file_paths: %s, " - "threshold: %s)", self.__class__.__name__, reference_file_paths, - nreference_file_paths, threshold) - images = list(map(face_recognition.load_image_file, reference_file_paths)) - nimages = list(map(face_recognition.load_image_file, nreference_file_paths)) - # Note: we take only first face, so the reference file should only contain one face. - self.encodings = list(map(lambda im: face_recognition.face_encodings(im)[0], images)) - self.nencodings = list(map(lambda im: face_recognition.face_encodings(im)[0], nimages)) - self.threshold = threshold - logger.trace("encodings: %s", self.encodings) - logger.trace("nencodings: %s", self.nencodings) - logger.debug("Initialized %s", self.__class__.__name__) - - def check(self, detected_face): - """ Check Face - we could use detected landmarks, but I did not manage to do so. - TODO The copy/paste below should help """ - logger.trace("Checking face with FaceFilter") - encodings = face_recognition.face_encodings(detected_face.image) - if not encodings: - logger.verbose("No face encodings found") - return False - - if self.encodings: - distances = list(face_recognition.face_distance(self.encodings, encodings[0])) - logger.trace("Distances: %s", distances) - distance = avg(distances) - logger.trace("Average Distance: %s", distance) - mindistance = min(distances) - logger.trace("Minimum Distance: %s", mindistance) - if distance > self.threshold: - logger.verbose("Distance above threshold: %f < %f", distance, self.threshold) - return False - if self.nencodings: - ndistances = list(face_recognition.face_distance(self.nencodings, encodings[0])) - logger.trace("nDistances: %s", ndistances) - ndistance = avg(ndistances) - logger.trace("Average nDistance: %s", ndistance) - nmindistance = min(ndistances) - logger.trace("Minimum nDistance: %s", nmindistance) - if not self.encodings and ndistance < self.threshold: - logger.verbose("nDistance below threshold: %f < %f", ndistance, self.threshold) - return False - if self.encodings: - if mindistance > nmindistance: - logger.verbose("Distance to negative sample is smaller") - return False - if distance > ndistance: - logger.verbose("Average distance to negative sample is smaller") - return False - # k-nn classifier - var_k = min(5, min(len(distances), len(ndistances)) + 1) - var_n = sum(list(map(lambda x: x[0], - list(sorted([(1, d) for d in distances] + - [(0, d) for d in ndistances], - key=lambda x: x[1]))[:var_k]))) - ratio = var_n/var_k - if ratio < 0.5: - logger.verbose("K-nn is %.2f", ratio) - return False - return True - - -# # Copy/Paste (mostly) from private method in face_recognition -# face_recognition_model = face_recognition_models.face_recognition_model_location() -# face_encoder = dlib.face_recognition_model_v1(face_recognition_model) - -# def convert(detected_face): -# return np.array(face_encoder.compute_face_descriptor(detected_face.image, -# detected_face.landmarks, -# 1)) -# # end of Copy/Paste diff --git a/lib/faces_detect.py b/lib/faces_detect.py deleted file mode 100644 index c060a71a41..0000000000 --- a/lib/faces_detect.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin python3 -""" Face and landmarks detection for faceswap.py """ -import logging - -from dlib import rectangle as d_rectangle # pylint: disable=no-name-in-module -from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class DetectedFace(): - """ Detected face and landmark information """ - def __init__( # pylint: disable=invalid-name - self, image=None, x=None, w=None, y=None, h=None, - landmarksXY=None): - logger.trace("Initializing %s", self.__class__.__name__) - self.image = image - self.x = x - self.w = w - self.y = y - self.h = h - self.landmarksXY = landmarksXY - self.hash = None # Hash must be set when the file is saved due to image compression - - self.aligned = dict() - logger.trace("Initialized %s", self.__class__.__name__) - - @property - def landmarks_as_xy(self): - """ Landmarks as XY """ - return self.landmarksXY - - def to_dlib_rect(self): - """ Return Bounding Box as Dlib Rectangle """ - left = self.x - top = self.y - right = self.x + self.w - bottom = self.y + self.h - retval = d_rectangle(left, top, right, bottom) - logger.trace("Returning: %s", retval) - return retval - - def from_dlib_rect(self, d_rect, image=None): - """ Set Bounding Box from a Dlib Rectangle """ - logger.trace("Creating from dlib_rectangle: %s", d_rect) - if not isinstance(d_rect, d_rectangle): - raise ValueError("Supplied Bounding Box is not a dlib.rectangle.") - self.x = d_rect.left() - self.w = d_rect.right() - d_rect.left() - self.y = d_rect.top() - self.h = d_rect.bottom() - d_rect.top() - if image is not None and image.any(): - self.image_to_face(image) - logger.trace("Created from dlib_rectangle: (x: %s, w: %s, y: %s. h: %s)", - self.x, self.w, self.y, self.h) - - def image_to_face(self, image): - """ Crop an image around bounding box to the face - and capture it's dimensions """ - logger.trace("Cropping face from image") - self.image = image[self.y: self.y + self.h, - self.x: self.x + self.w] - - def to_alignment(self): - """ Convert a detected face to alignment dict """ - alignment = dict() - alignment["x"] = self.x - alignment["w"] = self.w - alignment["y"] = self.y - alignment["h"] = self.h - alignment["landmarksXY"] = self.landmarksXY - alignment["hash"] = self.hash - logger.trace("Returning: %s", alignment) - return alignment - - def from_alignment(self, alignment, image=None): - """ Convert a face alignment to detected face object """ - logger.trace("Creating from alignment: (alignment: %s, has_image: %s)", - alignment, bool(image is not None)) - self.x = alignment["x"] - self.w = alignment["w"] - self.y = alignment["y"] - self.h = alignment["h"] - self.landmarksXY = alignment["landmarksXY"] - # Manual tool does not know the final hash so default to None - self.hash = alignment.get("hash", None) - if image is not None and image.any(): - self.image_to_face(image) - logger.trace("Created from alignment: (x: %s, w: %s, y: %s. h: %s, " - "landmarks: %s)", - self.x, self.w, self.y, self.h, self.landmarksXY) - - # <<< Aligned Face methods and properties >>> # - def load_aligned(self, image, size=256, align_eyes=False): - """ No need to load aligned information for all uses of this - class, so only call this to load the information for easy - reference to aligned properties for this face """ - logger.trace("Loading aligned face: (size: %s, align_eyes: %s)", size, align_eyes) - padding = int(size * 0.1875) - self.aligned["size"] = size - self.aligned["padding"] = padding - self.aligned["align_eyes"] = align_eyes - self.aligned["matrix"] = get_align_mat(self, size, align_eyes) - if image is None: - self.aligned["face"] = None - else: - self.aligned["face"] = AlignerExtract().transform( - image, - self.aligned["matrix"], - size, - padding) - logger.trace("Loaded aligned face: %s", {key: val - for key, val in self.aligned.items() - if key != "face"}) - - @property - def original_roi(self): - """ Return the square aligned box location on the original - image """ - roi = AlignerExtract().get_original_roi(self.aligned["matrix"], - self.aligned["size"], - self.aligned["padding"]) - logger.trace("Returning: %s", roi) - return roi - - @property - def aligned_landmarks(self): - """ Return the landmarks location transposed to extracted face """ - landmarks = AlignerExtract().transform_points(self.landmarksXY, - self.aligned["matrix"], - self.aligned["size"], - self.aligned["padding"]) - logger.trace("Returning: %s", landmarks) - return landmarks - - @property - def aligned_face(self): - """ Return aligned detected face """ - return self.aligned["face"] - - @property - def adjusted_matrix(self): - """ Return adjusted matrix for size/padding combination """ - mat = AlignerExtract().transform_matrix(self.aligned["matrix"], - self.aligned["size"], - self.aligned["padding"]) - logger.trace("Returning: %s", mat) - return mat - - @property - def adjusted_interpolators(self): - """ Return the interpolator and reverse interpolator for the adjusted matrix """ - return get_matrix_scaling(self.adjusted_matrix) diff --git a/lib/git.py b/lib/git.py new file mode 100644 index 0000000000..3460eba394 --- /dev/null +++ b/lib/git.py @@ -0,0 +1,162 @@ +#!/usr/bin python3 +""" Handles command line calls to git """ +import logging +import os +import sys + +from subprocess import PIPE, Popen + +from lib.utils import get_module_objects + +logger = logging.getLogger(__name__) + + +class Git(): + """ Handles calls to github """ + def __init__(self) -> None: + logger.debug("Initializing: %s", self.__class__.__name__) + self._working_dir = os.path.dirname(os.path.realpath(sys.argv[0])) + self._available = self._check_available() + logger.debug("Initialized: %s", self.__class__.__name__) + + def _from_git(self, command: str) -> tuple[bool, list[str]]: + """ Execute a git command + + Parameters + ---------- + command : str + The command to send to git + + Returns + ------- + success: bool + ``True`` if the command succesfully executed otherwise ``False`` + list[str] + The output lines from stdout if there was no error, otherwise from stderr + """ + logger.debug("command: '%s'", command) + cmd = f"git {command}" + with Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE, cwd=self._working_dir) as proc: + stdout, stderr = proc.communicate() + retcode = proc.returncode + success = retcode == 0 + lines = stdout.decode("utf-8", errors="replace").splitlines() + if not lines: + lines = stderr.decode("utf-8", errors="replace").splitlines() + logger.debug("command: '%s', returncode: %s, success: %s, lines: %s", + cmd, retcode, success, lines) + return success, lines + + def _check_available(self) -> bool: + """ Check if git is available. Does a call to git status. If the process errors due to + folder ownership, attempts to add the folder to github safe folders list and tries + again + + Returns + ------- + bool + ``True`` if git is available otherwise ``False`` + + """ + success, msg = self._from_git("status") + if success: + return True + config = next((line.strip() for line in msg if "add safe.directory" in line), None) + if not config: + return False + success, _ = self._from_git(config.split("git ", 1)[-1]) + return True + + @property + def status(self) -> list[str]: + """ Obtain the output of git status for tracked files only """ + if not self._available: + return [] + success, status = self._from_git("status -uno") + if not success or not status: + return [] + return status + + @property + def branch(self) -> str: + """ str: The git branch that is currently being used to execute Faceswap. """ + status = next((line.strip() for line in self.status if "On branch" in line), "Not Found") + return status.replace("On branch ", "") + + @property + def branches(self) -> list[str]: + """ list[str]: List of all available branches. """ + if not self._available: + return [] + success, branches = self._from_git("branch -a") + if not success or not branches: + return [] + return branches + + def update_remote(self) -> bool: + """ Update all branches to track remote + + Returns + ------- + bool + ``True`` if update was succesful otherwise ``False`` + """ + if not self._available: + return False + return self._from_git("remote update")[0] + + def pull(self) -> bool: + """ Pull the current branch + + Returns + ------- + bool + ``True`` if pull is successful otherwise ``False`` + """ + if not self._available: + return False + return self._from_git("pull")[0] + + def checkout(self, branch: str) -> bool: + """ Checkout the requested branch + + Parameters + ---------- + branch : str + The branch to checkout + + Returns + ------- + bool + ``True`` if the branch was succesfully checkout out otherwise ``False`` + """ + if not self._available: + return False + return self._from_git(f"checkout {branch}")[0] + + def get_commits(self, count: int) -> list[str]: + """ Obtain the last commits to the repo + + Parameters + ---------- + count : int + The last number of commits to obtain + + Returns + ------- + list[str] + list of commits, or empty list if none found + """ + if not self._available: + return [] + success, commits = self._from_git(f"log --pretty=oneline --abbrev-commit -n {count}") + if not success or not commits: + return [] + return commits + + +git = Git() +""" :class:`Git`: Handles calls to github """ + + +__all__ = get_module_objects(__name__) diff --git a/lib/gpu_stats.py b/lib/gpu_stats.py deleted file mode 100644 index a0ecbc1657..0000000000 --- a/lib/gpu_stats.py +++ /dev/null @@ -1,189 +0,0 @@ -#!/usr/bin python3 -""" Information on available Nvidia GPUs """ - -import logging -import platform - -if platform.system() == 'Darwin': - import pynvx # pylint: disable=import-error - IS_MACOS = True -else: - import pynvml - IS_MACOS = False - - -class GPUStats(): - """ Holds information about system GPU(s) """ - def __init__(self, log=True): - self.logger = None - if log: - # Logger is held internally, as we don't want to log - # when obtaining system stats on crash - self.logger = logging.getLogger(__name__) # pylint: disable=invalid-name - self.logger.debug("Initializing %s", self.__class__.__name__) - - self.initialized = False - self.device_count = 0 - self.handles = None - self.driver = None - self.devices = None - self.vram = None - - self.initialize() - - self.driver = self.get_driver() - self.devices = self.get_devices() - self.vram = self.get_vram() - if self.device_count == 0: - if self.logger: - self.logger.warning("No GPU detected. Switching to CPU mode") - return - - self.shutdown() - if self.logger: - self.logger.debug("Initialized %s", self.__class__.__name__) - - def initialize(self): - """ Initialize pynvml """ - if not self.initialized: - if IS_MACOS: - if self.logger: - self.logger.debug("macOS Detected. Using pynvx") - try: - pynvx.cudaInit() - except RuntimeError: - self.initialized = True - return - else: - try: - if self.logger: - self.logger.debug("OS is not macOS. Using pynvml") - pynvml.nvmlInit() - except (pynvml.NVMLError_LibraryNotFound, pynvml.NVMLError_DriverNotLoaded): - self.initialized = True - return - self.initialized = True - self.get_device_count() - self.get_handles() - - def shutdown(self): - """ Shutdown pynvml """ - if self.initialized: - self.handles = None - if not IS_MACOS: - pynvml.nvmlShutdown() - self.initialized = False - - def get_device_count(self): - """ Return count of Nvidia devices """ - if IS_MACOS: - self.device_count = pynvx.cudaDeviceGetCount(ignore=True) - else: - try: - self.device_count = pynvml.nvmlDeviceGetCount() - except pynvml.NVMLError: - self.device_count = 0 - if self.logger: - self.logger.debug("GPU Device count: %s", self.device_count) - - def get_handles(self): - """ Return all listed Nvidia handles """ - if IS_MACOS: - self.handles = pynvx.cudaDeviceGetHandles(ignore=True) - else: - self.handles = [pynvml.nvmlDeviceGetHandleByIndex(i) - for i in range(self.device_count)] - if self.logger: - self.logger.debug("GPU Handles found: %s", len(self.handles)) - - def get_driver(self): - """ Get the driver version """ - if IS_MACOS: - driver = pynvx.cudaSystemGetDriverVersion(ignore=True) - else: - try: - driver = pynvml.nvmlSystemGetDriverVersion().decode("utf-8") - except pynvml.NVMLError: - driver = "No Nvidia driver found" - if self.logger: - self.logger.debug("GPU Driver: %s", driver) - return driver - - def get_devices(self): - """ Return name of devices """ - self.initialize() - if self.device_count == 0: - names = list() - elif IS_MACOS: - names = [pynvx.cudaGetName(handle, ignore=True) - for handle in self.handles] - else: - names = [pynvml.nvmlDeviceGetName(handle).decode("utf-8") - for handle in self.handles] - if self.logger: - self.logger.debug("GPU Devices: %s", names) - return names - - def get_vram(self): - """ Return total vram in megabytes per device """ - self.initialize() - if self.device_count == 0: - vram = list() - elif IS_MACOS: - vram = [pynvx.cudaGetMemTotal(handle, ignore=True) / (1024 * 1024) - for handle in self.handles] - else: - vram = [pynvml.nvmlDeviceGetMemoryInfo(handle).total / - (1024 * 1024) - for handle in self.handles] - if self.logger: - self.logger.debug("GPU VRAM: %s", vram) - return vram - - def get_used(self): - """ Return the vram in use """ - self.initialize() - if IS_MACOS: - vram = [pynvx.cudaGetMemUsed(handle, ignore=True) / (1024 * 1024) - for handle in self.handles] - else: - vram = [pynvml.nvmlDeviceGetMemoryInfo(handle).used / (1024 * 1024) - for handle in self.handles] - self.shutdown() - - if self.logger: - self.logger.verbose("GPU VRAM used: %s", vram) - return vram - - def get_free(self): - """ Return the vram available """ - self.initialize() - if IS_MACOS: - vram = [pynvx.cudaGetMemFree(handle, ignore=True) / (1024 * 1024) - for handle in self.handles] - else: - vram = [pynvml.nvmlDeviceGetMemoryInfo(handle).free / (1024 * 1024) - for handle in self.handles] - self.shutdown() - if self.logger: - self.logger.debug("GPU VRAM free: %s", vram) - return vram - - def get_card_most_free(self): - """ Return the card and available VRAM for card with - most VRAM free """ - if self.device_count == 0: - return {"card_id": -1, - "device": "No Nvidia devices found", - "free": 2048, - "total": 2048} - free_vram = self.get_free() - vram_free = max(free_vram) - card_id = free_vram.index(vram_free) - retval = {"card_id": card_id, - "device": self.devices[card_id], - "free": vram_free, - "total": self.vram[card_id]} - if self.logger: - self.logger.debug("GPU Card with most free VRAM: %s", retval) - return retval diff --git a/lib/gpu_stats/__init__.py b/lib/gpu_stats/__init__.py new file mode 100644 index 0000000000..71246ba31c --- /dev/null +++ b/lib/gpu_stats/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +""" Dynamically import the correct GPU Stats library based on the faceswap backend and the machine +being used. """ + +from lib.utils import get_backend + +from ._base import GPUInfo, _GPUStats + +backend = get_backend() + +GPUStats: type[_GPUStats] | None +try: + if backend == "nvidia": + from .nvidia import NvidiaStats as GPUStats + elif backend == "apple_silicon": + from .apple_silicon import AppleSiliconStats as GPUStats + elif backend == "rocm": + from .rocm import ROCm as GPUStats + else: + from .cpu import CPUStats as GPUStats +except (ImportError, ModuleNotFoundError): + GPUStats = None diff --git a/lib/gpu_stats/_base.py b/lib/gpu_stats/_base.py new file mode 100644 index 0000000000..cc7df62785 --- /dev/null +++ b/lib/gpu_stats/_base.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +""" Parent class for obtaining Stats for various GPU/TPU backends. All GPU Stats should inherit +from the :class:`_GPUStats` class contained here. """ + +import logging + +from dataclasses import dataclass + +from lib.utils import get_backend + +_EXCLUDE_DEVICES: list[int] = [] + + +@dataclass +class GPUInfo(): + """Dataclass for storing information about the available GPUs on the system. + + Attributes: + ---------- + vram: list[int] + List of integers representing the total VRAM available on each GPU, in MB. + vram_free: list[int] + List of integers representing the free VRAM available on each GPU, in MB. + driver: str + String representing the driver version being used for the GPUs. + devices: list[str] + List of strings representing the names of each GPU device. + devices_active: list[int] + List of integers representing the indices of the active GPU devices. + """ + vram: list[int] + vram_free: list[int] + driver: str + devices: list[str] + devices_active: list[int] + + +@dataclass +class BiggestGPUInfo(): + """ Dataclass for holding GPU Information about the card with most available VRAM. + + Attributes + ---------- + card_id: int + Integer representing the index of the GPU device. + device: str + The name of the device + free: float + The amount of available VRAM on the GPU + total: float + the total amount of VRAM on the GPU + """ + card_id: int + device: str + free: float + total: float + + +class _GPUStats(): + """ Parent class for collecting GPU device information. + + Parameters: + ----------- + log : bool, optional + Flag indicating whether or not to log debug messages. Default: `True`. + """ + + def __init__(self, log: bool = True) -> None: + # Logger is held internally, as we don't want to log when obtaining system stats on crash + # or when querying the backend for command line options + self._logger: logging.Logger | None = logging.getLogger(__name__) if log else None + self._log("debug", f"Initializing {self.__class__.__name__}") + + self._is_initialized = False + self._initialize() + + self._device_count: int = self._get_device_count() + self._active_devices: list[int] = self._get_active_devices() + self._handles: list = self._get_handles() + self._driver: str = self._get_driver() + self._device_names: list[str] = self._get_device_names() + self._vram: list[int] = self._get_vram() + self._vram_free: list[int] = self._get_free_vram() + + if get_backend() != "cpu" and not self._active_devices: + self._log("warning", "No GPU detected") + + self._shutdown() + self._log("debug", f"Initialized {self.__class__.__name__}") + + @property + def device_count(self) -> int: + """int: The number of GPU devices discovered on the system. """ + return self._device_count + + @property + def cli_devices(self) -> list[str]: + """ list[str]: Formatted index: name text string for each GPU """ + return [f"{idx}: {device}" for idx, device in enumerate(self._device_names)] + + @property + def exclude_all_devices(self) -> bool: + """ bool: ``True`` if all GPU devices have been explicitly disabled otherwise ``False`` """ + return all(idx in _EXCLUDE_DEVICES for idx in range(self._device_count)) + + @property + def sys_info(self) -> GPUInfo: + """ :class:`GPUInfo`: The GPU Stats that are required for system information logging """ + return GPUInfo(vram=self._vram, + vram_free=self._get_free_vram(), + driver=self._driver, + devices=self._device_names, + devices_active=self._active_devices) + + def _log(self, level: str, message: str) -> None: + """ If the class has been initialized with :attr:`log` as `True` then log the message + otherwise skip logging. + + Parameters + ---------- + level: str + The log level to log at + message: str + The message to log + """ + if self._logger is None: + return + logger = getattr(self._logger, level.lower()) + logger(message) + + def _initialize(self) -> None: + """ Override to initialize the GPU device handles and any other necessary resources. """ + self._is_initialized = True + + def _shutdown(self) -> None: + """ Override to shutdown the GPU device handles and any other necessary resources. """ + self._is_initialized = False + + def _get_device_count(self) -> int: + """ Override to obtain the number of GPU devices + + Returns + ------- + int + The total number of GPUs connected to the PC + """ + raise NotImplementedError() + + def _get_active_devices(self) -> list[int]: + """ Obtain the indices of active GPUs (those that have not been explicitly excluded in + the command line arguments). + + Notes + ----- + Override for GPU specific checking + + Returns + ------- + list + The list of device indices that are available for Faceswap to use + """ + devices = [idx for idx in range(self._device_count) if idx not in _EXCLUDE_DEVICES] + self._log("debug", f"Active GPU Devices: {devices}") + return devices + + def _get_handles(self) -> list: + """ Override to obtain GPU specific device handles for all connected devices. + + Returns + ------- + list + The device handle for each connected GPU + """ + raise NotImplementedError() + + def _get_driver(self) -> str: + """ Override to obtain the GPU specific driver version. + + Returns + ------- + str + The GPU driver currently in use + """ + raise NotImplementedError() + + def _get_device_names(self) -> list[str]: + """ Override to obtain the names of all connected GPUs. The quality of this information + depends on the backend and OS being used, but it should be sufficient for identifying + cards. + + Returns + ------- + list + List of device names for connected GPUs as corresponding to the values in + :attr:`_handles` + """ + raise NotImplementedError() + + def _get_vram(self) -> list[int]: + """ Override to obtain the total VRAM in Megabytes for each connected GPU. + + Returns + ------- + list + List of `float`s containing the total amount of VRAM in Megabytes for each + connected GPU as corresponding to the values in :attr:`_handles` + """ + raise NotImplementedError() + + def _get_free_vram(self) -> list[int]: + """ Override to obtain the amount of VRAM that is available, in Megabytes, for each + connected GPU. + + Returns + ------- + list + List of `float`s containing the amount of VRAM available, in Megabytes, for each + connected GPU as corresponding to the values in :attr:`_handles + """ + raise NotImplementedError() + + def get_card_most_free(self) -> BiggestGPUInfo: + """ Obtain statistics for the GPU with the most available free VRAM. + + Returns + ------- + :class:`BiggestGpuInfo` + If a GPU is not detected then the **card_id** is returned as ``-1`` and the amount + of free and total RAM available is fixed to 2048 Megabytes. + """ + if len(self._active_devices) == 0: + retval = BiggestGPUInfo(card_id=-1, + device="No GPU devices found", + free=2048, + total=2048) + else: + free_vram = [self._vram_free[i] for i in self._active_devices] + vram_free = max(free_vram) + card_id = self._active_devices[free_vram.index(vram_free)] + retval = BiggestGPUInfo(card_id=card_id, + device=self._device_names[card_id], + free=vram_free, + total=self._vram[card_id]) + self._log("debug", f"Active GPU Card with most free VRAM: {retval}") + return retval + + def exclude_devices(self, devices: list[int]) -> None: + """ Exclude GPU devices from being used by Faceswap. Override for backend specific logic + + Parameters + ---------- + devices: list[int] + The GPU device IDS to be excluded + """ + raise NotImplementedError diff --git a/lib/gpu_stats/apple_silicon.py b/lib/gpu_stats/apple_silicon.py new file mode 100644 index 0000000000..467cc20819 --- /dev/null +++ b/lib/gpu_stats/apple_silicon.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """ +import typing as T + +import os +import psutil +import torch + +from lib.utils import FaceswapError, get_module_objects + + +from ._base import _GPUStats + + +_metal_initialized: bool = False + + +class AppleSiliconStats(_GPUStats): + """ Holds information and statistics about Apple Silicon SoC(s) available on the currently + running Apple system. + + Notes + ----- + Apple Silicon is a bit different from other backends, as it does not have a dedicated GPU with + it's own dedicated VRAM, rather the RAM is shared with the CPU and GPU. A combination of psutil + and torch are used to pull as much useful information as possible. + + Parameters + ---------- + log: bool, optional + Whether the class should output information to the logger. There may be occasions where the + logger has not yet been set up when this class is queried. Attempting to log in these + instances will raise an error. If GPU stats are being queried prior to the logger being + available then this parameter should be set to ``False``. Otherwise set to ``True``. + Default: ``True`` + """ + def __init__(self, log: bool = True) -> None: + # Following attribute set in :func:``_initialize`` + self._mps_devices: list[T.Any] = [] + + super().__init__(log=log) + + def _initialize(self) -> None: + """ Initialize Metal for Apple Silicon SoC(s). + + If :attr:`_is_initialized` is ``True`` then this function just returns performing no + action. Otherwise :attr:`is_initialized` is set to ``True`` after successfully + initializing Metal. + """ + if self._is_initialized: + return + self._log("debug", "Initializing Metal for Apple Silicon SoC.") + self._initialize_metal() + + self._mps_devices = [torch.device("mps")] + + super()._initialize() + + def _initialize_metal(self) -> None: + """ Initialize Metal on first call to this class and set global + :attr:``_metal_initialized`` to ``True``. If Metal has already been initialized then return + performing no action. + """ + global _metal_initialized # pylint:disable=global-statement + + if _metal_initialized: + return + + self._log("debug", "Performing first time Apple SoC setup.") + + os.environ["DISPLAY"] = ":0" + + try: + os.system("open -a XQuartz") + except Exception as err: # pylint:disable=broad-except + self._log("debug", f"Swallowing error opening XQuartz: {str(err)}") + + self._test_torch() + + _metal_initialized = True + + def _test_torch(self) -> None: + """ Test that torch can execute correctly. + + Raises + ------ + FaceswapError + If the Torch library could not be successfully initialized + """ + try: + meminfo = torch.mps.driver_allocated_memory() + self._log("debug", + f"Torch initialization test: (mem_info: {meminfo})") + except RuntimeError as err: + msg = ("An unhandled exception occured initializing the device via Torch " + f"Library. Original error: {str(err)}") + raise FaceswapError(msg) from err + + def _get_device_count(self) -> int: + """ Detect the number of SoCs attached to the system. + + Returns + ------- + int + The total number of SoCs available + """ + retval = len(self._mps_devices) + self._log("debug", f"GPU Device count: {retval}") + return retval + + def _get_handles(self) -> list: + """ Obtain the device handles for all available Apple Silicon SoCs. + + Notes + ----- + Apple SoC does not use handles, so return a list of indices corresponding to found + GPU devices + + Returns + ------- + list + The list of indices for available Apple Silicon SoCs + """ + handles = list(range(self._device_count)) + self._log("debug", f"GPU Handles found: {handles}") + return handles + + def _get_driver(self) -> str: + """ Obtain the Apple Silicon driver version currently in use. + + Notes + ----- + As the SoC is not a discreet GPU it does not technically have a driver version, so just + return `'Not Applicable'` as a string + + Returns + ------- + str + The current SoC driver version + """ + driver = "Not Applicable" + self._log("debug", f"GPU Driver: {driver}") + return driver + + def _get_device_names(self) -> list[str]: + """ Obtain the list of names of available Apple Silicon SoC(s) as identified in + :attr:`_handles`. + + Returns + ------- + list + The list of available Apple Silicon SoC names + """ + names = [d.type for d in self._mps_devices] + self._log("debug", f"GPU Devices: {names}") + return names + + def _get_vram(self) -> list[int]: + """ Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in + :attr:`_handles`. + + Returns + ------- + list + The RAM in Megabytes for each available Apple Silicon SoC + """ + vram = [int((torch.mps.driver_allocated_memory() / self._device_count) / (1024 * 1024)) + for _ in range(self._device_count)] + self._log("debug", f"SoC RAM: {vram}") + return vram + + def _get_free_vram(self) -> list[int]: + """ Obtain the amount of VRAM that is available, in Megabytes, for each available Apple + Silicon SoC. + + Returns + ------- + list + List of `float`s containing the amount of RAM available, in Megabytes, for each + available SoC as corresponding to the values in :attr:`_handles + """ + vram = [int((psutil.virtual_memory().available / self._device_count) / (1024 * 1024)) + for _ in range(self._device_count)] + self._log("debug", f"SoC RAM free: {vram}") + return vram + + def exclude_devices(self, devices: list[int]) -> None: + """ Apple-Silicon does not support excluding devices + + Parameters + ---------- + devices: list[int] + The GPU device IDS to be excluded + """ + self._log("warning", "Apple Silicon does not support excluding GPUs. This option has been " + "ignored") + + +__all__ = get_module_objects(__name__) diff --git a/lib/gpu_stats/cpu.py b/lib/gpu_stats/cpu.py new file mode 100644 index 0000000000..0a4194ee9b --- /dev/null +++ b/lib/gpu_stats/cpu.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +""" Dummy functions for running faceswap on CPU. """ + +from lib.utils import get_module_objects + +from ._base import _GPUStats + + +class CPUStats(_GPUStats): + """ Holds information and statistics about the CPU on the currently running system. + + Notes + ----- + The information held here is not useful, but _GPUStats is dynamically imported depending on + the backend used, so we need to make sure this class is available for Faceswap run on the CPU + Backend. + + The base :class:`_GPUStats` handles the dummying in of information when no GPU is detected. + + Parameters + ---------- + log: bool, optional + Whether the class should output information to the logger. There may be occasions where the + logger has not yet been set up when this class is queried. Attempting to log in these + instances will raise an error. If GPU stats are being queried prior to the logger being + available then this parameter should be set to ``False``. Otherwise set to ``True``. + Default: ``True`` + """ + + def _get_device_count(self) -> int: + """ Detect the number of GPUs attached to the system. Always returns zero for CPU + backends. + + Returns + ------- + int + The total number of GPUs connected to the PC + """ + retval = 0 + self._log("debug", f"GPU Device count: {retval}") + return retval + + def _get_handles(self) -> list: + """ Obtain the device handles for all connected GPUs. + + Returns + ------- + list + An empty list for CPU Backends + """ + handles: list = [] + self._log("debug", f"GPU Handles found: {len(handles)}") + return handles + + def _get_driver(self) -> str: + """ Obtain the driver version currently in use. + + Returns + ------- + str + An empty string for CPU backends + """ + driver = "" + self._log("debug", f"GPU Driver: {driver}") + return driver + + def _get_device_names(self) -> list[str]: + """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`. + + Returns + ------- + list + An empty list for CPU backends + """ + names: list[str] = [] + self._log("debug", f"GPU Devices: {names}") + return names + + def _get_vram(self) -> list[int]: + """ Obtain the RAM in Megabytes for the running system. + + Returns + ------- + list + An empty list for CPU backends + """ + vram: list[int] = [] + self._log("debug", f"GPU VRAM: {vram}") + return vram + + def _get_free_vram(self) -> list[int]: + """ Obtain the amount of RAM that is available, in Megabytes, for the running system. + + Returns + ------- + list + An empty list for CPU backends + """ + vram: list[int] = [] + self._log("debug", f"GPU VRAM free: {vram}") + return vram + + def exclude_devices(self, devices: list[int]) -> None: + """ CPU does not support excluding devices + + Parameters + ---------- + devices: list[int] + The GPU device IDS to be excluded + """ + self._log("warning", "CPU does not support excluding GPUs. This option has been ignored") + + +__all__ = get_module_objects(__name__) diff --git a/lib/gpu_stats/nvidia.py b/lib/gpu_stats/nvidia.py new file mode 100644 index 0000000000..29f1a872f4 --- /dev/null +++ b/lib/gpu_stats/nvidia.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" Collects and returns Information on available Nvidia GPUs. """ +import os + +import pynvml # pylint:disable=import-error + +from lib.utils import FaceswapError, get_module_objects + +from ._base import _GPUStats, _EXCLUDE_DEVICES + + +class NvidiaStats(_GPUStats): + """ Holds information and statistics about Nvidia GPU(s) available on the currently + running system. + + Notes + ----- + PyNVML is used for hooking in to Nvidia's Machine Learning Library and allows for pulling + fairly extensive statistics for Nvidia GPUs + + Parameters + ---------- + log: bool, optional + Whether the class should output information to the logger. There may be occasions where the + logger has not yet been set up when this class is queried. Attempting to log in these + instances will raise an error. If GPU stats are being queried prior to the logger being + available then this parameter should be set to ``False``. Otherwise set to ``True``. + Default: ``True`` + """ + + def _initialize(self) -> None: + """ Initialize PyNVML for Nvidia GPUs. + + If :attr:`_is_initialized` is ``True`` then this function just returns performing no + action. Otherwise :attr:`is_initialized` is set to ``True`` after successfully + initializing NVML. + + Raises + ------ + FaceswapError + If the NVML library could not be successfully loaded + """ + if self._is_initialized: + return + try: + self._log("debug", "Initializing PyNVML for Nvidia GPU.") + pynvml.nvmlInit() + except (pynvml.NVMLError_LibraryNotFound, # pylint:disable=no-member + pynvml.NVMLError_DriverNotLoaded, # pylint:disable=no-member + pynvml.NVMLError_NoPermission) as err: # pylint:disable=no-member + msg = ("There was an error reading from the Nvidia Machine Learning Library. The most " + "likely cause is incorrectly installed drivers. If this is the case, Please " + "remove and reinstall your Nvidia drivers before reporting. Original " + f"Error: {str(err)}") + raise FaceswapError(msg) from err + except Exception as err: # pylint:disable=broad-except + msg = ("An unhandled exception occured reading from the Nvidia Machine Learning " + f"Library. Original error: {str(err)}") + raise FaceswapError(msg) from err + + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + super()._initialize() + + def _shutdown(self) -> None: + """ Cleanly close access to NVML and set :attr:`_is_initialized` back to ``False``. """ + self._log("debug", "Shutting down NVML") + pynvml.nvmlShutdown() + super()._shutdown() + + def _get_device_count(self) -> int: + """ Detect the number of GPUs attached to the system. + + Returns + ------- + int + The total number of GPUs connected to the PC + """ + try: + retval = pynvml.nvmlDeviceGetCount() + except pynvml.NVMLError as err: + self._log("debug", "Error obtaining device count. Setting to 0. " + f"Original error: {str(err)}") + retval = 0 + self._log("debug", f"GPU Device count: {retval}") + return retval + + def _get_active_devices(self) -> list[int]: + """ Obtain the indices of active GPUs (those that have not been explicitly excluded by + CUDA_VISIBLE_DEVICES environment variable or explicitly excluded in the command line + arguments). + + Returns + ------- + list + The list of device indices that are available for Faceswap to use + """ + # pylint:disable=duplicate-code + devices = super()._get_active_devices() + env_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if env_devices: + new_devices = [int(i) for i in env_devices.split(",")] + devices = [idx for idx in devices if idx in new_devices] + self._log("debug", f"Active GPU Devices: {devices}") + return devices + + def _get_handles(self) -> list: + """ Obtain the device handles for all connected Nvidia GPUs. + + Returns + ------- + list + The list of pointers for connected Nvidia GPUs + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) + for i in range(self._device_count)] + self._log("debug", f"GPU Handles found: {len(handles)}") + return handles + + def _get_driver(self) -> str: + """ Obtain the Nvidia driver version currently in use. + + Returns + ------- + str + The current GPU driver version + """ + try: + driver = pynvml.nvmlSystemGetDriverVersion() + except pynvml.NVMLError as err: + self._log("debug", f"Unable to obtain driver. Original error: {str(err)}") + driver = "No Nvidia driver found" + self._log("debug", f"GPU Driver: {driver}") + return driver + + def _get_device_names(self) -> list[str]: + """ Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`. + + Returns + ------- + list + The list of connected Nvidia GPU names + """ + names = [pynvml.nvmlDeviceGetName(handle) + for handle in self._handles] + self._log("debug", f"GPU Devices: {names}") + return names + + def _get_vram(self) -> list[int]: + """ Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in + :attr:`_handles`. + + Returns + ------- + list + The VRAM in Megabytes for each connected Nvidia GPU + """ + vram = [pynvml.nvmlDeviceGetMemoryInfo(handle).total / (1024 * 1024) + for handle in self._handles] + self._log("debug", f"GPU VRAM: {vram}") + return vram + + def _get_free_vram(self) -> list[int]: + """ Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia + GPU. + + Returns + ------- + list + List of `float`s containing the amount of VRAM available, in Megabytes, for each + connected GPU as corresponding to the values in :attr:`_handles + """ + is_initialized = self._is_initialized + if not is_initialized: + self._initialize() + self._handles = self._get_handles() + + vram = [pynvml.nvmlDeviceGetMemoryInfo(handle).free / (1024 * 1024) + for handle in self._handles] + if not is_initialized: + self._shutdown() + + self._log("debug", f"GPU VRAM free: {vram}") + return vram + + def exclude_devices(self, devices: list[int]) -> None: + """ Exclude GPU devices from being used by Faceswap. Sets the CUDA_VISIBLE_DEVICES + environment variable. This must be called before Torch/Keras are imported + + Parameters + ---------- + devices: list[int] + The GPU device IDS to be excluded + """ + # pylint:disable=duplicate-code + if not devices: + return + self._log("debug", f"Excluding GPU indicies: {devices}") + + _EXCLUDE_DEVICES.extend(devices) + + active = self._get_active_devices() + + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(d) for d in active + if d not in _EXCLUDE_DEVICES) + + env_vars = [f"{k}: {v}" for k, v in os.environ.items() if k.lower().startswith("cuda")] + self._log("debug", f"Cuda environmet variables: {env_vars}") + + +__all__ = get_module_objects(__name__) diff --git a/lib/gpu_stats/rocm.py b/lib/gpu_stats/rocm.py new file mode 100644 index 0000000000..eddbd4a920 --- /dev/null +++ b/lib/gpu_stats/rocm.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python3 +""" Collects and returns Information about connected AMD GPUs for ROCm using sysfs and from +modinfo + +As no ROCm compatible hardware was available for testing, this just returns information on all AMD +GPUs discovered on the system regardless of ROCm compatibility. + +It is a good starting point but may need to be refined over time +""" +import os +import re +from subprocess import run + +from lib.utils import get_module_objects +from ._base import _GPUStats, _EXCLUDE_DEVICES + +_DEVICE_LOOKUP = { # ref: https://gist.github.com/roalercon/51f13a387f3754615cce + int("0x130F", 0): "AMD Radeon(TM) R7 Graphics", + int("0x1313", 0): "AMD Radeon(TM) R7 Graphics", + int("0x1316", 0): "AMD Radeon(TM) R5 Graphics", + int("0x6600", 0): "AMD Radeon HD 8600/8700M", + int("0x6601", 0): "AMD Radeon (TM) HD 8500M/8700M", + int("0x6604", 0): "AMD Radeon R7 M265 Series", + int("0x6605", 0): "AMD Radeon R7 M260 Series", + int("0x6606", 0): "AMD Radeon HD 8790M", + int("0x6607", 0): "AMD Radeon (TM) HD8530M", + int("0x6610", 0): "AMD Radeon HD 8670 Graphics", + int("0x6611", 0): "AMD Radeon HD 8570 Graphics", + int("0x6613", 0): "AMD Radeon R7 200 Series", + int("0x6640", 0): "AMD Radeon HD 8950", + int("0x6658", 0): "AMD Radeon R7 200 Series", + int("0x665C", 0): "AMD Radeon HD 7700 Series", + int("0x665D", 0): "AMD Radeon R7 200 Series", + int("0x6660", 0): "AMD Radeon HD 8600M Series", + int("0x6663", 0): "AMD Radeon HD 8500M Series", + int("0x6664", 0): "AMD Radeon R5 M200 Series", + int("0x6665", 0): "AMD Radeon R5 M230 Series", + int("0x6667", 0): "AMD Radeon R5 M200 Series", + int("0x666F", 0): "AMD Radeon HD 8500M", + int("0x6704", 0): "AMD FirePro V7900 (FireGL V)", + int("0x6707", 0): "AMD FirePro V5900 (FireGL V)", + int("0x6718", 0): "AMD Radeon HD 6900 Series", + int("0x6719", 0): "AMD Radeon HD 6900 Series", + int("0x671D", 0): "AMD Radeon HD 6900 Series", + int("0x671F", 0): "AMD Radeon HD 6900 Series", + int("0x6720", 0): "AMD Radeon HD 6900M Series", + int("0x6738", 0): "AMD Radeon HD 6800 Series", + int("0x6739", 0): "AMD Radeon HD 6800 Series", + int("0x673E", 0): "AMD Radeon HD 6700 Series", + int("0x6740", 0): "AMD Radeon HD 6700M Series", + int("0x6741", 0): "AMD Radeon 6600M and 6700M Series", + int("0x6742", 0): "AMD Radeon HD 5570", + int("0x6743", 0): "AMD Radeon E6760", + int("0x6749", 0): "AMD FirePro V4900 (FireGL V)", + int("0x674A", 0): "AMD FirePro V3900 (ATI FireGL)", + int("0x6750", 0): "AMD Radeon HD 6500 series", + int("0x6751", 0): "AMD Radeon HD 7600A Series", + int("0x6758", 0): "AMD Radeon HD 6670", + int("0x6759", 0): "AMD Radeon HD 6570 Graphics", + int("0x675B", 0): "AMD Radeon HD 7600 Series", + int("0x675D", 0): "AMD Radeon HD 7500 Series", + int("0x675F", 0): "AMD Radeon HD 5500 Series", + int("0x6760", 0): "AMD Radeon HD 6400M Series", + int("0x6761", 0): "AMD Radeon HD 6430M", + int("0x6763", 0): "AMD Radeon E6460", + int("0x6770", 0): "AMD Radeon HD 6400 Series", + int("0x6771", 0): "AMD Radeon R5 235X", + int("0x6772", 0): "AMD Radeon HD 7400A Series", + int("0x6778", 0): "AMD Radeon HD 7000 series", + int("0x6779", 0): "AMD Radeon HD 6450", + int("0x677B", 0): "AMD Radeon HD 7400 Series", + int("0x6780", 0): "AMD FirePro W9000 (FireGL V)", + int("0x678A", 0): "AMD FirePro S10000 (FireGL V)", + int("0x6798", 0): "AMD Radeon HD 7900 Series", + int("0x679A", 0): "AMD Radeon HD 7900 Series", + int("0x679B", 0): "AMD Radeon HD 7900 Series", + int("0x679E", 0): "AMD Radeon HD 7800 Series", + int("0x67B0", 0): "AMD Radeon R9 200 Series", + int("0x67B1", 0): "AMD Radeon R9 200 Series", + int("0x6800", 0): "AMD Radeon HD 7970M", + int("0x6801", 0): "AMD Radeon(TM) HD8970M", + int("0x6808", 0): "AMD FirePro S7000 (FireGL V)", + int("0x6809", 0): "AMD FirePro R5000 (FireGL V)", + int("0x6810", 0): "AMD Radeon R9 200 Series", + int("0x6811", 0): "AMD Radeon R9 200 Series", + int("0x6818", 0): "AMD Radeon HD 7800 Series", + int("0x6819", 0): "AMD Radeon HD 7800 Series", + int("0x6820", 0): "AMD Radeon HD 8800M Series", + int("0x6821", 0): "AMD Radeon HD 8800M Series", + int("0x6822", 0): "AMD Radeon E8860", + int("0x6823", 0): "AMD Radeon HD 8800M Series", + int("0x6825", 0): "AMD Radeon HD 7800M Series", + int("0x6827", 0): "AMD Radeon HD 7800M Series", + int("0x6828", 0): "AMD FirePro W600", + int("0x682B", 0): "AMD Radeon HD 8800M Series", + int("0x682D", 0): "AMD Radeon HD 7700M Series", + int("0x682F", 0): "AMD Radeon HD 7700M Series", + int("0x6835", 0): "AMD Radeon R7 Series / HD 9000 Series", + int("0x6837", 0): "AMD Radeon HD 6570", + int("0x683D", 0): "AMD Radeon HD 7700 Series", + int("0x683F", 0): "AMD Radeon HD 7700 Series", + int("0x6840", 0): "AMD Radeon HD 7600M Series", + int("0x6841", 0): "AMD Radeon HD 7500M/7600M Series", + int("0x6842", 0): "AMD Radeon HD 7000M Series", + int("0x6843", 0): "AMD Radeon HD 7670M", + int("0x6858", 0): "AMD Radeon HD 7400 Series", + int("0x6859", 0): "AMD Radeon HD 7400 Series", + int("0x6888", 0): "ATI FirePro V8800 (FireGL V)", + int("0x6889", 0): "ATI FirePro V7800 (FireGL V)", + int("0x688A", 0): "ATI FirePro V9800 (FireGL V)", + int("0x688C", 0): "AMD FireStream 9370", + int("0x688D", 0): "AMD FireStream 9350", + int("0x6898", 0): "AMD Radeon HD 5800 Series", + int("0x6899", 0): "AMD Radeon HD 5800 Series", + int("0x689B", 0): "AMD Radeon HD 6800 Series", + int("0x689C", 0): "AMD Radeon HD 5900 Series", + int("0x689E", 0): "AMD Radeon HD 5800 Series", + int("0x68A0", 0): "AMD Mobility Radeon HD 5800 Series", + int("0x68A1", 0): "AMD Mobility Radeon HD 5800 Series", + int("0x68A8", 0): "AMD Radeon HD 6800M Series", + int("0x68A9", 0): "ATI FirePro V5800 (FireGL V)", + int("0x68B8", 0): "AMD Radeon HD 5700 Series", + int("0x68B9", 0): "AMD Radeon HD 5600/5700", + int("0x68BA", 0): "AMD Radeon HD 6700 Series", + int("0x68BE", 0): "AMD Radeon HD 5700 Series", + int("0x68BF", 0): "AMD Radeon HD 6700 Green Edition", + int("0x68C0", 0): "AMD Mobility Radeon HD 5000", + int("0x68C1", 0): "AMD Mobility Radeon HD 5000 Series", + int("0x68C7", 0): "AMD Mobility Radeon HD 5570", + int("0x68C8", 0): "ATI FirePro V4800 (FireGL V)", + int("0x68C9", 0): "ATI FirePro 3800 (FireGL) Graphics Adapter", + int("0x68D8", 0): "AMD Radeon HD 5670", + int("0x68D9", 0): "AMD Radeon HD 5570", + int("0x68DA", 0): "AMD Radeon HD 5500 Series", + int("0x68E0", 0): "AMD Mobility Radeon HD 5000 Series", + int("0x68E1", 0): "AMD Mobility Radeon HD 5000 Series", + int("0x68E4", 0): "AMD Radeon HD 5450", + int("0x68E5", 0): "AMD Radeon HD 6300M Series", + int("0x68F1", 0): "AMD FirePro 2460", + int("0x68F2", 0): "AMD FirePro 2270 (ATI FireGL)", + int("0x68F9", 0): "AMD Radeon HD 5450", + int("0x68FA", 0): "AMD Radeon HD 7300 Series", + int("0x9640", 0): "AMD Radeon HD 6550D", + int("0x9641", 0): "AMD Radeon HD 6620G", + int("0x9642", 0): "AMD Radeon HD 6370D", + int("0x9643", 0): "AMD Radeon HD 6380G", + int("0x9644", 0): "AMD Radeon HD 6410D", + int("0x9645", 0): "AMD Radeon HD 6410D", + int("0x9647", 0): "AMD Radeon HD 6520G", + int("0x9648", 0): "AMD Radeon HD 6480G", + int("0x9649", 0): "AMD Radeon(TM) HD 6480G", + int("0x964A", 0): "AMD Radeon HD 6530D", + int("0x9802", 0): "AMD Radeon HD 6310 Graphics", + int("0x9803", 0): "AMD Radeon HD 6250 Graphics", + int("0x9804", 0): "AMD Radeon HD 6250 Graphics", + int("0x9805", 0): "AMD Radeon HD 6250 Graphics", + int("0x9806", 0): "AMD Radeon HD 6320 Graphics", + int("0x9807", 0): "AMD Radeon HD 6290 Graphics", + int("0x9808", 0): "AMD Radeon HD 7340 Graphics", + int("0x9809", 0): "AMD Radeon HD 7310 Graphics", + int("0x980A", 0): "AMD Radeon HD 7290 Graphics", + int("0x9830", 0): "AMD Radeon HD 8400", + int("0x9831", 0): "AMD Radeon(TM) HD 8400E", + int("0x9832", 0): "AMD Radeon HD 8330", + int("0x9833", 0): "AMD Radeon(TM) HD 8330E", + int("0x9834", 0): "AMD Radeon HD 8210", + int("0x9835", 0): "AMD Radeon(TM) HD 8210E", + int("0x9836", 0): "AMD Radeon HD 8280", + int("0x9837", 0): "AMD Radeon(TM) HD 8280E", + int("0x9838", 0): "AMD Radeon HD 8240", + int("0x9839", 0): "AMD Radeon HD 8180", + int("0x983D", 0): "AMD Radeon HD 8250", + int("0x9900", 0): "AMD Radeon HD 7660G", + int("0x9901", 0): "AMD Radeon HD 7660D", + int("0x9903", 0): "AMD Radeon HD 7640G", + int("0x9904", 0): "AMD Radeon HD 7560D", + int("0x9906", 0): "AMD FirePro A300 Series (FireGL V) Graphics Adapter", + int("0x9907", 0): "AMD Radeon HD 7620G", + int("0x9908", 0): "AMD Radeon HD 7600G", + int("0x990A", 0): "AMD Radeon HD 7500G", + int("0x990B", 0): "AMD Radeon HD 8650G", + int("0x990C", 0): "AMD Radeon HD 8670D", + int("0x990D", 0): "AMD Radeon HD 8550G", + int("0x990E", 0): "AMD Radeon HD 8570D", + int("0x990F", 0): "AMD Radeon HD 8610G", + int("0x9910", 0): "AMD Radeon HD 7660G", + int("0x9913", 0): "AMD Radeon HD 7640G", + int("0x9917", 0): "AMD Radeon HD 7620G", + int("0x9918", 0): "AMD Radeon HD 7600G", + int("0x9919", 0): "AMD Radeon HD 7500G", + int("0x9990", 0): "AMD Radeon HD 7520G", + int("0x9991", 0): "AMD Radeon HD 7540D", + int("0x9992", 0): "AMD Radeon HD 7420G", + int("0x9993", 0): "AMD Radeon HD 7480D", + int("0x9994", 0): "AMD Radeon HD 7400G", + int("0x9995", 0): "AMD Radeon HD 8450G", + int("0x9996", 0): "AMD Radeon HD 8470D", + int("0x9997", 0): "AMD Radeon HD 8350G", + int("0x9998", 0): "AMD Radeon HD 8370D", + int("0x9999", 0): "AMD Radeon HD 8510G", + int("0x999A", 0): "AMD Radeon HD 8410G", + int("0x999B", 0): "AMD Radeon HD 8310G", + int("0x999C", 0): "AMD Radeon HD 8650D", + int("0x999D", 0): "AMD Radeon HD 8550D", + int("0x99A0", 0): "AMD Radeon HD 7520G", + int("0x99A2", 0): "AMD Radeon HD 7420G", + int("0x99A4", 0): "AMD Radeon HD 7400G"} + + +class ROCm(_GPUStats): + """ Holds information and statistics about GPUs connected using sysfs + + Parameters + ---------- + log: bool, optional + Whether the class should output information to the logger. There may be occasions where the + logger has not yet been set up when this class is queried. Attempting to log in these + instances will raise an error. If GPU stats are being queried prior to the logger being + available then this parameter should be set to ``False``. Otherwise set to ``True``. + Default: ``True`` + """ + def __init__(self, log: bool = True) -> None: + self._vendor_id = "0x1002" # AMD VendorID + self._sysfs_paths: list[str] = [] + super().__init__(log=log) + + def _from_sysfs_file(self, path: str) -> str: + """ Obtain the value from a sysfs file. On permission error or file doesn't exist, log and + return empty value + + Parameters + ---------- + path: str + The path to a sysfs file to obtain the value from + + Returns + ------- + str + The obtained value from the given path + """ + if not os.path.isfile(path): + self._log("debug", f"File '{path}' does not exist. Returning empty string") + return "" + try: + with open(path, "r", encoding="utf-8", errors="ignore") as sysfile: + val = sysfile.read().strip() + except PermissionError: + self._log("debug", f"Permission error accessing file '{path}'. Returning empty string") + val = "" + return val + + def _get_sysfs_paths(self) -> list[str]: + """ Obtain a list of sysfs paths to AMD branded GPUs connected to the system + + Returns + ------- + list[str] + List of full paths to the sysfs entries for connected AMD GPUs + """ + base_dir = "/sys/class/drm/" + + retval: list[str] = [] + if not os.path.exists(base_dir): + self._log("warning", f"sysfs not found at '{base_dir}'") + return retval + + for folder in sorted(os.listdir(base_dir)): + folder_path = os.path.join(base_dir, folder, "device") + vendor_path = os.path.join(folder_path, "vendor") + if not os.path.isdir(vendor_path) and not re.match(r"^card\d+$", folder): + self._log("debug", f"skipping path '{folder_path}'") + continue + + vendor_id = self._from_sysfs_file(vendor_path) + if vendor_id != self._vendor_id: + self._log("debug", f"Skipping non AMD Vendor '{vendor_id}' for device: '{folder}'") + continue + + retval.append(folder_path) + + self._log("debug", f"sysfs AMD devices: {retval}") + return retval + + def _initialize(self) -> None: + """ Initialize sysfs for ROCm backend. + + If :attr:`_is_initialized` is ``True`` then this function just returns performing no + action. + + if ``False`` then the location of AMD cards within sysfs is collected + """ + if self._is_initialized: + return + self._log("debug", "Initializing sysfs for AMDGPU (ROCm).") + self._sysfs_paths = self._get_sysfs_paths() + super()._initialize() + + def _get_device_count(self) -> int: + """ The number of AMD cards found in sysfs + + Returns + ------- + int + The total number of GPUs available + """ + retval = len(self._sysfs_paths) + self._log("debug", f"GPU Device count: {retval}") + return retval + + def _get_handles(self) -> list: + """ The sysfs doesn't use device handles, so we just return the list of the sysfs locations + per card + + Returns + ------- + list + The list of all discovered GPUs + """ + handles = self._sysfs_paths + self._log("debug", f"sysfs GPU Handles found: {handles}") + return handles + + def _get_driver(self) -> str: + """ Obtain the driver versions currently in use from modinfo + + Returns + ------- + str + The current AMDGPU driver versions + """ + retval = "" + cmd = ["modinfo", "amdgpu"] + try: + proc = run(cmd, + check=True, + timeout=5, + capture_output=True, + encoding="utf-8", + errors="ignore") + for line in proc.stdout.split("\n"): + if line.startswith("version:"): + retval = line.split()[-1] + break + except Exception as err: # pylint:disable=broad-except + self._log("debug", f"Error reading modinfo: '{str(err)}'") + + self._log("debug", f"GPU Drivers: {retval}") + return retval + + def _get_device_names(self) -> list[str]: + """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`. + + Returns + ------- + list + The list of connected AMD GPU names + """ + retval = [] + for device in self._sysfs_paths: + name = self._from_sysfs_file(os.path.join(device, "product_name")) + number = self._from_sysfs_file(os.path.join(device, "product_number")) + if name or number: # product_name or product_number populated + self._log("debug", f"Got name from product_name: '{name}', product_number: " + f"'{number}'") + retval.append(f"{name + ' ' if name else ''}{number}") + continue + + device_id = self._from_sysfs_file(os.path.join(device, "device")) + self._log("debug", f"Got device_id: '{device_id}'") + + if not device_id: # Can't get device name + retval.append("Not found") + continue + try: + lookup = int(device_id, 0) + except ValueError: + retval.append(device_id) + continue + + device_name = _DEVICE_LOOKUP.get(lookup, device_id) + retval.append(device_name) + + self._log("debug", f"Device names: {retval}") + return retval + + def _get_active_devices(self) -> list[int]: + """ Obtain the indices of active GPUs (those that have not been explicitly excluded by + HIP_VISIBLE_DEVICES environment variable or explicitly excluded in the command line + arguments). + + Returns + ------- + list + The list of device indices that are available for Faceswap to use + """ + devices = super()._get_active_devices() + env_devices = os.environ.get("HIP_VISIBLE_DEVICES ") + if env_devices: + new_devices = [int(i) for i in env_devices.split(",")] + devices = [idx for idx in devices if idx in new_devices] + self._log("debug", f"Active GPU Devices: {devices}") + return devices + + def _get_vram(self) -> list[int]: + """ Obtain the VRAM in Megabytes for each connected AMD GPU as identified in + :attr:`_handles`. + + Returns + ------- + list + The VRAM in Megabytes for each connected Nvidia GPU + """ + retval = [] + for device in self._sysfs_paths: + query = self._from_sysfs_file(os.path.join(device, "mem_info_vram_total")) + try: + vram = int(query) + except ValueError: + self._log("debug", f"Couldn't extract VRAM from string: '{query}'", ) + vram = 0 + retval.append(int(vram / (1024 * 1024))) + + self._log("debug", f"GPU VRAM: {retval}") + return retval + + def _get_free_vram(self) -> list[int]: + """ Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD + GPU. + + Returns + ------- + list + List of `float`s containing the amount of VRAM available, in Megabytes, for each + connected GPU as corresponding to the values in :attr:`_handles + """ + retval = [] + total_vram = self._get_vram() + for device, vram in zip(self._sysfs_paths, total_vram): + if not vram: + retval.append(0) + continue + query = self._from_sysfs_file(os.path.join(device, "mem_info_vram_used")) + try: + used = int(query) + except ValueError: + self._log("debug", f"Couldn't extract used VRAM from string: '{query}'") + used = 0 + + retval.append(vram - int(used / (1024 * 1024))) + self._log("debug", f"GPU VRAM free: {retval}") + return retval + + def exclude_devices(self, devices: list[int]) -> None: + """ Exclude GPU devices from being used by Faceswap. Sets the HIP_VISIBLE_DEVICES + environment variable. This must be called before Torch/Keras are imported + + Parameters + ---------- + devices: list[int] + The GPU device IDS to be excluded + """ + if not devices: + return + self._log("debug", f"Excluding GPU indicies: {devices}") + + _EXCLUDE_DEVICES.extend(devices) + + active = self._get_active_devices() + + os.environ["HIP_VISIBLE_DEVICES"] = ",".join(str(d) for d in active + if d not in _EXCLUDE_DEVICES) + + env_vars = [f"{k}: {v}" for k, v in os.environ.items() if k.lower().startswith("hip")] + self._log("debug", f"HIP environmet variables: {env_vars}") + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/.cache/icons/LICENSE.md b/lib/gui/.cache/icons/LICENSE.md new file mode 100644 index 0000000000..31157e1ea9 --- /dev/null +++ b/lib/gui/.cache/icons/LICENSE.md @@ -0,0 +1,3 @@ +Icons made by [smashicons](https://www.flaticon.com/authors/smashicons) from [www.flaticon.com](www.flaticon.com) + +Colorized and adapted by @torzdf \ No newline at end of file diff --git a/lib/gui/.cache/icons/beginning.png b/lib/gui/.cache/icons/beginning.png new file mode 100755 index 0000000000..a9fdb1f788 Binary files /dev/null and b/lib/gui/.cache/icons/beginning.png differ diff --git a/lib/gui/.cache/icons/boundingbox.png b/lib/gui/.cache/icons/boundingbox.png new file mode 100755 index 0000000000..1863fcb4d7 Binary files /dev/null and b/lib/gui/.cache/icons/boundingbox.png differ diff --git a/lib/gui/.cache/icons/clear.png b/lib/gui/.cache/icons/clear.png index 0f2c6364a9..551de5259f 100755 Binary files a/lib/gui/.cache/icons/clear.png and b/lib/gui/.cache/icons/clear.png differ diff --git a/lib/gui/.cache/icons/clear2.png b/lib/gui/.cache/icons/clear2.png new file mode 100644 index 0000000000..f7e5826ca8 Binary files /dev/null and b/lib/gui/.cache/icons/clear2.png differ diff --git a/lib/gui/.cache/icons/context.png b/lib/gui/.cache/icons/context.png new file mode 100644 index 0000000000..4354c1bc99 Binary files /dev/null and b/lib/gui/.cache/icons/context.png differ diff --git a/lib/gui/.cache/icons/copy_next.png b/lib/gui/.cache/icons/copy_next.png new file mode 100755 index 0000000000..e6df6fc7ad Binary files /dev/null and b/lib/gui/.cache/icons/copy_next.png differ diff --git a/lib/gui/.cache/icons/copy_prev.png b/lib/gui/.cache/icons/copy_prev.png new file mode 100755 index 0000000000..41b84b305b Binary files /dev/null and b/lib/gui/.cache/icons/copy_prev.png differ diff --git a/lib/gui/.cache/icons/draw.png b/lib/gui/.cache/icons/draw.png new file mode 100755 index 0000000000..c79809bf42 Binary files /dev/null and b/lib/gui/.cache/icons/draw.png differ diff --git a/lib/gui/.cache/icons/end.png b/lib/gui/.cache/icons/end.png new file mode 100755 index 0000000000..c79ee55ebd Binary files /dev/null and b/lib/gui/.cache/icons/end.png differ diff --git a/lib/gui/.cache/icons/erase.png b/lib/gui/.cache/icons/erase.png new file mode 100755 index 0000000000..113e4b8bc4 Binary files /dev/null and b/lib/gui/.cache/icons/erase.png differ diff --git a/lib/gui/.cache/icons/extractbox.png b/lib/gui/.cache/icons/extractbox.png new file mode 100755 index 0000000000..bd82f0c54a Binary files /dev/null and b/lib/gui/.cache/icons/extractbox.png differ diff --git a/lib/gui/.cache/icons/favicon.png b/lib/gui/.cache/icons/favicon.png new file mode 100644 index 0000000000..4c8f094327 Binary files /dev/null and b/lib/gui/.cache/icons/favicon.png differ diff --git a/lib/gui/.cache/icons/folder.png b/lib/gui/.cache/icons/folder.png new file mode 100644 index 0000000000..e2c1be628a Binary files /dev/null and b/lib/gui/.cache/icons/folder.png differ diff --git a/lib/gui/.cache/icons/generate.png b/lib/gui/.cache/icons/generate.png new file mode 100644 index 0000000000..d5cc9270f8 Binary files /dev/null and b/lib/gui/.cache/icons/generate.png differ diff --git a/lib/gui/.cache/icons/graph.png b/lib/gui/.cache/icons/graph.png old mode 100755 new mode 100644 index 7056a9a0a9..2b514fd6d4 Binary files a/lib/gui/.cache/icons/graph.png and b/lib/gui/.cache/icons/graph.png differ diff --git a/lib/gui/.cache/icons/landmarks.png b/lib/gui/.cache/icons/landmarks.png new file mode 100755 index 0000000000..a098e1317d Binary files /dev/null and b/lib/gui/.cache/icons/landmarks.png differ diff --git a/lib/gui/.cache/icons/load.png b/lib/gui/.cache/icons/load.png new file mode 100644 index 0000000000..d94d224061 Binary files /dev/null and b/lib/gui/.cache/icons/load.png differ diff --git a/lib/gui/.cache/icons/load2.png b/lib/gui/.cache/icons/load2.png new file mode 100644 index 0000000000..31d4bc0fca Binary files /dev/null and b/lib/gui/.cache/icons/load2.png differ diff --git a/lib/gui/.cache/icons/logo.png b/lib/gui/.cache/icons/logo.png index c1b65812c0..4c8f094327 100755 Binary files a/lib/gui/.cache/icons/logo.png and b/lib/gui/.cache/icons/logo.png differ diff --git a/lib/gui/.cache/icons/mask.png b/lib/gui/.cache/icons/mask.png new file mode 100755 index 0000000000..ffdc2fa1f3 Binary files /dev/null and b/lib/gui/.cache/icons/mask.png differ diff --git a/lib/gui/.cache/icons/mask2.png b/lib/gui/.cache/icons/mask2.png new file mode 100644 index 0000000000..ca6440b66e Binary files /dev/null and b/lib/gui/.cache/icons/mask2.png differ diff --git a/lib/gui/.cache/icons/model.png b/lib/gui/.cache/icons/model.png new file mode 100644 index 0000000000..fd4f356d6f Binary files /dev/null and b/lib/gui/.cache/icons/model.png differ diff --git a/lib/gui/.cache/icons/move.png b/lib/gui/.cache/icons/move.png index 8fb918a725..afccc85852 100755 Binary files a/lib/gui/.cache/icons/move.png and b/lib/gui/.cache/icons/move.png differ diff --git a/lib/gui/.cache/icons/multi_load.png b/lib/gui/.cache/icons/multi_load.png new file mode 100644 index 0000000000..94f648e031 Binary files /dev/null and b/lib/gui/.cache/icons/multi_load.png differ diff --git a/lib/gui/.cache/icons/new.png b/lib/gui/.cache/icons/new.png new file mode 100644 index 0000000000..51e298336d Binary files /dev/null and b/lib/gui/.cache/icons/new.png differ diff --git a/lib/gui/.cache/icons/next.png b/lib/gui/.cache/icons/next.png new file mode 100755 index 0000000000..47d55783a6 Binary files /dev/null and b/lib/gui/.cache/icons/next.png differ diff --git a/lib/gui/.cache/icons/open_file.png b/lib/gui/.cache/icons/open_file.png deleted file mode 100755 index e91a27b603..0000000000 Binary files a/lib/gui/.cache/icons/open_file.png and /dev/null differ diff --git a/lib/gui/.cache/icons/open_folder.png b/lib/gui/.cache/icons/open_folder.png deleted file mode 100755 index 8e4b2aa69d..0000000000 Binary files a/lib/gui/.cache/icons/open_folder.png and /dev/null differ diff --git a/lib/gui/.cache/icons/pause.png b/lib/gui/.cache/icons/pause.png new file mode 100755 index 0000000000..c10c1933e0 Binary files /dev/null and b/lib/gui/.cache/icons/pause.png differ diff --git a/lib/gui/.cache/icons/picture.png b/lib/gui/.cache/icons/picture.png new file mode 100644 index 0000000000..e0bbafd5d3 Binary files /dev/null and b/lib/gui/.cache/icons/picture.png differ diff --git a/lib/gui/.cache/icons/play.png b/lib/gui/.cache/icons/play.png new file mode 100755 index 0000000000..225f777b76 Binary files /dev/null and b/lib/gui/.cache/icons/play.png differ diff --git a/lib/gui/.cache/icons/prev.png b/lib/gui/.cache/icons/prev.png new file mode 100755 index 0000000000..f5387956b5 Binary files /dev/null and b/lib/gui/.cache/icons/prev.png differ diff --git a/lib/gui/.cache/icons/reload.png b/lib/gui/.cache/icons/reload.png new file mode 100644 index 0000000000..1677233c7f Binary files /dev/null and b/lib/gui/.cache/icons/reload.png differ diff --git a/lib/gui/.cache/icons/reload2.png b/lib/gui/.cache/icons/reload2.png new file mode 100644 index 0000000000..b10fefe757 Binary files /dev/null and b/lib/gui/.cache/icons/reload2.png differ diff --git a/lib/gui/.cache/icons/reload3.png b/lib/gui/.cache/icons/reload3.png new file mode 100755 index 0000000000..5526f5565c Binary files /dev/null and b/lib/gui/.cache/icons/reload3.png differ diff --git a/lib/gui/.cache/icons/reset.png b/lib/gui/.cache/icons/reset.png deleted file mode 100755 index bc5cd44bd6..0000000000 Binary files a/lib/gui/.cache/icons/reset.png and /dev/null differ diff --git a/lib/gui/.cache/icons/save.png b/lib/gui/.cache/icons/save.png index 97dc732531..25b764a93a 100755 Binary files a/lib/gui/.cache/icons/save.png and b/lib/gui/.cache/icons/save.png differ diff --git a/lib/gui/.cache/icons/save2.png b/lib/gui/.cache/icons/save2.png new file mode 100644 index 0000000000..cfc285066c Binary files /dev/null and b/lib/gui/.cache/icons/save2.png differ diff --git a/lib/gui/.cache/icons/save_as.png b/lib/gui/.cache/icons/save_as.png new file mode 100644 index 0000000000..99d014eff9 Binary files /dev/null and b/lib/gui/.cache/icons/save_as.png differ diff --git a/lib/gui/.cache/icons/save_as2.png b/lib/gui/.cache/icons/save_as2.png new file mode 100644 index 0000000000..07cf750ad4 Binary files /dev/null and b/lib/gui/.cache/icons/save_as2.png differ diff --git a/lib/gui/.cache/icons/settings.png b/lib/gui/.cache/icons/settings.png new file mode 100644 index 0000000000..874fe42cb0 Binary files /dev/null and b/lib/gui/.cache/icons/settings.png differ diff --git a/lib/gui/.cache/icons/settings_convert.png b/lib/gui/.cache/icons/settings_convert.png new file mode 100644 index 0000000000..5e4a3472dd Binary files /dev/null and b/lib/gui/.cache/icons/settings_convert.png differ diff --git a/lib/gui/.cache/icons/settings_extract.png b/lib/gui/.cache/icons/settings_extract.png new file mode 100644 index 0000000000..eeab735e4a Binary files /dev/null and b/lib/gui/.cache/icons/settings_extract.png differ diff --git a/lib/gui/.cache/icons/settings_train.png b/lib/gui/.cache/icons/settings_train.png new file mode 100644 index 0000000000..66d100eba1 Binary files /dev/null and b/lib/gui/.cache/icons/settings_train.png differ diff --git a/lib/gui/.cache/icons/start.png b/lib/gui/.cache/icons/start.png new file mode 100644 index 0000000000..5923d35d92 Binary files /dev/null and b/lib/gui/.cache/icons/start.png differ diff --git a/lib/gui/.cache/icons/stop.png b/lib/gui/.cache/icons/stop.png new file mode 100644 index 0000000000..ee7e590ac2 Binary files /dev/null and b/lib/gui/.cache/icons/stop.png differ diff --git a/lib/gui/.cache/icons/video.png b/lib/gui/.cache/icons/video.png new file mode 100644 index 0000000000..1851c2a0ff Binary files /dev/null and b/lib/gui/.cache/icons/video.png differ diff --git a/lib/gui/.cache/icons/view.png b/lib/gui/.cache/icons/view.png new file mode 100755 index 0000000000..879ad44502 Binary files /dev/null and b/lib/gui/.cache/icons/view.png differ diff --git a/lib/gui/.cache/icons/zoom.png b/lib/gui/.cache/icons/zoom.png index c2a5653c38..fd71e07a9b 100755 Binary files a/lib/gui/.cache/icons/zoom.png and b/lib/gui/.cache/icons/zoom.png differ diff --git a/lib/gui/.cache/presets/convert/.keep b/lib/gui/.cache/presets/convert/.keep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/gui/.cache/presets/extract/.keep b/lib/gui/.cache/presets/extract/.keep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/gui/.cache/presets/gui/.keep b/lib/gui/.cache/presets/gui/.keep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/gui/.cache/presets/train/model_phaze_a_clipfaker128_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_clipfaker128_preset.json new file mode 100644 index 0000000000..0afede8da7 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_clipfaker128_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 128, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "clipv_farl-b-16-64", + "enc_scaling": 29, + "enc_load_weights": true, + "bottleneck_type": "flatten", + "bottleneck_norm": "none", + "bottleneck_size": 1024, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 1024, + "fc_max_filters": 1024, + "fc_dimensions": 4, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 64, + "dec_max_filters": 512, + "dec_slope_mode": "cap_min", + "dec_filter_slope": 0.5, + "dec_res_blocks": 1, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": true, + "freeze_layers": "keras_encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_clipfaker256_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_clipfaker256_preset.json new file mode 100644 index 0000000000..974b614b97 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_clipfaker256_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 256, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "clipv_farl-b-16-64", + "enc_scaling": 58, + "enc_load_weights": true, + "bottleneck_type": "flatten", + "bottleneck_norm": "none", + "bottleneck_size": 1024, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 1024, + "fc_max_filters": 1024, + "fc_dimensions": 4, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 64, + "dec_max_filters": 1024, + "dec_slope_mode": "cap_min", + "dec_filter_slope": 0.5, + "dec_res_blocks": 1, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": true, + "freeze_layers": "keras_encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_clipfaker448_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_clipfaker448_preset.json new file mode 100644 index 0000000000..59bfedce6b --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_clipfaker448_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 448, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "clipv_farl-b-16-64", + "enc_scaling": 100, + "enc_load_weights": true, + "bottleneck_type": "flatten", + "bottleneck_norm": "none", + "bottleneck_size": 1024, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 384, + "fc_max_filters": 384, + "fc_dimensions": 7, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 1024, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 64, + "dec_max_filters": 1024, + "dec_slope_mode": "cap_min", + "dec_filter_slope": 0.5, + "dec_res_blocks": 1, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": true, + "freeze_layers": "keras_encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dfaker_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dfaker_preset.json new file mode 100644 index 0000000000..2d9803c3dd --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dfaker_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 128, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 7, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 1024, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 1024, + "fc_max_filters": 1024, + "fc_dimensions": 4, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 64, + "dec_max_filters": 512, + "dec_slope_mode": "full", + "dec_filter_slope": -0.45, + "dec_res_blocks": 1, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": true, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dfl-h128_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dfl-h128_preset.json new file mode 100644 index 0000000000..69c879a501 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dfl-h128_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 128, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 13, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 8, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 128, + "dec_max_filters": 512, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 0, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dfl-sae-df_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dfl-sae-df_preset.json new file mode 100644 index 0000000000..af9ec50b5f --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dfl-sae-df_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 128, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 13, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 8, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 128, + "dec_max_filters": 504, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 2, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 126, + "fs_original_max_filters": 1008, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dfl-sae-liae_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dfl-sae-liae_preset.json new file mode 100644 index 0000000000..c1d9901201 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dfl-sae-liae_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 128, + "shared_fc": "half", + "enable_gblock": false, + "split_fc": true, + "split_gblock": false, + "split_decoders": false, + "enc_architecture": "fs_original", + "enc_scaling": 13, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 256, + "bottleneck_in_encoder": false, + "fc_depth": 1, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 8, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 128, + "dec_max_filters": 504, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 2, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 126, + "fs_original_max_filters": 1008, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dfl-saehd-df_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dfl-saehd-df_preset.json new file mode 100644 index 0000000000..df94b9180f --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dfl-saehd-df_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 128, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 13, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 256, + "bottleneck_in_encoder": false, + "fc_depth": 1, + "fc_min_filters": 256, + "fc_max_filters": 256, + "fc_dimensions": 8, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 256, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 128, + "dec_max_filters": 512, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 1, + "dec_output_kernel": 1, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 64, + "fs_original_max_filters": 512, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dfl-saehd-liae_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dfl-saehd-liae_preset.json new file mode 100644 index 0000000000..b43a33e431 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dfl-saehd-liae_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 128, + "shared_fc": "half", + "enable_gblock": false, + "split_fc": true, + "split_gblock": false, + "split_decoders": false, + "enc_architecture": "fs_original", + "enc_scaling": 13, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 256, + "bottleneck_in_encoder": false, + "fc_depth": 1, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 8, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 128, + "dec_max_filters": 512, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 1, + "dec_output_kernel": 1, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 64, + "fs_original_max_filters": 512, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dny1024_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dny1024_preset.json new file mode 100644 index 0000000000..161d53336c --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dny1024_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 1024, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 100, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 0, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 1, + "fc_filter_slope": 0.0, + "fc_dropout": 0.0, + "fc_upsampler": "upsample2d", + "fc_upsamples": 2, + "fc_upsample_filters": 128, + "fc_gblock_depth": 1, + "fc_gblock_min_nodes": 128, + "fc_gblock_max_nodes": 128, + "fc_gblock_filter_slope": 0.0, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "upscale_dny", + "dec_upscales_in_fc": 2, + "dec_norm": "none", + "dec_min_filters": 16, + "dec_max_filters": 512, + "dec_slope_mode": "cap_max", + "dec_filter_slope": 0.5, + "dec_res_blocks": 0, + "dec_output_kernel": 1, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 9, + "fs_original_min_filters": 16, + "fs_original_max_filters": 512, + "fs_original_use_alt": true, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dny256_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dny256_preset.json new file mode 100644 index 0000000000..e19e61dcf7 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dny256_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 256, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 25, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 0, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 1, + "fc_filter_slope": 0.0, + "fc_dropout": 0.0, + "fc_upsampler": "upsample2d", + "fc_upsamples": 2, + "fc_upsample_filters": 128, + "fc_gblock_depth": 1, + "fc_gblock_min_nodes": 128, + "fc_gblock_max_nodes": 128, + "fc_gblock_filter_slope": 0.0, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "upscale_dny", + "dec_upscales_in_fc": 1, + "dec_norm": "none", + "dec_min_filters": 16, + "dec_max_filters": 512, + "dec_slope_mode": "cap_max", + "dec_filter_slope": 0.5, + "dec_res_blocks": 0, + "dec_output_kernel": 1, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 7, + "fs_original_min_filters": 16, + "fs_original_max_filters": 512, + "fs_original_use_alt": true, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_dny512_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_dny512_preset.json new file mode 100644 index 0000000000..9e0534d5f9 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_dny512_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 512, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 50, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 0, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 1, + "fc_filter_slope": 0.0, + "fc_dropout": 0.0, + "fc_upsampler": "upsample2d", + "fc_upsamples": 2, + "fc_upsample_filters": 128, + "fc_gblock_depth": 1, + "fc_gblock_min_nodes": 128, + "fc_gblock_max_nodes": 128, + "fc_gblock_filter_slope": 0.0, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "upscale_dny", + "dec_upscales_in_fc": 2, + "dec_norm": "none", + "dec_min_filters": 16, + "dec_max_filters": 512, + "dec_slope_mode": "cap_max", + "dec_filter_slope": 0.5, + "dec_res_blocks": 0, + "dec_output_kernel": 1, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 8, + "fs_original_min_filters": 16, + "fs_original_max_filters": 512, + "fs_original_use_alt": true, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_iae_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_iae_preset.json new file mode 100644 index 0000000000..304fc70eeb --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_iae_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 64, + "shared_fc": "full", + "enable_gblock": false, + "split_fc": true, + "split_gblock": false, + "split_decoders": false, + "enc_architecture": "fs_original", + "enc_scaling": 7, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 1024, + "bottleneck_in_encoder": false, + "fc_depth": 1, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 4, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 0, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 64, + "dec_max_filters": 512, + "dec_slope_mode": "full", + "dec_filter_slope": -0.45, + "dec_res_blocks": 0, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_lightweight_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_lightweight_preset.json new file mode 100644 index 0000000000..f47590f5c5 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_lightweight_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 64, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 7, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 512, + "fc_max_filters": 512, + "fc_dimensions": 4, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 256, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 128, + "dec_max_filters": 512, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 0, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 3, + "fs_original_min_filters": 128, + "fs_original_max_filters": 512, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_original_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_original_preset.json new file mode 100644 index 0000000000..a4efe963dc --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_original_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 64, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "fs_original", + "enc_scaling": 7, + "enc_load_weights": false, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 1024, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 1024, + "fc_max_filters": 1024, + "fc_dimensions": 4, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 1, + "fc_upsample_filters": 512, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "subpixel", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 64, + "dec_max_filters": 256, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 0, + "dec_output_kernel": 5, + "dec_gaussian": false, + "dec_skip_last_residual": false, + "freeze_layers": "encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_stojo_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_stojo_preset.json new file mode 100644 index 0000000000..479e322568 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_stojo_preset.json @@ -0,0 +1,51 @@ +{ + "output_size": 256, + "shared_fc": "none", + "enable_gblock": true, + "split_fc": true, + "split_gblock": false, + "split_decoders": false, + "enc_architecture": "efficientnet_b4", + "enc_scaling": 60, + "enc_load_weights": true, + "bottleneck_type": "dense", + "bottleneck_norm": "none", + "bottleneck_size": 512, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 1280, + "fc_max_filters": 1280, + "fc_dimensions": 8, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "upsample2d", + "fc_upsamples": 1, + "fc_upsample_filters": 1280, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "resize_images", + "dec_upscales_in_fc": 0, + "dec_norm": "none", + "dec_min_filters": 160, + "dec_max_filters": 640, + "dec_slope_mode": "full", + "dec_filter_slope": -0.33, + "dec_res_blocks": 1, + "dec_output_kernel": 3, + "dec_gaussian": true, + "dec_skip_last_residual": false, + "freeze_layers": "keras_encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} \ No newline at end of file diff --git a/lib/gui/.cache/presets/train/model_phaze_a_sym384_preset.json b/lib/gui/.cache/presets/train/model_phaze_a_sym384_preset.json new file mode 100644 index 0000000000..e836a856f8 --- /dev/null +++ b/lib/gui/.cache/presets/train/model_phaze_a_sym384_preset.json @@ -0,0 +1,52 @@ +{ + "output_size": 384, + "shared_fc": "none", + "enable_gblock": false, + "split_fc": false, + "split_gblock": false, + "split_decoders": true, + "enc_architecture": "efficientnet_v2_s", + "enc_scaling": 100, + "enc_load_weights": true, + "bottleneck_type": "max_pooling", + "bottleneck_norm": "none", + "bottleneck_size": 1280, + "bottleneck_in_encoder": true, + "fc_depth": 1, + "fc_min_filters": 1536, + "fc_max_filters": 1536, + "fc_dimensions": 3, + "fc_filter_slope": -0.5, + "fc_dropout": 0.0, + "fc_upsampler": "subpixel", + "fc_upsamples": 0, + "fc_upsample_filters": 1280, + "fc_gblock_depth": 3, + "fc_gblock_min_nodes": 512, + "fc_gblock_max_nodes": 512, + "fc_gblock_filter_slope": -0.5, + "fc_gblock_dropout": 0.0, + "dec_upscale_method": "upscale_dny", + "dec_upscales_in_fc": 2, + "dec_norm": "none", + "dec_min_filters": 24, + "dec_max_filters": 1536, + "dec_slope_mode": "cap_max", + "dec_filter_slope": 0.5, + "dec_res_blocks": 1, + "dec_output_kernel": 3, + "dec_gaussian": true, + "dec_skip_last_residual": true, + "freeze_layers": "keras_encoder", + "load_layers": "encoder", + "fs_original_depth": 4, + "fs_original_min_filters": 128, + "fs_original_max_filters": 1024, + "fs_original_use_alt": false, + "mobilenet_width": 1.0, + "mobilenet_depth": 1, + "mobilenet_dropout": 0.001, + "mobilenet_minimalistic": false, + "__filetype": "faceswap_preset", + "__section": "train|model|phaze_a" +} diff --git a/lib/gui/.cache/themes/default.json b/lib/gui/.cache/themes/default.json new file mode 100644 index 0000000000..1f41610268 --- /dev/null +++ b/lib/gui/.cache/themes/default.json @@ -0,0 +1,139 @@ +{ + "info": "Initial default theme configuration whilst migrating from default ttk OS widgets", + "group_panel": { + "info": { + "info1": "The 'group_panel' section are any section which contains items for user input, such as the left hand options panel in the main GUI or the Settings pop-up", + "info2": "Anything which uses a 'group_panel' will use the theme specified here as default. Panels can be overriden (see below).", + + "panel_background": "The background color of the main panel that holds all of the group options.", + + "info_color": "The background color of the information header box at the top of each control panel", + "info_font": "The color of the font inside the information header box at the top of each control panel", + "info_border": "The color of the border around the outside of the information header box at the top of each control panel", + + "header_color": "The color to use for the option group boxes header backgrounds, the group box border and for labels on options groups.", + "header_font": "The color to use for the option group boxes header font.", + "group_background": "This is the color used for the background of each group of options, as well as the background color used for any label which resides inside a group box", + "group_font": "The font color used inside each group box for labels", + + "control_color": "The color of controls (e.g. Slider knob, combo pull-down arrow, scrollbar slider + arrows etc.)", + "control_active": "Selected/hovered over color of controls (e.g. Slider knob, combo pull-down arrow, scrollbar slider + arrows etc.)", + "control_disabled": "The color of controls when they are disabled (specifically scrollbars when there is no page to scroll).", + + "input_color": "The background color of input boxes (e.g. text entry)", + "input_font": "The font color of input boxes (e.g. text entry)", + "button_background": "The background color of buttons", + + "scrollbar_border": "Border color of scrollbar", + "scrollbar_trough": "Trough color of scrollbar" + }, + "panel_background": "#CDD3D5", + + "info_color": "#FFFFFF", + "info_font": "#000000", + "info_border": "#000000", + + "header_color": "#176087", + "header_font": "#FFFFFF", + "group_background": "#FFFFFF", + "group_border": "#176087", + "group_font": "#000000", + + "control_color": "#75929C", + "control_active": "#176087", + "control_disabled": "#CDD3D5", + + "input_color": "#FFFFFF", + "input_font": "#000000", + "button_background": "#FFFFFF", + + "scrollbar_border": "#176087", + "scrollbar_trough": "#CDD3D5" + }, + "group_settings": { + "info": { + "info1": "Override default colors for the settings pop-up. See 'group_panel' for allowable options", + "info2": "Options same as 'group_panel' with the following additions:", + + "tree_select": "The color of the selected item in the left hand nav frame", + "link_color": "The color of links on pages where there are no configuration options" + }, + "panel_background": "#DAD2D8", + + "header_color": "#9B1D20", + "group_border": "#9B1D20", + + "control_color": "#B090A8", + "control_active": "#9B1D20", + "control_disabled": "#DAD2D8", + + "scrollbar_border": "#9B1D20", + "scrollbar_trough": "#DAD2D8", + + "tree_select": "#9B1D20", + "link_color": "#9B1D20" + }, + "command_tabs": { + "frame_border": "#176087", + "tab_color": "#CDD3D5", + "tab_selected": "#75929C", + "tab_hover": "#176087" + }, + "console": { + "info": { + "info1": "The colors of the console output box", + + "background_color": "The background color of the console output", + "border_color": "The color of the border around the console box and scrollbar", + + "stdout_color": "The text color for standard print message output (non Faceswap Logging messages)", + "stderr_color": "The text color for messages that are printed to sterr (non Faceswap Logging messages)", + "info_color": "The text color for Faceswap INFO log messages", + "verbose_color": "The text color for Faceswap VERBOSE log messages", + "warning_color": "The text color for Faceswap WARNING log messages", + "critical_color": "The text color for Faceswap CRITICAL log messages", + "error_color": "The text color for Faceswap ERROR log messages", + + "scrollbar_border": "The color of the overall scrollbar border", + "scrollbar_trough": "The color of the scrollbar trough", + + "scrollbar_background_": "The main color of the up/down buttons and the slider of the scrollbar, for active (pressed/hovered), normal and disabled (no scrollbar required)", + "scrollbar_foreground_": "The foreground color for the up/down buttons of the scrollbar, for active (pressed/hovered), normal and disabled (no scrollbar required)", + "scrollbar_border_": "The border color of the up/down buttons and the slider of the scrollbar, for active (pressed/hovered), normal and disabled (no scrollbar required)" + }, + "background_color": "#CDD3D5", + "border_color": "#176087", + + "stdout_color": "#172c87", + "stderr_color": "#78162f", + "info_color": "#176087", + "verbose_color": "#1D9B32", + "warning_color": "#9B701D", + "critical_color": "#9B381D", + "error_color": "#9B381D", + + "scrollbar_border": "#176087", + "scrollbar_trough": "#CDD3D5", + "scrollbar_background_normal": "#75929C", + "scrollbar_background_disabled": "#CDD3D5", + "scrollbar_background_active": "#176087", + "scrollbar_foreground_normal": "#CDD3D5", + "scrollbar_foreground_disabled": "#75929C", + "scrollbar_foreground_active": "#CDD3D5", + "scrollbar_border_normal": "#176087", + "scrollbar_border_disabled": "#75929C", + "scrollbar_border_active": "#176087" + }, + "tooltip": { + "info": { + "info1": "The colors of the tool-tip pop ups", + + "background_color": "Tool-tip background color", + "border_color": "Tool-tip border color", + "font_color": "Tool-tip font color" + }, + "background_color": "#FFFFEA", + "border_color": "#FFFFEA", + "font_color": "#000000" + } +} diff --git a/lib/gui/__init__.py b/lib/gui/__init__.py index dca41f2ae0..22697f72e6 100644 --- a/lib/gui/__init__.py +++ b/lib/gui/__init__.py @@ -1,9 +1,12 @@ +#!/usr/bin python3 +""" The Faceswap GUI """ + from lib.gui.command import CommandNotebook +from lib.gui.custom_widgets import ConsoleOut, StatusBar from lib.gui.display import DisplayNotebook from lib.gui.options import CliOptions -from lib.gui.menu import MainMenuBar -from lib.gui.popup_configure import popup_config -from lib.gui.stats import Session -from lib.gui.statusbar import StatusBar -from lib.gui.utils import ConsoleOut, get_config, get_images, initialize_config, initialize_images +from lib.gui.menu import MainMenuBar, TaskBar +from lib.gui.project import LastSession +from lib.gui.utils import (get_config, get_images, initialize_config, initialize_images, + preview_trigger) from lib.gui.wrapper import ProcessWrapper diff --git a/lib/gui/analysis/__init__.py b/lib/gui/analysis/__init__.py new file mode 100644 index 0000000000..ed1b38c142 --- /dev/null +++ b/lib/gui/analysis/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Methods for querying and compiling statistical data for the Faceswap GUI Analysis tab. """ + +from .stats import Calculations, _SESSION as Session # noqa diff --git a/lib/gui/analysis/event_reader.py b/lib/gui/analysis/event_reader.py new file mode 100644 index 0000000000..1b348c69bf --- /dev/null +++ b/lib/gui/analysis/event_reader.py @@ -0,0 +1,832 @@ +#!/usr/bin/env python3 +""" Handles the loading and collation of events from Tensorboard event log files. """ +from __future__ import annotations +import logging +import os +import re +import typing as T +import zlib + +from dataclasses import dataclass, field + +import numpy as np +from tensorboard.compat.proto import event_pb2 # type:ignore[import-untyped] + +from lib.logger import parse_class_init +from lib.serializer import get_serializer +from lib.training.tensorboard import RecordIterator +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from collections.abc import Generator, Iterator + +logger = logging.getLogger(__name__) + + +@dataclass +class EventData: + """ Holds data collected from Tensorboard Event Files + + Parameters + ---------- + timestamp: float + The timestamp of the event step (iteration) + loss: list[float] + The loss values collected for A and B sides for the event step + """ + timestamp: float = 0.0 + loss: list[float] = field(default_factory=list) + + +class _LogFiles(): + """ Holds the filenames of the Tensorboard Event logs that require parsing. + + Parameters + ---------- + logs_folder: str + The folder that contains the Tensorboard log files + """ + def __init__(self, logs_folder: str) -> None: + logger.debug(parse_class_init(locals())) + self._logs_folder = logs_folder + self._filenames = self._get_log_filenames() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def session_ids(self) -> list[int]: + """ list[int]: Sorted list of `ints` of available session ids. """ + return list(sorted(self._filenames)) + + def _get_log_filenames(self) -> dict[int, str]: + """ Get the Tensorboard event filenames for all existing sessions. + + Returns + ------- + dict[int, str] + The full path of each log file for each training session id that has been run + """ + logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder) + retval: dict[int, str] = {} + for dirpath, _, filenames in os.walk(self._logs_folder): + if not any(filename.startswith("events.out.tfevents") for filename in filenames): + continue + session_id = self._get_session_id(dirpath) + if session_id is None: + logger.warning("Unable to load session data for model") + return retval + retval[session_id] = self._get_log_filename(dirpath, filenames) + logger.debug("logfiles: %s", retval) + return retval + + @classmethod + def _get_session_id(cls, folder: str) -> int | None: + """ Obtain the session id for the given folder. + + Parameters + ---------- + folder: str + The full path to the folder that contains the session's Tensorboard Event Log + + Returns + ------- + int or ``None`` + The session ID for the given folder. If no session id can be determined, return + ``None`` + """ + session = os.path.split(os.path.split(folder)[0])[1] + session_id = session[session.rfind("_") + 1:] + retval = None if not session_id.isdigit() else int(session_id) + logger.debug("folder: '%s', session_id: %s", folder, retval) + return retval + + @classmethod + def _get_log_filename(cls, folder: str, filenames: list[str]) -> str: + """ Obtain the session log file for the given folder. If multiple log files exist for the + given folder, then the most recent log file is used, as earlier files are assumed to be + obsolete. + + Parameters + ---------- + folder: str + The full path to the folder that contains the session's Tensorboard Event Log + filenames: list[str] + List of filenames that exist within the given folder + + Returns + ------- + str + The full path of the selected log file + """ + logfiles = [fname for fname in filenames if fname.startswith("events.out.tfevents")] + retval = os.path.join(folder, sorted(logfiles)[-1]) # Take last item if multi matches + logger.debug("logfiles: %s, selected: '%s'", logfiles, retval) + return retval + + def refresh(self) -> bool: + """ Refresh the list of log filenames. + + Returns + ------- + bool + ``True`` if the pre-existing log files are a subset of the new log files, otherwise + ``False`` + """ + logger.debug("Refreshing log filenames") + old_filenames = self._filenames + new_filenames = self._get_log_filenames() + retval = set(old_filenames.values()).issubset(set(new_filenames.values())) + self._filenames = new_filenames + logger.debug("old filenames are %sa subset of new filenames %s", + "" if retval else "not ", self._filenames) + return retval + + def get(self, session_id: int) -> str: + """ Obtain the log filename for the given session id. + + Parameters + ---------- + session_id: int + The session id to obtain the log filename for + + Returns + ------- + str + The full path to the log file for the requested session id + """ + retval = self._filenames.get(session_id, "") + logger.debug("session_id: %s, log_filename: '%s'", session_id, retval) + return retval + + +class _CacheData(): + """ Holds cached data that has been retrieved from Tensorboard Event Files and is compressed + in memory for a single or live training session + + Parameters + ---------- + labels: list[str] + The labels for the loss values + timestamps: :class:`np.ndarray` + The timestamp of the event step (iteration) + loss: :class:`np.ndarray` + The loss values collected for A and B sides for the session + """ + def __init__(self, labels: list[str], timestamps: np.ndarray, loss: np.ndarray) -> None: + self.labels = labels + self._loss = zlib.compress(T.cast(bytes, loss)) + self._timestamps = zlib.compress(T.cast(bytes, timestamps)) + self._timestamps_shape = timestamps.shape + self._loss_shape = loss.shape + + @property + def loss(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The loss values for this session """ + retval: np.ndarray = np.frombuffer(zlib.decompress(self._loss), dtype="float32") + if len(self._loss_shape) > 1: + retval = retval.reshape(-1, *self._loss_shape[1:]) + return retval + + @property + def timestamps(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The timestamps for this session """ + retval: np.ndarray = np.frombuffer(zlib.decompress(self._timestamps), dtype="float64") + if len(self._timestamps_shape) > 1: + retval = retval.reshape(-1, *self._timestamps_shape[1:]) + return retval + + def add_live_data(self, timestamps: np.ndarray, loss: np.ndarray) -> None: + """ Add live data to the end of the stored data + + loss: :class:`numpy.ndarray` + The latest loss values to add to the cache + timestamps: :class:`numpy.ndarray` + The latest timestamps to add to the cache + """ + new_buffer: list[bytes] = [] + new_shapes: list[tuple[int, ...]] = [] + for data, buffer, dtype, shape in zip([timestamps, loss], + [self._timestamps, self._loss], + ["float64", "float32"], + [self._timestamps_shape, self._loss_shape]): + + old = np.frombuffer(zlib.decompress(buffer), dtype=dtype) + if data.ndim > 1: + old = old.reshape(-1, *data.shape[1:]) + + new = np.concatenate((old, data)) + + logger.debug("old_shape: %s new_shape: %s", shape, new.shape) + new_buffer.append(zlib.compress(new)) + new_shapes.append(new.shape) + del old + + self._timestamps = new_buffer[0] + self._loss = new_buffer[1] + self._timestamps_shape = new_shapes[0] + self._loss_shape = new_shapes[1] + + +class _Cache(): + """ Holds parsed Tensorboard log event data in a compressed cache in memory. """ + def __init__(self) -> None: + logger.debug(parse_class_init(locals())) + self._data: dict[int, _CacheData] = {} + self._carry_over: dict[int, EventData] = {} + self._loss_labels: list[str] = [] + logger.debug("Initialized %s", self.__class__.__name__) + + def is_cached(self, session_id: int) -> bool: + """ Check if the given session_id's data is already cached + + Parameters + ---------- + session_id: int + The session ID to check + + Returns + ------- + bool + ``True`` if the data already exists in the cache otherwise ``False``. + """ + return self._data.get(session_id) is not None + + def cache_data(self, + session_id: int, + data: dict[int, EventData], + labels: list[str], + is_live: bool = False) -> None: + """ Add a full session's worth of event data to :attr:`_data`. + + Parameters + ---------- + session_id: int + The session id to add the data for + data[int, :class:`EventData`] + The extracted event data dictionary generated from :class:`_EventParser` + labels: list[str] + List of `str` for the labels of each loss value output + is_live: bool, optional + ``True`` if the data to be cached is from a live training session otherwise ``False``. + Default: ``False`` + """ + logger.debug("Caching event data: (session_id: %s, labels: %s, data points: %s, " + "is_live: %s)", session_id, labels, len(data), is_live) + + if labels: + logger.debug("Setting loss labels: %s", labels) + self._loss_labels = labels + + if not data: + logger.debug("No data to cache") + return + + timestamps, loss = self._to_numpy(data, is_live) + + if not is_live or (is_live and not self._data.get(session_id)): + self._data[session_id] = _CacheData(self._loss_labels, timestamps, loss) + else: + self._add_latest_live(session_id, loss, timestamps) + + def _to_numpy(self, + data: dict[int, EventData], + is_live: bool) -> tuple[np.ndarray, np.ndarray]: + """ Extract each individual step data into separate numpy arrays for loss and timestamps. + + Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays + are returned at the length of the shortest available data (i.e. truncated records are + dropped) + + Parameters + ---------- + data: dict + The incoming Tensorboard event data in dictionary form per step + is_live: bool, optional + ``True`` if the data to be cached is from a live training session otherwise ``False``. + Default: ``False`` + + Returns + ------- + timestamps: :class:`numpy.ndarray` + float64 array of all iteration's timestamps + loss: :class:`numpy.ndarray` + float32 array of all iteration's loss + """ + if is_live and self._carry_over: + logger.debug("Processing carry over: %s", self._carry_over) + self._collect_carry_over(data) + + times, loss = self._process_data(data, is_live) + + if is_live and not all(len(val) == len(self._loss_labels) for val in loss): + # TODO Many attempts have been made to fix this for live graph logging, and the issue + # of non-consistent loss record sizes keeps coming up. In the meantime we shall swallow + # any loss values that are of incorrect length so graph remains functional. This will, + # most likely, lead to a mismatch on iteration count so a proper fix should be + # implemented. + + # Timestamps and loss appears to remain consistent with each other, but sometimes loss + # appears non-consistent. eg (lengths): + # [2, 2, 2, 2, 2, 2, 2, 0] - last loss collection has zero length + # [1, 2, 2, 2, 2, 2, 2, 2] - 1st loss collection has 1 length + # [2, 2, 2, 3, 2, 2, 2] - 4th loss collection has 3 length + + logger.debug("Inconsistent loss found in collection: %s", loss) + for idx in reversed(range(len(loss))): + if len(loss[idx]) != len(self._loss_labels): + logger.debug("Removing loss/timestamps at position %s", idx) + del loss[idx] + del times[idx] + + n_times, n_loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32")) + logger.debug("Converted to numpy: (data points: %s, timestamps shape: %s, loss shape: %s)", + len(data), n_times.shape, n_loss.shape) + + return n_times, n_loss + + def _collect_carry_over(self, data: dict[int, EventData]) -> None: + """ For live data, collect carried over data from the previous update and merge into the + current data dictionary. + + Parameters + ---------- + data: dict[int, :class:`EventData`] + The latest raw data dictionary + """ + logger.debug("Carry over keys: %s, data keys: %s", list(self._carry_over), list(data)) + for key in list(self._carry_over): + if key not in data: + logger.debug("Carry over found for item %s which does not exist in current " + "data: %s. Skipping.", key, list(data)) + continue + carry_over = self._carry_over.pop(key) + update = data[key] + logger.debug("Merging carry over data: %s in to %s", carry_over, update) + timestamp = update.timestamp + update.timestamp = carry_over.timestamp if not timestamp else timestamp + update.loss = carry_over.loss + update.loss + logger.debug("Merged carry over data: %s", update) + + def _process_data(self, + data: dict[int, EventData], + is_live: bool) -> tuple[list[float], list[list[float]]]: + """ Process live update data. + + Live data requires different processing as often we will only have partial data for the + current step, so we need to cache carried over partial data to be picked up at the next + query. In addition to this, if training is unexpectedly interrupted, there may also be + partial data which needs to be cleansed prior to creating a numpy array + + Parameters + ---------- + data: dict + The incoming Tensorboard event data in dictionary form per step + is_live: bool + ``True`` if the data to be cached is from a live training session otherwise ``False``. + + Returns + ------- + timestamps: tuple + Cleaned list of complete timestamps for the latest live query + loss: list + Cleaned list of complete loss for the latest live query + """ + timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss) + for idx in sorted(data)]) + + l_loss: list[list[float]] = list(loss) + l_timestamps: list[float] = list(timestamps) + + if len(l_loss[-1]) != len(self._loss_labels): + logger.debug("Truncated loss found. loss count: %s", len(l_loss)) + idx = sorted(data)[-1] + if is_live: + logger.debug("Setting carried over data: %s", data[idx]) + self._carry_over[idx] = data[idx] + logger.debug("Removing truncated loss: (timestamp: %s, loss: %s)", + l_timestamps[-1], loss[-1]) + del l_loss[-1] + del l_timestamps[-1] + + return l_timestamps, l_loss + + def _add_latest_live(self, session_id: int, loss: np.ndarray, timestamps: np.ndarray) -> None: + """ Append the latest received live training data to the cached data. + + Parameters + ---------- + session_id: int + The training session ID to update the cache for + loss: :class:`numpy.ndarray` + The latest loss values returned from the iterator + timestamps: :class:`numpy.ndarray` + The latest time stamps returned from the iterator + """ + logger.debug("Adding live data to cache: (session_id: %s, loss: %s, timestamps: %s)", + session_id, loss.shape, timestamps.shape) + if not np.any(loss) and not np.any(timestamps): + return + + self._data[session_id].add_live_data(timestamps, loss) + + def get_data(self, session_id: int, metric: T.Literal["loss", "timestamps"] + ) -> dict[int, dict[str, np.ndarray | list[str]]] | None: + """ Retrieve the decompressed cached data from the cache for the given session id. + + Parameters + ---------- + session_id: int or ``None`` + If session_id is provided, then the cached data for that session is returned. If + session_id is ``None`` then the cached data for all sessions is returned + metric: ['loss', 'timestamps'] + The metric to return the data for. + + Returns + ------- + dict or ``None`` + The `session_id`(s) as key, the values are a dictionary containing the requested + metric information for each session returned. ``None`` if no data is stored for the + given session_id + """ + if session_id is None: + raw = self._data + else: + data = self._data.get(session_id) + if not data: + return None + raw = {session_id: data} + + retval: dict[int, dict[str, np.ndarray | list[str]]] = {} + for idx, data in raw.items(): + array = data.loss if metric == "loss" else data.timestamps + val: dict[str, np.ndarray | list[str]] = {str(metric): array} + if metric == "loss": + val["labels"] = data.labels + retval[idx] = val + + logger.debug("Obtained cached data: %s", + {session_id: {k: v.shape if isinstance(v, np.ndarray) else v + for k, v in data.items()} + for session_id, data in retval.items()}) + return retval + + def reset(self) -> None: + """ Remove all information stored within the cache and reset to default """ + logger.debug("Resetting cache") + del self._data + del self._carry_over + del self._loss_labels + self._data = {} + self._carry_over = {} + self._loss_labels = [] + + +class TensorBoardLogs(): + """ Parse data from TensorBoard logs. + + Process the input logs folder and stores the individual filenames per session. + + Caches timestamp and loss data on request and returns this data from the cache. + + Parameters + ---------- + logs_folder: str + The folder that contains the Tensorboard log files + is_training: bool + ``True`` if the events are being read whilst Faceswap is training otherwise ``False`` + """ + def __init__(self, logs_folder: str, is_training: bool) -> None: + logger.debug(parse_class_init(locals())) + self._is_training = False + self._training_iterator: RecordIterator | None = None + + self._log_files = _LogFiles(logs_folder) + self.set_training(is_training) + + self._cache = _Cache() + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def session_ids(self) -> list[int]: + """ list[int]: Sorted list of integers of available session ids. """ + return self._log_files.session_ids + + def set_training(self, is_training: bool) -> bool: + """ Set the internal training flag to the given `is_training` value. + + If a new training session is being instigated, refresh the log filenames + + Parameters + ---------- + is_training: bool + ``True`` to indicate that the logs to be read are from the currently training + session otherwise ``False`` + + Returns + ------- + bool + ``True`` if the session that is starting training belongs to the session already loaded + otherwise ``False`` + """ + retval = True + if self._is_training == is_training: + logger.debug("Training flag already set to %s. Returning", is_training) + return retval + + logger.debug("Setting is_training to %s", is_training) + self._is_training = is_training + if is_training: + retval = self._log_files.refresh() + if not retval: + self._cache.reset() + log_file = self._log_files.get(self.session_ids[-1]) + logger.debug("Setting training iterator for log file: '%s'", log_file) + self._training_iterator = RecordIterator(log_file, is_live=True) + else: + logger.debug("Removing training iterator") + del self._training_iterator + self._training_iterator = None + return retval + + def _cache_data(self, session_id: int) -> None: + """ Cache TensorBoard logs for the given session ID on first access. + + Populates :attr:`_cache` with timestamps and loss data. + + If this is a training session and the data is being queried for the training session ID + then get the latest available data and append to the cache + + Parameters + ------- + session_id: int + The session ID to cache the data for + """ + live_data = self._is_training and session_id == max(self.session_ids) + iterator = self._training_iterator if live_data else RecordIterator( + self._log_files.get(session_id)) + assert iterator is not None + parser = _EventParser(iterator, self._cache, live_data) + parser.cache_events(session_id) + + def _check_cache(self, session_id: int | None = None) -> None: + """ Check if the given session_id has been cached and if not, cache it. + + Parameters + ---------- + session_id: int, optional + The Session ID to return the data for. Set to ``None`` to return all session + data. Default ``None` + """ + if session_id is not None and not self._cache.is_cached(session_id): + self._cache_data(session_id) + elif self._is_training and session_id == self.session_ids[-1]: + self._cache_data(session_id) + elif session_id is None: + for idx in self.session_ids: + if not self._cache.is_cached(idx): + self._cache_data(idx) + + def get_loss(self, session_id: int | None = None) -> dict[int, dict[str, np.ndarray]]: + """ Read the loss from the TensorBoard event logs + + Parameters + ---------- + session_id: int, optional + The Session ID to return the loss for. Set to ``None`` to return all session + losses. Default ``None`` + + Returns + ------- + dict + The session id(s) as key, with a further dictionary as value containing the loss name + and list of loss values for each step + """ + logger.debug("Getting loss: (session_id: %s)", session_id) + retval: dict[int, dict[str, np.ndarray]] = {} + for idx in [session_id] if session_id else self.session_ids: + self._check_cache(idx) + full_data = self._cache.get_data(idx, "loss") + if not full_data: + continue + data = full_data[idx] + loss = data["loss"] + assert isinstance(loss, np.ndarray) + retval[idx] = {title: loss[:, idx] for idx, title in enumerate(data["labels"])} + + logger.debug({key: {k: v.shape for k, v in val.items()} + for key, val in retval.items()}) + return retval + + def get_timestamps(self, session_id: int | None = None) -> dict[int, np.ndarray]: + """ Read the timestamps from the TensorBoard logs. + + As loss timestamps are slightly different for each loss, we collect the timestamp from the + `batch_loss` key. + + Parameters + ---------- + session_id: int, optional + The Session ID to return the timestamps for. Set to ``None`` to return all session + timestamps. Default ``None`` + + Returns + ------- + dict + The session id(s) as key with list of timestamps per step as value + """ + + logger.debug("Getting timestamps: (session_id: %s, is_training: %s)", + session_id, self._is_training) + retval: dict[int, np.ndarray] = {} + for idx in [session_id] if session_id else self.session_ids: + self._check_cache(idx) + data = self._cache.get_data(idx, "timestamps") + if not data: + continue + timestamps = data[idx]["timestamps"] + assert isinstance(timestamps, np.ndarray) + retval[idx] = timestamps + logger.debug({k: v.shape for k, v in retval.items()}) + return retval + + +class _EventParser(): + """ Parses Tensorboard event and populates data to :class:`_Cache`. + + Parameters + ---------- + iterator: :class:`lib.training.tensorboard.RecordIterator` + The iterator to use for reading Tensorboard event logs + cache: :class:`_Cache` + The cache object to store the collected parsed events to + live_data: bool + ``True`` if the iterator to be loaded is a training iterator for reading live data + otherwise ``False`` + """ + def __init__(self, iterator: Iterator[bytes], cache: _Cache, live_data: bool) -> None: + logger.debug(parse_class_init(locals())) + self._live_data = live_data + self._cache = cache + self._iterator = self._get_latest_live(iterator) if live_data else iterator + self._loss_labels: list[str] = [] + self._num_strip = re.compile(r"_\d+$") + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _get_latest_live(cls, iterator: Iterator[bytes]) -> Generator[bytes, None, None]: + """ Obtain the latest event logs for live training data. + + The live data iterator remains open so that it can be re-queried + + Parameters + ---------- + iterator: :class:`lib.training.tensorboard.RecordIterator` + The live training iterator to use for reading Tensorboard event logs + + Yields + ------ + dict + A Tensorboard event in dictionary form for a single step + """ + i = 0 + while True: + try: + yield next(iterator) + i += 1 + except StopIteration: + logger.debug("End of data reached") + break + logger.debug("Collected %s records from live log file", i) + + def cache_events(self, session_id: int) -> None: + """ Parse the Tensorboard events logs and add to :attr:`_cache`. + + Parameters + ---------- + session_id: int + The session id that the data is being cached for + """ + assert self._iterator is not None + data: dict[int, EventData] = {} + for record in self._iterator: + event = event_pb2.Event.FromString(record) # pylint:disable=no-member + if not event.summary.value: + continue + if event.summary.value[0].tag.split("/", maxsplit=1)[0] == "keras": + self._parse_outputs(event) + if event.summary.value[0].tag.startswith("batch_"): + data[event.step] = self._process_event(event, + data.get(event.step, EventData())) + + self._cache.cache_data(session_id, data, self._loss_labels, is_live=self._live_data) + + def _parse_outputs(self, event: event_pb2.Event) -> None: + """ Parse the outputs from the stored model structure for mapping loss names to + model outputs. + + Loss names are added to :attr:`_loss_labels` + + Notes + ----- + The master model does not actually contain the specified output name, so we dig into the + sub-model to obtain the name of the output layers + + Parameters + ---------- + event: :class:`tensorboard.compat.proto.event_pb2` + The event data containing the keras model structure to be parsed + """ + serializer = get_serializer("json") + structure = event.summary.value[0].tensor.string_val[0] + + config = serializer.unmarshal(structure)["config"] + model_outputs = self._get_outputs(config, False) + + for side_outputs, side in zip(model_outputs, ("a", "b")): + logger.debug("side: '%s', outputs: %s", side, side_outputs) + layer_name = side_outputs[0][0] + + output_config = next(layer for layer in config["layers"] + if layer["name"] == layer_name)["config"] + layer_outputs = self._get_outputs(output_config, True) + logger.debug("Layer name: %s, layer_outputs: %s", layer_name, layer_outputs) + for output in layer_outputs[0]: # Drill into sub-model to get the actual output names + logger.debug("Parsing output: %s", output) + loss_name = self._num_strip.sub("", output[0]) # strip trailing numbers + if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output + new_name = f"{loss_name.replace('_both', '')}_{side}" + logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name) + loss_name = new_name + if loss_name not in self._loss_labels: + logger.debug("Adding loss name: '%s'", loss_name) + self._loss_labels.append(loss_name) + logger.debug("Collated loss labels: %s", self._loss_labels) + + @classmethod + def _get_outputs(cls, model_config: dict[str, T.Any], is_sub_model: bool) -> np.ndarray: + """ Obtain the output names, instance index and output index for the given model. + + If there is only a single output, the shape of the array is expanded to remain consistent + with multi model outputs + + Parameters + ---------- + model_config: dict + The saved Keras model configuration dictionary + is_sub_model: bool + ``True`` if the model_config is for a sub-model. ``False`` if it is for the main + faceswap model. + + Returns + ------- + :class:`numpy.ndarray` + The layer output names, their instance index and their output index + """ + outputs = np.array(model_config["output_layers"]) + logger.debug("Obtained model outputs. is_sub_model: %s, outputs: %s, shape: %s", + is_sub_model, outputs, outputs.shape) + # Reshape the outputs to (side, outputs per side, output info) + outputs = outputs.reshape((1 if is_sub_model else 2, -1, outputs.shape[-1])) + logger.debug("Reshaped model outputs: %s, shape: %s", outputs, outputs.shape) + return outputs + + @classmethod + def _process_event(cls, event: event_pb2.Event, step: EventData) -> EventData: + """ Process a single Tensorboard event. + + Adds timestamp to the step `dict` if a total loss value is received, process the labels for + any new loss entries and adds the side loss value to the step `dict`. + + Parameters + ---------- + event: :class:`tensorboard.compat.proto.event_pb2` + The event data to be processed + step: :class:`EventData` + The currently processing dictionary to be populated with the extracted data from the + Tensorboard event for this step + + Returns + ------- + :class:`EventData` + The given step :class:`EventData` with the given event data added to it. + """ + summary = event.summary.value[0] + + if summary.tag == "batch_total": + step.timestamp = event.wall_time + return step + + loss = summary.simple_value + if not loss: + # Need to convert a tensor to a float for TF2.8 logged data. This maybe due to change + # in logging or may be due to work around put in place in FS training function for the + # following bug in TF 2.8/2.9 when writing records: + # https://github.com/keras-team/keras/issues/16173 + loss = float(np.frombuffer(summary.tensor.tensor_content, dtype="float32")) + + step.loss.append(loss) + + return step + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/analysis/moving_average.py b/lib/gui/analysis/moving_average.py new file mode 100644 index 0000000000..b4bb447561 --- /dev/null +++ b/lib/gui/analysis/moving_average.py @@ -0,0 +1,179 @@ +#!/usr/bin python3 +""" Calculate Exponential Moving Average for faceswap GUI Stats. """ + +import logging + +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + + +logger = logging.getLogger(__name__) + + +class ExponentialMovingAverage: + """ Reshapes data before calculating exponential moving average, then iterates once over the + rows to calculate the offset without precision issues. + + Parameters + ---------- + data : :class:`numpy.ndarray` + A 1 dimensional numpy array to obtain smoothed data for + amount : float + in the range (0.0, 1.0) The alpha parameter (smoothing amount) for the moving average. + + Notes + ----- + Adapted from: https://stackoverflow.com/questions/42869495 + """ + def __init__(self, data: np.ndarray, amount: float) -> None: + logger.debug(parse_class_init(locals())) + assert data.ndim == 1 + amount = min(max(amount, 0.001), 0.999) + + self._data = np.nan_to_num(data) + self._alpha = 1. - amount + self._dtype = "float32" if data.dtype == np.float32 else "float64" + self._row_size = self._get_max_row_size() + self._out = np.empty_like(data, dtype=self._dtype) + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self) -> np.ndarray: + """ Perform the exponential moving average calculation. + + Returns + ------- + :class:`numpy.ndarray` + The smoothed data + """ + if self._data.size <= self._row_size: + self._ewma_vectorized(self._data, self._out) # Normal function can handle this input + else: + self._ewma_vectorized_safe() # Use the safe version + return self._out + + def _get_max_row_size(self) -> int: + """ Calculate the maximum row size for the running platform for the given dtype. + + Returns + ------- + int + The maximum row size possible on the running platform for the given :attr:`_dtype` + + Notes + ----- + Might not be the optimal value for speed, which is hard to predict due to numpy + optimizations. + """ + # Use :func:`np.finfo(dtype).eps` if you are worried about accuracy and want to be safe. + epsilon = np.finfo(self._dtype).tiny + # If this produces an OverflowError, make epsilon larger: + retval = int(np.log(epsilon) / np.log(1 - self._alpha)) + 1 + logger.debug("row_size: %s", retval) + return retval + + def _ewma_vectorized_safe(self) -> None: + """ Perform the vectorized exponential moving average in a safe way. """ + num_rows = int(self._data.size // self._row_size) # the number of rows to use + leftover = int(self._data.size % self._row_size) # the amount of data leftover + first_offset = self._data[0] + + if leftover > 0: + # set temporary results to slice view of out parameter + out_main_view = np.reshape(self._out[:-leftover], (num_rows, self._row_size)) + data_main_view = np.reshape(self._data[:-leftover], (num_rows, self._row_size)) + else: + out_main_view = self._out.reshape(-1, self._row_size) + data_main_view = self._data.reshape(-1, self._row_size) + + self._ewma_vectorized_2d(data_main_view, out_main_view) # get the scaled cumulative sums + + scaling_factors = (1 - self._alpha) ** np.arange(1, self._row_size + 1) + last_scaling_factor = scaling_factors[-1] + + # create offset array + offsets = np.empty(out_main_view.shape[0], dtype=self._dtype) + offsets[0] = first_offset + # iteratively calculate offset for each row + + for i in range(1, out_main_view.shape[0]): + offsets[i] = offsets[i - 1] * last_scaling_factor + out_main_view[i - 1, -1] + + # add the offsets to the result + out_main_view += offsets[:, np.newaxis] * scaling_factors[np.newaxis, :] + + if leftover > 0: + # process trailing data in the 2nd slice of the out parameter + self._ewma_vectorized(self._data[-leftover:], + self._out[-leftover:], + offset=out_main_view[-1, -1]) + + def _ewma_vectorized(self, + data: np.ndarray, + out: np.ndarray, + offset: float | None = None) -> None: + """ Calculates the exponential moving average over a vector. Will fail for large inputs. + + The result is processed in place into the array passed to the `out` parameter + + Parameters + ---------- + data : :class:`numpy.ndarray` + A 1 dimensional numpy array to obtain smoothed data for + out : :class:`numpy.ndarray` + A location into which the result is stored. It must have the same shape and dtype as + the input data + offset : float, optional + The offset for the moving average, scalar. Default: the value held in data[0]. + """ + if data.size < 1: # empty input, return empty array + return + + offset = data[0] if offset is None else offset + + # scaling_factors -> 0 as len(data) gets large. This leads to divide-by-zeros below + scaling_factors = np.power(1. - self._alpha, np.arange(data.size + 1, dtype=self._dtype), + dtype=self._dtype) + # create cumulative sum array + np.multiply(data, (self._alpha * scaling_factors[-2]) / scaling_factors[:-1], + dtype=self._dtype, out=out) + np.cumsum(out, dtype=self._dtype, out=out) + + out /= scaling_factors[-2::-1] # cumulative sums / scaling + + if offset != 0: + noffset = np.asarray(offset).astype(self._dtype, copy=False) + out += noffset * scaling_factors[1:] + + def _ewma_vectorized_2d(self, data: np.ndarray, out: np.ndarray) -> None: + """ Calculates the exponential moving average over the last axis. + + The result is processed in place into the array passed to the `out` parameter + + Parameters + ---------- + data : :class:`numpy.ndarray` + A 1 or 2 dimensional numpy array to obtain smoothed data for. + out : :class:`numpy.ndarray` + A location into which the result is stored. It must have the same shape and dtype as + the input data + """ + if data.size < 1: # empty input, return empty array + return + + # calculate the moving average + scaling_factors = np.power(1. - self._alpha, np.arange(data.shape[1] + 1, + dtype=self._dtype), + dtype=self._dtype) + # create a scaled cumulative sum array + np.multiply(data, + np.multiply(self._alpha * scaling_factors[-2], + np.ones((data.shape[0], 1), dtype=self._dtype), + dtype=self._dtype) / scaling_factors[np.newaxis, :-1], + dtype=self._dtype, out=out) + np.cumsum(out, axis=1, dtype=self._dtype, out=out) + out /= scaling_factors[np.newaxis, -2::-1] + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/analysis/stats.py b/lib/gui/analysis/stats.py new file mode 100644 index 0000000000..cb78e1f7e3 --- /dev/null +++ b/lib/gui/analysis/stats.py @@ -0,0 +1,873 @@ +#!/usr/bin python3 +""" Stats functions for the GUI. + +Holds the globally loaded training session. This will either be a user selected session (loaded in +the analysis tab) or the currently training session. + +""" +from __future__ import annotations +import logging +import os +import time +import typing as T +import warnings + +from math import ceil +from threading import Event + +import numpy as np + +from lib.logger import parse_class_init +from lib.serializer import get_serializer +from lib.utils import get_module_objects + +from .moving_average import ExponentialMovingAverage + +from .event_reader import TensorBoardLogs + +logger = logging.getLogger(__name__) + + +class GlobalSession(): + """ Holds information about a loaded or current training session by accessing a model's state + file and Tensorboard logs. This class should not be accessed directly, rather through + :attr:`lib.gui.analysis.Session` + """ + def __init__(self) -> None: + logger.debug(parse_class_init(locals())) + self._state: dict[str, T.Any] = {} + self._model_dir = "" + self._model_name = "" + + self._tb_logs: TensorBoardLogs | None = None + self._summary: SessionsSummary | None = None + + self._is_training = False + self._is_querying = Event() + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def is_loaded(self) -> bool: + """ bool: ``True`` if session data is loaded otherwise ``False`` """ + return bool(self._model_dir) + + @property + def is_training(self) -> bool: + """ bool: ``True`` if the loaded session is the currently training model, otherwise + ``False`` """ + return self._is_training + + @property + def model_filename(self) -> str: + """ str: The full model filename """ + return os.path.join(self._model_dir, self._model_name) + + @property + def have_session_data(self) -> bool: + """ bool : ``True`` if session data is available otherwise ``False`` """ + return bool(self._state and self._state["sessions"]) + + @property + def batch_sizes(self) -> dict[int, int]: + """ dict: The batch sizes for each session_id for the model. """ + if not self.have_session_data: + return {} + return {int(sess_id): sess["batchsize"] + for sess_id, sess in self._state.get("sessions", {}).items()} + + @property + def full_summary(self) -> list[dict]: + """ list: List of dictionaries containing summary statistics for each session id. """ + assert self._summary is not None + return self._summary.get_summary_stats() + + @property + def logging_disabled(self) -> bool: + """ bool: ``True`` if logging is disabled for the currently training session otherwise + ``False``. """ + if not self.have_session_data: + return True + max_id = str(max(int(idx) for idx in self._state["sessions"])) + return self._state["sessions"][max_id]["no_logs"] + + @property + def session_ids(self) -> list[int]: + """ list: The sorted list of all existing session ids in the state file """ + if self._tb_logs is None: + return [] + return self._tb_logs.session_ids + + def _load_state_file(self) -> None: + """ Load the current state file to :attr:`_state`. """ + state_file = os.path.join(self._model_dir, f"{self._model_name}_state.json") + logger.debug("Loading State: '%s'", state_file) + serializer = get_serializer("json") + self._state = serializer.load(state_file) + logger.debug("Loaded state: %s", self._state) + + def initialize_session(self, + model_folder: str, + model_name: str, + is_training: bool = False) -> None: + """ Initialize a Session. + + Load's the model's state file, and sets the paths to any underlying Tensorboard logs, ready + for access on request. + + Parameters + ---------- + model_folder: str, + If loading a session manually (e.g. for the analysis tab), then the path to the model + folder must be provided. For training sessions, this should be passed through from the + launcher + model_name: str, optional + If loading a session manually (e.g. for the analysis tab), then the model filename + must be provided. For training sessions, this should be passed through from the + launcher + is_training: bool, optional + ``True`` if the session is being initialized for a training session, otherwise + ``False``. Default: ``False`` + """ + logger.debug("Initializing session: (is_training: %s)", is_training) + + if self._model_dir == model_folder and self._model_name == model_name: + if is_training: + assert self._tb_logs is not None + if not self._tb_logs.set_training(is_training): + logger.debug("Resetting summary for updated log files") + self._summary = SessionsSummary(self) + self._load_state_file() + self._is_training = is_training + logger.debug("Requested session is already loaded. Not initializing: " + "(model_folder: %s, model_name: %s)", model_folder, model_name) + return + + self._is_training = is_training + self._model_dir = model_folder + self._model_name = model_name + self._load_state_file() + if not self.logging_disabled: + self._tb_logs = TensorBoardLogs(os.path.join(self._model_dir, + f"{self._model_name}_logs"), + is_training) + + self._summary = SessionsSummary(self) + logger.debug("Initialized session. Session_IDS: %s", self.session_ids) + + def stop_training(self) -> None: + """ Clears the internal training flag. To be called when training completes. """ + self._is_training = False + if self._tb_logs is not None: + self._tb_logs.set_training(False) + + def clear(self) -> None: + """ Clear the currently loaded session. """ + self._state = {} + self._model_dir = "" + self._model_name = "" + + del self._tb_logs + self._tb_logs = None + + del self._summary + self._summary = None + + self._is_training = False + + def get_loss(self, session_id: int | None) -> dict[str, np.ndarray]: + """ Obtain the loss values for the given session_id. + + Parameters + ---------- + session_id: int or ``None`` + The session ID to return loss for. Pass ``None`` to return loss for all sessions. + + Returns + ------- + dict + Loss names as key, :class:`numpy.ndarray` as value. If No session ID was provided + all session's losses are collated + """ + self._wait_for_thread() + + if self._is_training: + self._is_querying.set() + + assert self._tb_logs is not None + loss_dict = self._tb_logs.get_loss(session_id=session_id) + if session_id is None: + all_loss: dict[str, list[float]] = {} + for key in sorted(loss_dict): + for loss_key, loss in loss_dict[key].items(): + all_loss.setdefault(loss_key, []).extend(loss) + retval: dict[str, np.ndarray] = {key: np.array(val, dtype="float32") + for key, val in all_loss.items()} + else: + retval = loss_dict.get(session_id, {}) + + if self._is_training: + self._is_querying.clear() + return retval + + @T.overload + def get_timestamps(self, session_id: None) -> dict[int, np.ndarray]: + ... + + @T.overload + def get_timestamps(self, session_id: int) -> np.ndarray: + ... + + def get_timestamps(self, session_id): + """ Obtain the time stamps keys for the given session_id. + + Parameters + ---------- + session_id: int or ``None`` + The session ID to return the time stamps for. Pass ``None`` to return time stamps for + all sessions. + + Returns + ------- + dict[int] or :class:`numpy.ndarray` + If a session ID has been given then a single :class:`numpy.ndarray` will be returned + with the session's time stamps. Otherwise a 'dict' will be returned with the session + IDs as key with :class:`numpy.ndarray` of timestamps as values + """ + self._wait_for_thread() + + if self._is_training: + self._is_querying.set() + + assert self._tb_logs is not None + retval = self._tb_logs.get_timestamps(session_id=session_id) + if session_id is not None: + retval = retval[session_id] + + if self._is_training: + self._is_querying.clear() + + return retval + + def _wait_for_thread(self) -> None: + """ If a thread is querying the log files for live data, then block until task clears. """ + while True: + if self._is_training and self._is_querying.is_set(): + logger.debug("Waiting for available thread") + time.sleep(1) + continue + break + + def get_loss_keys(self, session_id: int | None) -> list[str]: + """ Obtain the loss keys for the given session_id. + + Parameters + ---------- + session_id: int or ``None`` + The session ID to return the loss keys for. Pass ``None`` to return loss keys for + all sessions. + + Returns + ------- + list + The loss keys for the given session. If ``None`` is passed as session_id then a unique + list of all loss keys for all sessions is returned + """ + assert self._tb_logs is not None + loss_keys = {sess_id: list(logs.keys()) + for sess_id, logs + in self._tb_logs.get_loss(session_id=session_id).items()} + + if session_id is None: + retval: list[str] = list(set(loss_key + for session in loss_keys.values() + for loss_key in session)) + else: + retval = loss_keys.get(session_id, []) + return retval + + +_SESSION = GlobalSession() + + +class SessionsSummary(): + """ Performs top level summary calculations for each session ID within the loaded or currently + training Session for display in the Analysis tree view. + + Parameters + ---------- + session: :class:`GlobalSession` + The loaded or currently training session + """ + def __init__(self, session: GlobalSession) -> None: + logger.debug(parse_class_init(locals())) + self._session = session + self._state = session._state + + self._time_stats: dict[int, dict[str, float | int]] = {} + self._per_session_stats: list[dict[str, T.Any]] = [] + logger.debug("Initialized %s", self.__class__.__name__) + + def get_summary_stats(self) -> list[dict]: + """ Compile the individual session statistics and calculate the total. + + Format the stats for display + + Returns + ------- + list + A list of summary statistics dictionaries containing the Session ID, start time, end + time, elapsed time, rate, batch size and number of iterations for each session id + within the loaded data as well as the totals. + """ + logger.debug("Compiling sessions summary data") + if not self._session.have_session_data: + logger.debug("Session data doesn't exist. Most likely task has been " + "terminated during compilation, or is from LR finder") + return [] + self._get_time_stats() + self._get_per_session_stats() + if not self._per_session_stats: + return self._per_session_stats + + total_stats = self._total_stats() + retval = self._per_session_stats + [total_stats] + retval = self._format_stats(retval) + logger.debug("Final stats: %s", retval) + return retval + + def _get_time_stats(self) -> None: + """ Populates the attribute :attr:`_time_stats` with the start start time, end time and + data points for each session id within the loaded session if it has not already been + calculated. + + If the main Session is currently training, then the training session ID is updated with the + latest stats. + """ + if not self._time_stats: + logger.debug("Collating summary time stamps") + + self._time_stats = { + sess_id: {"start_time": np.min(timestamps) if np.any(timestamps) else 0, + "end_time": np.max(timestamps) if np.any(timestamps) else 0, + "iterations": timestamps.shape[0] if np.any(timestamps) else 0} + for sess_id, timestamps in T.cast(dict[int, np.ndarray], + self._session.get_timestamps(None)).items()} + + elif _SESSION.is_training: + logger.debug("Updating summary time stamps for training session") + + session_id = _SESSION.session_ids[-1] + latest = T.cast(np.ndarray, self._session.get_timestamps(session_id)) + + self._time_stats[session_id] = { + "start_time": np.min(latest) if np.any(latest) else 0, + "end_time": np.max(latest) if np.any(latest) else 0, + "iterations": latest.shape[0] if np.any(latest) else 0} + + logger.debug("time_stats: %s", self._time_stats) + + def _get_per_session_stats(self) -> None: + """ Populate the attribute :attr:`_per_session_stats` with a sorted list by session ID + of each ID in the training/loaded session. Stats contain the session ID, start, end and + elapsed times, the training rate, batch size and number of iterations for each session. + + If a training session is running, then updates the training sessions stats only. + """ + if not self._per_session_stats: + logger.debug("Collating per session stats") + compiled = [] + for session_id in self._time_stats: + logger.debug("Compiling session ID: %s", session_id) + if not self._session.have_session_data: + logger.debug("Session data doesn't exist. Most likely task has been " + "terminated during compilation, or is from LR finder") + return + compiled.append(self._collate_stats(session_id)) + + self._per_session_stats = list(sorted(compiled, key=lambda k: k["session"])) + + elif self._session.is_training: + logger.debug("Collating per session stats for latest training data") + session_id = self._session.session_ids[-1] + ts_data = self._time_stats[session_id] + + if session_id > len(self._per_session_stats): + self._per_session_stats.append(self._collate_stats(session_id)) + + stats = self._per_session_stats[-1] + + start = np.nan_to_num(ts_data["start_time"]) + end = np.nan_to_num(ts_data["end_time"]) + stats["start"] = start + stats["end"] = end + stats["elapsed"] = int(end - start) + stats["iterations"] = ts_data["iterations"] + stats["rate"] = (((stats["batch"] * 2) * stats["iterations"]) + / stats["elapsed"] if stats["elapsed"] > 0 else 0) + logger.debug("per_session_stats: %s", self._per_session_stats) + + def _collate_stats(self, session_id: int) -> dict[str, int | float]: + """ Collate the session summary statistics for the given session ID. + + Parameters + ---------- + session_id: int + The session id to compile the stats for + + Returns + ------- + dict + The collated session summary statistics + """ + timestamps = self._time_stats[session_id] + start = np.nan_to_num(timestamps["start_time"]) + end = np.nan_to_num(timestamps["end_time"]) + elapsed = int(end - start) + batchsize = self._session.batch_sizes.get(session_id, 0) + retval = { + "session": session_id, + "start": start, + "end": end, + "elapsed": elapsed, + "rate": (((batchsize * 2) * timestamps["iterations"]) / elapsed + if elapsed != 0 else 0), + "batch": batchsize, + "iterations": timestamps["iterations"]} + logger.debug(retval) + return retval + + def _total_stats(self) -> dict[str, str | int | float]: + """ Compile the Totals stats. + Totals are fully calculated each time as they will change on the basis of the training + session. + + Returns + ------- + dict + The Session name, start time, end time, elapsed time, rate, batch size and number of + iterations for all session ids within the loaded data. + """ + logger.debug("Compiling Totals") + starttime = 0.0 + endtime = 0.0 + elapsed = 0 + examples = 0 + iterations = 0 + batchset = set() + total_summaries = len(self._per_session_stats) + for idx, summary in enumerate(self._per_session_stats): + if idx == 0: + starttime = summary["start"] + if idx == total_summaries - 1: + endtime = summary["end"] + elapsed += summary["elapsed"] + examples += ((summary["batch"] * 2) * summary["iterations"]) + batchset.add(summary["batch"]) + iterations += summary["iterations"] + batch = ",".join(str(bs) for bs in batchset) + totals: dict[str, str | int | float] = { + "session": "Total", + "start": starttime, + "end": endtime, + "elapsed": elapsed, + "rate": examples / elapsed if elapsed != 0 else 0, + "batch": batch, + "iterations": iterations} + logger.debug(totals) + return totals + + def _format_stats(self, compiled_stats: list[dict]) -> list[dict]: + """ Format for the incoming list of statistics for display. + + Parameters + ---------- + compiled_stats: list + List of summary statistics dictionaries to be formatted for display + + Returns + ------- + list + The original statistics formatted for display + """ + logger.debug("Formatting stats") + retval = [] + for summary in compiled_stats: + hrs, mins, secs = self._convert_time(summary["elapsed"]) + stats = {} + for key in summary: + if key not in ("start", "end", "elapsed", "rate"): + stats[key] = summary[key] + continue + stats["start"] = time.strftime("%x %X", time.localtime(summary["start"])) + stats["end"] = time.strftime("%x %X", time.localtime(summary["end"])) + stats["elapsed"] = f"{hrs}:{mins}:{secs}" + stats["rate"] = f"{summary['rate']:.1f}" + retval.append(stats) + return retval + + @classmethod + def _convert_time(cls, timestamp: float) -> tuple[str, str, str]: + """ Convert time stamp to total hours, minutes and seconds. + + Parameters + ---------- + timestamp: float + The Unix timestamp to be converted + + Returns + ------- + tuple + (`hours`, `minutes`, `seconds`) as strings + """ + ihrs = int(timestamp // 3600) + hrs = f"{ihrs:02d}" if ihrs < 10 else str(ihrs) + mins = f"{(int(timestamp % 3600) // 60):02d}" + secs = f"{(int(timestamp % 3600) % 60):02d}" + return hrs, mins, secs + + +class Calculations(): + """ Class that performs calculations on the :class:`GlobalSession` raw data for the given + session id. + + Parameters + ---------- + session_id: int or ``None`` + The session id number for the selected session from the Analysis tab. Should be ``None`` + if all sessions are being calculated + display: {"loss", "rate"}, optional + Whether to display a graph for loss or training rate. Default: `"loss"` + loss_keys: list, optional + The list of loss keys to display on the graph. Default: `["loss"]` + selections: list, optional + The selected annotations to display. Default: `["raw"]` + avg_samples: int, optional + The number of samples to use for performing moving average calculation. Default: `500`. + smooth_amount: float, optional + The amount of smoothing to apply for performing smoothing calculation. Default: `0.9`. + flatten_outliers: bool, optional + ``True`` if values significantly away from the average should be excluded, otherwise + ``False``. Default: ``False`` + """ + def __init__(self, session_id, # pylint:disable=too-many-positional-arguments + display: str = "loss", + loss_keys: list[str] | str = "loss", + selections: list[str] | str = "raw", + avg_samples: int = 500, + smooth_amount: float = 0.90, + flatten_outliers: bool = False) -> None: + logger.debug(parse_class_init(locals())) + warnings.simplefilter("ignore", np.exceptions.RankWarning) + + self._session_id = session_id + + self._display = display + self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys] + self._selections = selections if isinstance(selections, list) else [selections] + self._is_totals = session_id is None + self._args: dict[str, int | float] = {"avg_samples": avg_samples, + "smooth_amount": smooth_amount, + "flatten_outliers": flatten_outliers} + self._iterations = 0 + self._limit = 0 + self._start_iteration = 0 + self._stats: dict[str, np.ndarray] = {} + self.refresh() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def iterations(self) -> int: + """ int: The number of iterations in the data set. """ + return self._iterations + + @property + def start_iteration(self) -> int: + """ int: The starting iteration number of a limit has been set on the amount of data. """ + return self._start_iteration + + @property + def stats(self) -> dict[str, np.ndarray]: + """ dict: The final calculated statistics """ + return self._stats + + def refresh(self) -> Calculations | None: + """ Refresh the stats """ + logger.debug("Refreshing") + if not _SESSION.is_loaded: + logger.warning("Session data is not initialized. Not refreshing") + return None + self._iterations = 0 + self._get_raw() + self._get_calculations() + self._remove_raw() + logger.debug("Refreshed: %s", {k: f"Total: {len(v)}, Min: {np.nanmin(v)}, " + f"Max: {np.nanmax(v)}, " + f"nans: {np.count_nonzero(np.isnan(v))}" + for k, v in self.stats.items()}) + return self + + def set_smooth_amount(self, amount: float) -> None: + """ Set the amount of smoothing to apply to smoothed graph. + + Parameters + ---------- + amount: float + The amount of smoothing to apply to smoothed graph + """ + update = max(min(amount, 0.999), 0.001) + logger.debug("Setting smooth amount to: %s (provided value: %s)", update, amount) + self._args["smooth_amount"] = update + + def update_selections(self, selection: str, option: bool) -> None: + """ Update the type of selected data. + + Parameters + ---------- + selection: str + The selection to update (as can exist in :attr:`_selections`) + option: bool + ``True`` if the selection should be included, ``False`` if it should be removed + """ + # TODO Somewhat hacky, to ensure values are inserted in the correct order. Fine for + # now as this is only called from Live Graph and selections can only be "raw" and + # smoothed. + if option: + if selection not in self._selections: + if selection == "raw": + self._selections.insert(0, selection) + else: + self._selections.append(selection) + else: + if selection in self._selections: + self._selections.remove(selection) + + def set_iterations_limit(self, limit: int) -> None: + """ Set the number of iterations to display in the calculations. + + If a value greater than 0 is passed, then the latest iterations up to the given + limit will be calculated. + + Parameters + ---------- + limit: int + The number of iterations to calculate data for. `0` to calculate for all data + """ + limit = max(0, limit) + logger.debug("Setting iteration limit to: %s", limit) + self._limit = limit + + def _get_raw(self) -> None: + """ Obtain the raw loss values and add them to a new :attr:`stats` dictionary. """ + logger.debug("Getting Raw Data") + self.stats.clear() + iterations = set() + + if self._display.lower() == "loss": + loss_dict = _SESSION.get_loss(self._session_id) + for loss_name, loss in loss_dict.items(): + if loss_name not in self._loss_keys: + continue + iterations.add(loss.shape[0]) + + if self._limit > 0: + loss = loss[-self._limit:] + + if self._args["flatten_outliers"]: + loss = self._flatten_outliers(loss) + + self.stats[f"raw_{loss_name}"] = loss + + self._iterations = 0 if not iterations else min(iterations) + if self._limit > 1: + self._start_iteration = max(0, self._iterations - self._limit) + self._iterations = min(self._iterations, self._limit) + else: + self._start_iteration = 0 + + if len(iterations) > 1: + # Crop all losses to the same number of items + if self._iterations == 0: + self._stats = {lossname: np.array([], dtype=loss.dtype) + for lossname, loss in self.stats.items()} + else: + self._stats = {lossname: loss[:self._iterations] + for lossname, loss in self.stats.items()} + + else: # Rate calculation + data = self._calc_rate_total() if self._is_totals else self._calc_rate() + if self._args["flatten_outliers"]: + data = self._flatten_outliers(data) + self._iterations = data.shape[0] + self.stats["raw_rate"] = data + + logger.debug("Got Raw Data: %s", {k: f"Total: {len(v)}, Min: {np.nanmin(v)}, " + f"Max: {np.nanmax(v)}, " + f"nans: {np.count_nonzero(np.isnan(v))}" + for k, v in self.stats.items()}) + + @classmethod + def _flatten_outliers(cls, data: np.ndarray) -> np.ndarray: + """ Remove the outliers from a provided list. + + Removes data more than 1 Standard Deviation from the mean. + + Parameters + ---------- + data: :class:`numpy.ndarray` + The data to remove the outliers from + + Returns + ------- + :class:`numpy.ndarray` + The data with outliers removed + """ + logger.debug("Flattening outliers: %s", data.shape) + mean = np.mean(np.nan_to_num(data)) + limit = np.std(np.nan_to_num(data)) + logger.debug("mean: %s, limit: %s", mean, limit) + retdata = np.where(abs(data - mean) < limit, data, mean) + logger.debug("Flattened outliers") + return retdata + + def _remove_raw(self) -> None: + """ Remove raw values from :attr:`stats` if they are not requested. """ + if "raw" in self._selections: + return + logger.debug("Removing Raw Data from output") + for key in list(self._stats.keys()): + if key.startswith("raw"): + del self._stats[key] + logger.debug("Removed Raw Data from output") + + def _calc_rate(self) -> np.ndarray: + """ Calculate rate per iteration. + + Returns + ------- + :class:`numpy.ndarray` + The training rate for each iteration of the selected session + """ + logger.debug("Calculating rate") + batch_size = _SESSION.batch_sizes[self._session_id] * 2 + retval = batch_size / np.diff(T.cast(np.ndarray, + _SESSION.get_timestamps(self._session_id))) + logger.debug("Calculated rate: Item_count: %s", len(retval)) + return retval + + @classmethod + def _calc_rate_total(cls) -> np.ndarray: + """ Calculate rate per iteration for all sessions. + + Returns + ------- + :class:`numpy.ndarray` + The training rate for each iteration in all sessions + + Notes + ----- + For totals, gaps between sessions can be large so the time difference has to be reset for + each session's rate calculation. + """ + logger.debug("Calculating totals rate") + batchsizes = _SESSION.batch_sizes + total_timestamps = _SESSION.get_timestamps(None) + rate: list[float] = [] + for sess_id in sorted(total_timestamps.keys()): + batchsize = batchsizes[sess_id] + timestamps = total_timestamps[sess_id] + rate.extend((batchsize * 2) / np.diff(timestamps)) + retval = np.array(rate) + logger.debug("Calculated totals rate: Item_count: %s", len(retval)) + return retval + + def _get_calculations(self) -> None: + """ Perform the required calculations and populate :attr:`stats`. """ + for selection in self._selections: + if selection == "raw": + continue + logger.debug("Calculating: %s", selection) + method = getattr(self, f"_calc_{selection}") + raw_keys = [key for key in self._stats if key.startswith("raw_")] + for key in raw_keys: + selected_key = f"{selection}_{key.replace('raw_', '')}" + self._stats[selected_key] = method(self._stats[key]) + logger.debug("Got calculations: %s", {k: f"Total: {len(v)}, Min: {np.nanmin(v)}, " + f"Max: {np.nanmax(v)}, " + f"nans: {np.count_nonzero(np.isnan(v))}" + for k, v in self.stats.items() + if not k.startswith("raw")}) + + def _calc_avg(self, data: np.ndarray) -> np.ndarray: + """ Calculate moving average. + + Parameters + ---------- + data: :class:`numpy.ndarray` + The data to calculate the moving average for + + Returns + ------- + :class:`numpy.ndarray` + The moving average for the given data + """ + logger.debug("Calculating Average. Data points: %s", len(data)) + window = T.cast(int, self._args["avg_samples"]) + pad = ceil(window / 2) + datapoints = data.shape[0] + + if datapoints <= (self._args["avg_samples"] * 2): + logger.info("Not enough data to compile rolling average") + return np.array([], dtype="float64") + + avgs = np.cumsum(np.nan_to_num(data), dtype="float64") + avgs[window:] = avgs[window:] - avgs[:-window] + avgs = avgs[window - 1:] / window + avgs = np.pad(avgs, (pad, datapoints - (avgs.shape[0] + pad)), constant_values=(np.nan,)) + logger.debug("Calculated Average: shape: %s", avgs.shape) + return avgs + + def _calc_smoothed(self, data: np.ndarray) -> np.ndarray: + """ Smooth the data. + + Parameters + ---------- + data: :class:`numpy.ndarray` + The data to smooth + + Returns + ------- + :class:`numpy.ndarray` + The smoothed data + """ + retval = ExponentialMovingAverage(data, self._args["smooth_amount"])() + logger.debug("Calculated Smoothed data: shape: %s", retval.shape) + return retval + + @classmethod + def _calc_trend(cls, data: np.ndarray) -> np.ndarray: + """ Calculate polynomial trend of the given data. + + Parameters + ---------- + data: :class:`numpy.ndarray` + The data to calculate the trend for + + Returns + ------- + :class:`numpy.ndarray` + The trend for the given data + """ + logger.debug("Calculating Trend") + points = data.shape[0] + if points < 10: + dummy = np.empty((points, ), dtype=data.dtype) + dummy[:] = np.nan + return dummy + x_range = range(points) + trend = np.poly1d(np.polyfit(x_range, np.nan_to_num(data), 3))(x_range) + logger.debug("Calculated Trend: shape: %s", trend.shape) + return trend + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/command.py b/lib/gui/command.py index 2b6e608be6..1f1dbccd18 100644 --- a/lib/gui/command.py +++ b/lib/gui/command.py @@ -2,49 +2,71 @@ """ The command frame for Faceswap GUI """ import logging +import gettext import tkinter as tk from tkinter import ttk -from .tooltip import Tooltip -from .utils import ContextMenu, FileHandler, get_images, get_config, set_slider_rounding +from lib.utils import get_module_objects -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +from .control_helper import ControlPanel +from .custom_widgets import Tooltip +from .utils import get_images, get_config +from .options import CliOption +logger = logging.getLogger(__name__) -class CommandNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class CommandNotebook(ttk.Notebook): # pylint:disable=too-many-ancestors """ Frame to hold each individual tab of the command notebook """ def __init__(self, parent): logger.debug("Initializing %s: (parent: %s)", self.__class__.__name__, parent) - scaling_factor = get_config().scaling_factor - width = int(420 * scaling_factor) - height = int(500 * scaling_factor) - ttk.Notebook.__init__(self, parent, width=width, height=height) + self.actionbtns = {} + super().__init__(parent) parent.add(self) - self.actionbtns = dict() + self.tools_notebook = ToolsNotebook(self) self.set_running_task_trace() self.build_tabs() - get_config().command_notebook = self + self.modified_vars = self._set_modified_vars() + get_config().set_command_notebook(self) logger.debug("Initialized %s", self.__class__.__name__) + @property + def tab_names(self): + """ dict: Command tab titles with their IDs """ + return {self.tab(tab_id, "text").lower(): tab_id + for tab_id in range(0, self.index("end"))} + + @property + def tools_tab_names(self): + """ dict: Tools tab titles with their IDs """ + return {self.tools_notebook.tab(tab_id, "text").lower(): tab_id + for tab_id in range(0, self.tools_notebook.index("end"))} + def set_running_task_trace(self): """ Set trigger action for the running task to change the action buttons text and command """ logger.debug("Set running trace") tk_vars = get_config().tk_vars - tk_vars["runningtask"].trace("w", self.change_action_button) + tk_vars.running_task.trace("w", self.change_action_button) def build_tabs(self): """ Build the tabs for the relevant command """ logger.debug("Build Tabs") cli_opts = get_config().cli_opts for category in cli_opts.categories: + book = self.tools_notebook if category == "tools" else self cmdlist = cli_opts.commands[category] for command in cmdlist: title = command.title() - commandtab = CommandTab(self, category, command) - self.add(commandtab, text=title) + commandtab = CommandTab(book, category, command) + book.add(commandtab, text=title) + self.add(self.tools_notebook, text="Tools") logger.debug("Built Tabs") def change_action_button(self, *args): @@ -52,26 +74,52 @@ def change_action_button(self, *args): logger.debug("Update Action Buttons: (args: %s", args) tk_vars = get_config().tk_vars - for cmd in self.actionbtns.keys(): - btnact = self.actionbtns[cmd] - if tk_vars["runningtask"].get(): - ttl = "Terminate" + for cmd, action in self.actionbtns.items(): + btnact = action + if tk_vars.running_task.get(): + ttl = " Stop" + img = get_images().icons["stop"] hlp = "Exit the running process" else: - ttl = cmd.title() - hlp = "Run the {} script".format(cmd.title()) + ttl = f" {cmd.title()}" + img = get_images().icons["start"] + hlp = f"Run the {cmd.title()} script" logger.debug("Updated Action Button: '%s'", ttl) - btnact.config(text=ttl) - Tooltip(btnact, text=hlp, wraplength=200) + btnact.config(text=ttl, image=img) + Tooltip(btnact, text=hlp, wrap_length=200) + + def _set_modified_vars(self): + """ Set the tkinter variable for each tab to indicate whether contents + have been modified """ + tkvars = {} + for tab in self.tab_names: + if tab == "tools": + for ttab in self.tools_tab_names: + var = tk.BooleanVar() + var.set(False) + tkvars[ttab] = var + continue + var = tk.BooleanVar() + var.set(False) + tkvars[tab] = var + logger.debug("Set modified vars: %s", tkvars) + return tkvars + + +class ToolsNotebook(ttk.Notebook): # pylint:disable=too-many-ancestors + """ Tools sit in their own tab, but need to inherit objects from the main command notebook """ + def __init__(self, parent): + super().__init__(parent) + self.actionbtns = parent.actionbtns -class CommandTab(ttk.Frame): # pylint: disable=too-many-ancestors +class CommandTab(ttk.Frame): # pylint:disable=too-many-ancestors """ Frame to hold each individual tab of the command notebook """ def __init__(self, parent, category, command): logger.debug("Initializing %s: (category: '%s', command: '%s')", self.__class__.__name__, category, command) - ttk.Frame.__init__(self, parent) + super().__init__(parent, name=f"tab_{command.lower()}") self.category = category self.actionbtns = parent.actionbtns @@ -83,10 +131,16 @@ def __init__(self, parent, category, command): def build_tab(self): """ Build the tab """ logger.debug("Build Tab: '%s'", self.command) - OptionsFrame(self) - + options = get_config().cli_opts.opts[self.command] + cp_opts = [val.cpanel_option for val in options.values() if isinstance(val, CliOption)] + ControlPanel(self, + cp_opts, + label_width=16, + option_columns=3, + columns=1, + header_text=options.get("helptext", None), + style="CPanel") self.add_frame_separator() - ActionFrame(self) logger.debug("Built Tab: '%s'", self.command) @@ -98,313 +152,12 @@ def add_frame_separator(self): logger.debug("Added frame seperator") -class OptionsFrame(ttk.Frame): # pylint: disable=too-many-ancestors - """ Options Frame - Holds the Options for each command """ - - def __init__(self, parent): - logger.debug("Initializing %s", self.__class__.__name__) - ttk.Frame.__init__(self, parent) - self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - - self.command = parent.command - - self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) - self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - - self.optsframe = ttk.Frame(self.canvas) - self.optscanvas = self.canvas.create_window((0, 0), - window=self.optsframe, - anchor=tk.NW) - self.chkbtns = self.checkbuttons_frame() - - self.build_frame() - cli_opts = get_config().cli_opts - cli_opts.set_context_option(self.command) - logger.debug("Initialized %s", self.__class__.__name__) - - def checkbuttons_frame(self): - """ Build and format frame for holding the check buttons """ - logger.debug("Add Options CheckButtons Frame") - container = ttk.Frame(self.optsframe) - - lbl = ttk.Label(container, text="Options", width=16, anchor=tk.W) - lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) - - chkframe = ttk.Frame(container) - chkframe.pack(side=tk.BOTTOM, expand=True) - - chkleft = ttk.Frame(chkframe, name="leftFrame") - chkleft.pack(side=tk.LEFT, anchor=tk.N, expand=True) - - chkright = ttk.Frame(chkframe, name="rightFrame") - chkright.pack(side=tk.RIGHT, anchor=tk.N, expand=True) - logger.debug("Added Options CheckButtons Frame") - - return container, chkframe - - def build_frame(self): - """ Build the options frame for this command """ - logger.debug("Add Options Frame") - self.add_scrollbar() - self.canvas.bind("", self.resize_frame) - - cli_opts = get_config().cli_opts - for option in cli_opts.gen_command_options(self.command): - optioncontrol = OptionControl(self.command, - option, - self.optsframe, - self.chkbtns[1]) - optioncontrol.build_full_control() - - if self.chkbtns[1].winfo_children(): - self.chkbtns[0].pack(side=tk.BOTTOM, fill=tk.X, expand=True) - logger.debug("Added Options Frame") - - def add_scrollbar(self): - """ Add a scrollbar to the options frame """ - logger.debug("Add Options Scrollbar") - scrollbar = ttk.Scrollbar(self, command=self.canvas.yview) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - self.canvas.config(yscrollcommand=scrollbar.set) - self.optsframe.bind("", self.update_scrollbar) - logger.debug("Added Options Scrollbar") - - def update_scrollbar(self, event): # pylint: disable=unused-argument - """ Update the options frame scrollbar """ - self.canvas.configure(scrollregion=self.canvas.bbox("all")) - - def resize_frame(self, event): - """ Resize the options frame to fit the canvas """ - logger.debug("Resize Options Frame") - canvas_width = event.width - self.canvas.itemconfig(self.optscanvas, width=canvas_width) - logger.debug("Resized Options Frame") - - -class OptionControl(): - """ Build the correct control for the option parsed and place it on the - frame """ - - def __init__(self, command, option, option_frame, checkbuttons_frame): - logger.debug("Initializing %s", self.__class__.__name__) - self.command = command - self.option = option - self.option_frame = option_frame - self.chkbtns = checkbuttons_frame - logger.debug("Initialized %s", self.__class__.__name__) - - def build_full_control(self): - """ Build the correct control type for the option passed through """ - logger.debug("Build option control") - ctl = self.option["control"] - ctltitle = self.option["control_title"] - sysbrowser = self.option["filesystem_browser"] - ctlhelp = self.format_help(ctltitle) - dflt = self.option.get("default", "") - if self.option.get("nargs", None) and isinstance(dflt, (list, tuple)): - dflt = ' '.join(str(val) for val in dflt) - if ctl == ttk.Checkbutton: - dflt = self.option.get("default", False) - choices = self.option["choices"] if ctl == ttk.Combobox else None - min_max = self.option["min_max"] if ctl == ttk.Scale else None - - ctlframe = self.build_one_control_frame() - - if ctl != ttk.Checkbutton: - self.build_one_control_label(ctlframe, ctltitle) - - ctlvars = (ctl, ctltitle, dflt, ctlhelp) - self.option["value"] = self.build_one_control(ctlframe, - ctlvars, - choices, - min_max, - sysbrowser) - logger.debug("Built option control") - - def format_help(self, ctltitle): - """ Format the help text for tooltips """ - logger.debug("Format control help: '%s'", ctltitle) - ctlhelp = self.option.get("help", "") - if ctlhelp.startswith("R|"): - ctlhelp = ctlhelp[2:].replace("\n\t", " ").replace("\n'", "\n\n'") - else: - ctlhelp = " ".join(ctlhelp.split()) - ctlhelp = ctlhelp.replace("%%", "%") - ctlhelp = ". ".join(i.capitalize() for i in ctlhelp.split(". ")) - ctlhelp = ctltitle + " - " + ctlhelp - logger.debug("Formatted control help: (title: '%s', help: '%s'", ctltitle, ctlhelp) - return ctlhelp - - def build_one_control_frame(self): - """ Build the frame to hold the control """ - logger.debug("Build control frame") - frame = ttk.Frame(self.option_frame) - frame.pack(fill=tk.X, expand=True) - logger.debug("Built control frame") - return frame - - @staticmethod - def build_one_control_label(frame, control_title): - """ Build and place the control label """ - logger.debug("Build control label: '%s'", control_title) - lbl = ttk.Label(frame, text=control_title, width=16, anchor=tk.W) - lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) - logger.debug("Built control label: '%s'", control_title) - - def build_one_control(self, frame, controlvars, choices, min_max, sysbrowser): - """ Build and place the option controls """ - logger.debug("Build control: (controlvars: %s, choices: %s, min_max: %s, sysbrowser: %s", - controlvars, choices, min_max, sysbrowser) - control, control_title, default, helptext = controlvars - default = default if default is not None else "" - - var = tk.BooleanVar(frame) if control == ttk.Checkbutton else tk.StringVar(frame) - var.set(default) - - if sysbrowser: - self.add_browser_buttons(frame, sysbrowser, var) - - if control == ttk.Checkbutton: - self.checkbutton_to_checkframe(control, - control_title, - var, - helptext) - elif control == ttk.Scale: - self.slider_control(control, - frame, - var, - min_max, - helptext) - else: - self.control_to_optionsframe(control, - frame, - var, - choices, - helptext) - logger.debug("Built control: '%s'", control_title) - return var - - def checkbutton_to_checkframe(self, control, control_title, var, helptext): - """ Add checkbuttons to the checkbutton frame """ - logger.debug("Add control checkframe: '%s'", control_title) - leftframe = self.chkbtns.children["leftFrame"] - rightframe = self.chkbtns.children["rightFrame"] - chkbtn_count = len({**leftframe.children, **rightframe.children}) - - frame = leftframe if chkbtn_count % 2 == 0 else rightframe - - ctl = control(frame, variable=var, text=control_title) - ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W) - - Tooltip(ctl, text=helptext, wraplength=200) - logger.debug("Added control checkframe: '%s'", control_title) - - def slider_control(self, control, frame, tk_var, min_max, helptext): - """ A slider control with corresponding Entry box """ - logger.debug("Add slider control to Options Frame: %s", control) - d_type = self.option.get("type", float) - rnd = self.option.get("rounding", 2) if d_type == float else self.option.get("rounding", 1) - - tbox = ttk.Entry(frame, width=8, textvariable=tk_var, justify=tk.RIGHT) - tbox.pack(padx=(0, 5), side=tk.RIGHT) - ctl = control( - frame, - variable=tk_var, - command=lambda val, var=tk_var, dt=d_type, rn=rnd, mm=min_max: - set_slider_rounding(val, var, dt, rn, mm)) - ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) - rc_menu = ContextMenu(ctl) - rc_menu.cm_bind() - ctl["from_"] = min_max[0] - ctl["to"] = min_max[1] - - Tooltip(ctl, text=helptext, wraplength=720) - Tooltip(tbox, text=helptext, wraplength=720) - logger.debug("Added slider control to Options Frame: %s", control) - - @staticmethod - def control_to_optionsframe(control, frame, var, choices, helptext): - """ Standard non-check buttons sit in the main options frame """ - logger.debug("Add control to Options Frame: %s", control) - ctl = control(frame, textvariable=var) - ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) - rc_menu = ContextMenu(ctl) - rc_menu.cm_bind() - if control == ttk.Combobox: - logger.debug("Adding combo choices: %s", choices) - ctl["values"] = [choice for choice in choices] - Tooltip(ctl, text=helptext, wraplength=920) - logger.debug("Added control to Options Frame: %s", control) - - def add_browser_buttons(self, frame, sysbrowser, filepath): - """ Add correct file browser button for control """ - logger.debug("Adding browser buttons: (sysbrowser: '%s', filepath: '%s'", - sysbrowser, filepath) - for browser in sysbrowser: - img = get_images().icons[browser] - action = getattr(self, "ask_" + browser) - filetypes = self.option.get("filetypes", "default") - fileopn = ttk.Button(frame, - image=img, - command=lambda cmd=action: cmd(filepath, filetypes)) - fileopn.pack(padx=(0, 5), side=tk.RIGHT) - logger.debug("Added browser buttons: (action: %s, filetypes: %s", - action, filetypes) - - @staticmethod - def ask_folder(filepath, filetypes=None): - """ Pop-up to get path to a directory - :param filepath: tkinter StringVar object - that will store the path to a directory. - :param filetypes: Unused argument to allow - filetypes to be given in ask_load(). """ - dirname = FileHandler("dir", filetypes).retfile - if dirname: - logger.debug(dirname) - filepath.set(dirname) - - @staticmethod - def ask_load(filepath, filetypes): - """ Pop-up to get path to a file """ - filename = FileHandler("filename", filetypes).retfile - if filename: - logger.debug(filename) - filepath.set(filename) - - @staticmethod - def ask_save(filepath, filetypes=None): - """ Pop-up to get path to save a new file """ - filename = FileHandler("savefilename", filetypes).retfile - if filename: - logger.debug(filename) - filepath.set(filename) - - @staticmethod - def ask_nothing(filepath, filetypes=None): # pylint: disable=unused-argument - """ Method that does nothing, used for disabling open/save pop up """ - return - - def ask_context(self, filepath, filetypes): - """ Method to pop the correct dialog depending on context """ - logger.debug("Getting context filebrowser") - selected_action = self.option["action_option"].get() - selected_variable = self.option["dest"] - filename = FileHandler("context", - filetypes, - command=self.command, - action=selected_action, - variable=selected_variable).retfile - if filename: - logger.debug(filename) - filepath.set(filename) - - -class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors +class ActionFrame(ttk.Frame): # pylint:disable=too-many-ancestors """Action Frame - Displays action controls for the command tab """ def __init__(self, parent): logger.debug("Initializing %s: (command: '%s')", self.__class__.__name__, parent.command) - ttk.Frame.__init__(self, parent) + super().__init__(parent) self.pack(fill=tk.BOTH, padx=5, pady=5, side=tk.BOTTOM, anchor=tk.N) self.command = parent.command @@ -412,55 +165,41 @@ def __init__(self, parent): self.add_action_button(parent.category, parent.actionbtns) - self.add_util_buttons() logger.debug("Initialized %s", self.__class__.__name__) def add_action_button(self, category, actionbtns): """ Add the action buttons for page """ logger.debug("Add action buttons: '%s'", self.title) actframe = ttk.Frame(self) - actframe.pack(fill=tk.X, side=tk.LEFT) + actframe.pack(fill=tk.X, side=tk.RIGHT) + tk_vars = get_config().tk_vars + var_value = f"{category},{self.command}" - var_value = "{},{}".format(category, self.command) + btngen = ttk.Button(actframe, + image=get_images().icons["generate"], + text=" Generate", + compound=tk.LEFT, + width=14, + command=lambda: tk_vars.generate_command.set(var_value)) + btngen.pack(side=tk.LEFT, padx=5) + Tooltip(btngen, + text=_("Output command line options to the console"), + wrap_length=200) btnact = ttk.Button(actframe, - text=self.title, - width=10, - command=lambda: tk_vars["action"].set(var_value)) - btnact.pack(side=tk.LEFT) + image=get_images().icons["start"], + text=f" {self.title}", + compound=tk.LEFT, + width=14, + command=lambda: tk_vars.action_command.set(var_value)) + btnact.pack(side=tk.LEFT, fill=tk.X, expand=True) Tooltip(btnact, - text="Run the {} script".format(self.title), - wraplength=200) + text=_("Run the {} script").format(self.title), + wrap_length=200) actionbtns[self.command] = btnact - btngen = ttk.Button(actframe, - text="Generate", - width=10, - command=lambda: tk_vars["generate"].set(var_value)) - btngen.pack(side=tk.RIGHT, padx=5) - Tooltip(btngen, - text="Output command line options to the console", - wraplength=200) logger.debug("Added action buttons: '%s'", self.title) - def add_util_buttons(self): - """ Add the section utility buttons """ - logger.debug("Add util buttons") - utlframe = ttk.Frame(self) - utlframe.pack(side=tk.RIGHT) - - config = get_config() - for utl in ("load", "save", "clear", "reset"): - logger.debug("Adding button: '%s'", utl) - img = get_images().icons[utl] - action_cls = config if utl in (("save", "load")) else config.cli_opts - action = getattr(action_cls, utl) - btnutl = ttk.Button(utlframe, - image=img, - command=lambda cmd=action: cmd(self.command)) - btnutl.pack(padx=2, side=tk.LEFT) - Tooltip(btnutl, - text=utl.capitalize() + " " + self.title + " config", - wraplength=200) - logger.debug("Added util buttons") + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/control_helper.py b/lib/gui/control_helper.py new file mode 100644 index 0000000000..46c26a4017 --- /dev/null +++ b/lib/gui/control_helper.py @@ -0,0 +1,1521 @@ +#!/usr/bin/env python3 +""" Helper functions and classes for GUI controls """ +from __future__ import annotations +import gettext +import logging +import re +import tkinter as tk +import types + +from tkinter import colorchooser, ttk +from itertools import zip_longest +from functools import partial +from typing import Any, cast, get_args, Literal, Self, TYPE_CHECKING + +from _tkinter import Tcl_Obj, TclError + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .custom_widgets import ContextMenu, MultiOption, ToggledFrame, Tooltip +from .utils import FileHandler, get_config, get_images +from . import gui_config as cfg + +if TYPE_CHECKING: + from lib.config import ConfigItem + + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext + +# We store Tooltips, ContextMenus and Commands globally when they are created +# Because we need to add them back to newly cloned widgets (they are not easily accessible from +# original config or are prone to getting destroyed when the original widget is destroyed) +_RECREATE_OBJECTS: dict[str, dict[str, Any]] = {"tooltips": {}, + "commands": {}, + "contextmenus": {}} + + +def _get_tooltip(widget, text=None, text_variable=None): + """ Store the tooltip layout and widget id in _TOOLTIPS and return a tooltip. + + Auto adjust tooltip width based on amount of text. + + """ + _RECREATE_OBJECTS["tooltips"][str(widget)] = {"text": text, + "text_variable": text_variable} + logger.debug("Adding to tooltips dict: (widget: %s. text: '%s')", widget, text) + + wrap_length = 400 + if text is not None: + while True: + if len(text) < wrap_length * 5: + break + if wrap_length > 800: + break + wrap_length = int(wrap_length * 1.10) + + return Tooltip(widget, text=text, text_variable=text_variable, wrap_length=wrap_length) + + +def _get_contextmenu(widget): + """ Create a context menu, store its mapping and return """ + rc_menu = ContextMenu(widget) + _RECREATE_OBJECTS["contextmenus"][str(widget)] = rc_menu + logger.debug("Adding to Context menu: (widget: %s. rc_menu: %s)", + widget, rc_menu) + return rc_menu + + +def _add_command(name, func): + """ For controls that execute commands, the command must be added to the _COMMAND list so that + it can be added back to the widget during cloning """ + logger.debug("Adding to commands: %s - %s", name, func) + _RECREATE_OBJECTS["commands"][str(name)] = func + + +def set_slider_rounding(value, var, d_type, round_to, min_max): + """ Set the value of sliders underlying variable based on their datatype, + rounding value and min/max. + + Parameters + ---------- + var: tkinter.Var + The variable to set the value for + d_type: [:class:`int`, :class:`float`] + The type of value that is stored in :attr:`var` + round_to: int or list + If :attr:`d_type` is :class:`float` then this is the decimal place rounding for + :attr:`var`. If :attr:`d_type` is :class:`int` then this is the number of steps between + each increment for :attr:`var`. If a list is provided, then this must be a list of + discreet values that are of the correct :attr:`d_type`. + min_max: tuple (`int`, `int`) + The (``min``, ``max``) values that this slider accepts + """ + if isinstance(round_to, list): + # Lock to nearest item + var.set(min(round_to, key=lambda x: abs(x-float(value)))) + elif d_type == float: + var.set(round(float(value), round_to)) + else: + steps = range(min_max[0], min_max[1] + round_to, round_to) + value = min(steps, key=lambda x: abs(x - int(float(value)))) + var.set(value) + + +class ControlPanelOption(): + """ A class to hold a control panel option. A list of these is expected to be passed to the + ControlPanel object. + + Parameters + ---------- + title : str + Title of the control. Will be used for label text and control naming + dtype : type + Datatype of the control. + group : str | None, optional + The group that this control should sit with. If provided, all controls in the same + group will be placed together. Default: ``None`` + subgroup : str | None, optional + The subgroup that this option belongs to. If provided, will group options in the same + subgroups together for the same layout as option/check boxes. Default: ``None`` + default : str | bool | float | int | list[str] | None, optional + Default value for the control. If None is provided, then action will be dictated by + whether "blank_nones" is set in ControlPanel. Default: ``None`` + initial_value : str | bool | float | int | list[str] | None, optional + Initial value for the control. If ``None``, default will be used. Default: ``None`` + choices : list[str] | tuple[str, ...] | Literal["colorchooser"] | None, optional + Used for combo boxes and radio control option setting. Set to `"colorchooser"` for a color + selection dialog. Default: ``None`` + is_radio : bool, optional + Specifies to use a Radio control instead of combobox if choices are passed. + Default: ``False`` + is_multi_option : bool, optional + Specifies to use a Multi Check Button option group for the specified control. + Default: ``False`` + rounding : int | float | None, optional + For slider controls. Sets the stepping. Default: ``None`` + min_max : tuple[int, int] | tuple[float, float] | None, optional + For slider controls. Sets the min and max values. Default: ``None`` + sysbrowser : dict[Literal["filetypes", "browser", "command", "destination", "action_option"], str | list[str]] | None, optional + Adds Filesystem browser buttons to ttk.Entry options. Default: ``None`` + helptext : str | None, optional + Sets the tooltip text. Default: ``None`` + track_modified : bool, optional + Set whether to set a callback trace indicating that the parameter has been modified. + Default: ``False`` + command : str | None, optional + Required if tracking modified. The command that this option belongs to. Default: ``None`` + """ # noqa[E501] # pylint:disable=line-too-long + def __init__(self, # pylint:disable=too-many-arguments,too-many-positional-arguments,too-many-locals # noqa[E501] + title: str, + dtype: type, + group: str | None = None, + subgroup: str | None = None, + default: str | bool | float | int | None = None, + initial_value: str | bool | float | int | None = None, + choices: list[str] | tuple[str, ...] | Literal["colorchooser"] | None = None, + is_radio: bool = False, + is_multi_option: bool = False, + rounding: int | float | None = None, + min_max: tuple[int, int] | tuple[float, float] | None = None, + sysbrowser: dict[Literal["filetypes", + "browser", + "command", + "destination", + "action_option"], str | list[str]] | None = None, + helptext: str | None = None, + track_modified: bool = False, + command: str | None = None) -> None: + logger.debug(parse_class_init(locals())) + self.dtype = dtype + self.sysbrowser = sysbrowser + self._command = command + self._track_modified = track_modified + self._options = {"title": title, + "subgroup": subgroup, + "group": group, + "default": default, + "initial_value": initial_value, + "choices": choices, + "is_radio": is_radio, + "is_multi_option": is_multi_option, + "rounding": rounding, + "min_max": min_max, + "helptext": helptext} + self.control = self.get_control() + initial_value = default if initial_value is None else initial_value + initial_value = "" if initial_value is None else initial_value + self.tk_var = self.get_tk_var(initial_value) + logger.debug("Initialized %s", self.__class__.__name__) + + def __repr__(self) -> str: + """ Pretty printed representation for logging """ + non_opts = {"dtype": self.dtype, + "sysbrowser": self.sysbrowser, + "track_modified": self._track_modified} + params = non_opts | self._options + str_params = ", ".join(f"{k}={repr(v)}" for k, v in params.items()) + return f"{self.__class__.__name__}({str_params})" + + @property + def name(self) -> str: + """ str : Lowered title for naming """ + title = self._options["title"] + assert isinstance(title, str) + return title.lower() + + @property + def title(self): + """ str : Title case title for naming with underscores removed """ + title = self._options["title"] + assert isinstance(title, str) + return title.replace("_", " ").title() + + @property + def group(self) -> str: + """ str : Option group or "_master" if no group set """ + group = self._options["group"] + if group is None: + group = "_master" + assert isinstance(group, str) + return group + + @property + def subgroup(self) -> str | None: + """ str | None : Option subgroup, or ``None`` if none provided. """ + retval = self._options["subgroup"] + if retval is not None: + assert isinstance(retval, str) + return retval + + @property + def default(self) -> str | bool | float | int | None: + """ str | bool | float | int | list[str] : Either the currently selected value or the + default """ + retval = self._options["default"] + assert isinstance(retval, (str, bool, float, int, types.NoneType)) + return retval + + @property + def value(self) -> str | bool | float | int | None: + """ str | bool | float | int | list[str] : Either the initial value or default """ + retval = self._options["initial_value"] + retval = self.default if retval is None else retval + assert isinstance(retval, (str, bool, float, int, types.NoneType)) + return retval + + @property + def choices(self) -> list[str] | tuple[str, ...] | Literal["colorchooser"] | None: + """ list[str] | tuple[str, ...] | Literal["colorchooser"] : The option choices """ + retval = self._options["choices"] + if retval is not None: + assert isinstance(retval, (list, tuple, str)) + if isinstance(retval, str): + assert retval in get_args(Literal["colorchooser"]) + else: + assert all(isinstance(x, str) for x in retval) + return cast(list[str] | tuple[str, ...] | Literal["colorchooser"] | None, retval) + + @property + def is_radio(self) -> bool: + """ bool : If the option should be a radio control """ + retval = self._options["is_radio"] + assert isinstance(retval, bool) + return retval + + @property + def is_multi_option(self) -> bool: + """ bool : ``True`` if the control should be contained in a multi check button group, + otherwise ``False``. """ + retval = self._options["is_multi_option"] + assert isinstance(retval, bool) + return retval + + @property + def rounding(self) -> int | float | None: + """ int | float | None : Rounding for numeric controls """ + retval = self._options["rounding"] + assert retval is None or isinstance(retval, (int, float)) + return retval + + @property + def min_max(self) -> tuple[int, int] | tuple[float, float] | None: + """ tuple[int, int] | tuple[float, float] | None : minimum and maximum values for numeric + controls """ + retval = self._options["min_max"] + if retval is not None: + assert isinstance(retval, tuple) + assert len(retval) == 2 + assert isinstance(retval[0], (int, float)) and isinstance(retval[1], (int, float)) + return retval + + @property + def helptext(self) -> str | None: + """ str | None : The formatted option help text for tooltips """ + helptext = self._options["helptext"] + if helptext is None: + return helptext + assert isinstance(helptext, str) + logger.debug("Format control help: '%s'", self.name) + if helptext.startswith("R|"): + helptext = helptext[2:].replace("\nL|", "\n - ").replace("\n", "\n\n") + else: + helptext = helptext.replace("\n\t", "\n - ").replace("%%", "%") + helptext = self.title + " - " + helptext + logger.debug("Formatted control help: (name: '%s', help: '%s'", self.name, helptext) + return helptext + + def get(self) -> str | bool | int | float: + """ Return the option value from the tk_var + + Returns + ------- + str | bool | float | int + The value selected for this option + + Notes + ----- + tk variables don't like empty values if it's not a stringVar. This seems to be pretty + much the only reason that a get() call would fail, so replace any numerical variable + with it's numerical zero equivalent on a TCL Error. Only impacts variables linked + to Entry widgets. + """ + try: + val = self.tk_var.get() + except TclError: + if isinstance(self.tk_var, tk.IntVar): + val = 0 + elif isinstance(self.tk_var, tk.DoubleVar): + val = 0.0 + else: + raise + return val + + def set(self, value: str | bool | int | float | None) -> None: + """ Set the variable for the config option with the given value + + Parameters + ---------- + value : str | bool | float | int | None + The value to set the config option variable to + """ + self.tk_var.set(value) + + def set_initial_value(self, value: str | bool | int | float): + """ Set the initial_value to the given value + + Parameters + ---------- + value : str | bool | int | float + The value to set the initial value attribute to + """ + logger.debug("Setting inital value for %s to %s", self.name, value) + self._options["initial_value"] = value + + def get_control(self) -> Literal["radio", "multi", "colorchooser", "scale"] | type[ + ttk.Combobox] | type[ttk.Checkbutton] | type[tk.Entry]: + """ Set the correct control type based on the datatype or for this option """ + control: Literal["radio", + "multi", + "colorchooser", + "scale"] | type[ttk.Combobox] | type[ttk.Checkbutton] | type[tk.Entry] + if self.choices and self.is_radio: + control = "radio" + elif self.choices and self.is_multi_option: + control = "multi" + elif self.choices and self.choices == "colorchooser": + control = "colorchooser" + elif self.choices: + control = ttk.Combobox + elif self.dtype == bool: + control = ttk.Checkbutton + elif self.dtype in (int, float): + control = "scale" + else: + control = tk.Entry + logger.debug("Setting control '%s' to %s", self.title, control) + return control + + def get_tk_var(self, initial_value: str | bool | int | float) -> tk.Variable: + """ Correct variable type for control + + Parameters + ---------- + initial value : str | bool | int | float + The initial value to set the tk.Variable to + + Returns + ------- + :class:`tk.BooleanVar` | :class:`tk.IntVar` | :class:`tk.DoubleVar` | :class:`tk.StringVar` + The correct tk.Variable for the given initial value + """ + var: tk.Variable + if self.dtype == bool: + assert isinstance(initial_value, bool) + var = tk.BooleanVar() + var.set(initial_value) + elif self.dtype == int: + assert isinstance(initial_value, int) + var = tk.IntVar() + var.set(initial_value) + elif self.dtype == float: + assert isinstance(initial_value, float) + var = tk.DoubleVar() + var.set(initial_value) + else: + var = tk.StringVar() + var.set(cast(str, initial_value)) + logger.debug("Setting tk variable: (name: '%s', dtype: %s, tk_var: %s, initial_value: %s)", + self.name, self.dtype, var, initial_value) + if self._track_modified and self._command is not None: + logger.debug("Tracking variable modification: %s", self.name) + var.trace("w", + lambda name, index, mode, cmd=self._command: self._modified_callback(cmd)) + + if self._track_modified and self._command == "train" and self.title == "Model Dir": + var.trace("w", lambda name, index, mode, v=var: self._model_callback(v)) + + return var + + @staticmethod + def _modified_callback(command: str) -> None: + """ Set the modified variable for this tab to TRUE + + On initial setup the notebook won't yet exist, and we don't want to track the changes + for initial variables anyway, so make sure notebook exists prior to performing the callback + + Parameters + ---------- + command : str + The command to set the modified variable callback for + """ + config = get_config() + if config.command_notebook is None: + return + config.set_modified_true(command) + + @staticmethod + def _model_callback(tk_var: tk.StringVar) -> None: + """ Set a callback to load model stats for existing models when a model folder is selected + + Parameters + ---------- + tk_var : :class:`tkinter.StringVar` + The Tk variable to set the callback on + """ + config = get_config() + if not cfg.auto_load_model_stats(): + logger.debug("Session updating disabled by user config") + return + if config.tk_vars.running_task.get(): + logger.debug("Task running. Not updating session") + return + folder = tk_var.get() + logger.debug("Setting analysis model folder callback: '%s'", folder) + get_config().tk_vars.analysis_folder.set(folder) + + @classmethod + def from_config_object(cls, title: str, option: ConfigItem) -> Self: + """ Create a GUI control panel option from a Faceswap ConfigItem + + Parameters + ---------- + title : str + The option title (that displays as a label in the GUI) + option : :class:`~lib.config.ConfigItem` + The faceswap object to create the Control Panel option from + + Returns + ------- + :class:`ControlPanelOption` + A GUI ControlPanelOption instance + """ + initial_value = option.value + if option.datatype == list and isinstance(initial_value, list): + # Split multi-select lists into space separated strings for tk variables + initial_value = " ".join(initial_value) + + default = ", ".join(option.default) if isinstance(option.default, list) else option.default + + logger.debug("Creating Gui Option '%s' from: %s", title, option) + + retval = cls( + title=title, + dtype=option.datatype, + group=option.group, + default=default, + initial_value=initial_value, + choices=option.choices, + is_radio=option.gui_radio, + is_multi_option=option.datatype == list, + rounding=option.rounding, + min_max=option.min_max, + helptext=option.helptext) + logger.debug("Created GUI option '%s': %s", title, retval) + return retval + + +class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors,too-many-instance-attributes + """ + A Control Panel to hold control panel options. + This class handles all of the formatting, placing and TK_Variables + in a consistent manner. + + It can also provide dynamic columns for resizing widgets + + Parameters + ---------- + parent: tkinter object + Parent widget that should hold this control panel + options: list of ControlPanelOptions objects + The list of controls that are to be built into this control panel + label_width: int, optional + The width that labels for controls should be set to. + Defaults to 20 + columns: int, optional + The initial number of columns to set the layout for. Default: 1 + max_columns: int, optional + The maximum number of columns that this control panel should be able + to accommodate. Setting to 1 means that there will only be 1 column + regardless of how wide the control panel is. Higher numbers will + dynamically fill extra columns if space permits. Defaults to 4 + option_columns: int, optional + For check-button and radio-button containers, how many options should + be displayed on each row. Defaults to 4 + header_text: str, optional + If provided, will place an information box at the top of the control + panel with these contents. + style: str, optional + The name of the style to use for the control panel. Styles are configured when TkInter + initializes. The style name is the common prefix prior to the widget name. Default: + ``None`` (use the OS style) + blank_nones: bool, optional + How the control panel should handle None values. If set to True then None values will be + converted to empty strings. Default: False + scrollbar: bool, optional + ``True`` if a scrollbar should be added to the control panel, otherwise ``False``. + Default: ``True`` + """ + + def __init__(self, parent, options, # pylint:disable=too-many-arguments,too-many-positional-arguments # noqa[E501] + label_width=20, columns=1, max_columns=4, option_columns=4, header_text=None, + style=None, blank_nones=True, scrollbar=True): + logger.debug("Initializing %s: (parent: '%s', options: %s, label_width: %s, columns: %s, " + "max_columns: %s, option_columns: %s, header_text: %s, style: %s, " + "blank_nones: %s, scrollbar: %s)", + self.__class__.__name__, parent, options, label_width, columns, max_columns, + option_columns, header_text, style, blank_nones, scrollbar) + self._style = "" if style is None else f"{style}." + super().__init__(parent, style=f"{self._style}.Group.TFrame") + + self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + + self.options = options + self.controls = [] + self.label_width = label_width + self.columns = columns + self.max_columns = max_columns + self.option_columns = option_columns + + self.header_text = header_text + self._theme = get_config().user_theme["group_panel"] + if self._style.startswith("SPanel"): + self._theme = {**self._theme, **get_config().user_theme["group_settings"]} + + self.group_frames = {} + self._sub_group_frames = {} + + canvas_kwargs = {"bd": 0, "highlightthickness": 0, "bg": self._theme["panel_background"]} + + self._canvas = tk.Canvas(self, **canvas_kwargs) + self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + self.mainframe, self.optsframe = self.get_opts_frame() + self._optscanvas = self._canvas.create_window((0, 0), window=self.mainframe, anchor=tk.NW) + self.build_panel(blank_nones, scrollbar) + + logger.debug("Initialized %s", self.__class__.__name__) + + @staticmethod + def _adjust_wraplength(event): + """ dynamically adjust the wrap length of a label on event """ + label = event.widget + label.configure(wraplength=event.width - 1) + + def get_opts_frame(self): + """ Return an auto-fill container for the options inside a main frame """ + style = f"{self._style}Holder." + mainframe = ttk.Frame(self._canvas, style=f"{style}TFrame") + if self.header_text is not None: + self.add_info(mainframe) + optsframe = ttk.Frame(mainframe, name="opts_frame", style=f"{style}TFrame") + optsframe.pack(expand=True, fill=tk.BOTH) + holder = AutoFillContainer(optsframe, self.columns, self.max_columns, style=style) + logger.debug("Opts frames: '%s'", holder) + return mainframe, holder + + def add_info(self, frame): + """ Plugin information """ + info_frame = ttk.Frame(frame, style=f"{self._style}InfoHeader.TFrame") + info_frame.pack(fill=tk.X, side=tk.TOP, expand=True, padx=10, pady=(10, 0)) + label_frame = ttk.Frame(info_frame, style=f"{self._style}InfoHeader.TFrame") + label_frame.pack(padx=5, pady=5, fill=tk.X, expand=True) + for idx, line in enumerate(self.header_text.splitlines()): + if not line: + continue + style = f"{self._style}InfoHeader" if idx == 0 else f"{self._style}InfoBody" + info = ttk.Label(label_frame, text=line, style=f"{style}.TLabel", anchor=tk.W) + info.bind("", self._adjust_wraplength) + info.pack(fill=tk.X, padx=0, pady=0, expand=True, side=tk.TOP) + + def build_panel(self, blank_nones, scrollbar): + """ Build the options frame for this command """ + logger.debug("Add Config Frame") + if scrollbar: + self.add_scrollbar() + self._canvas.bind("", self.resize_frame) + + for option in self.options: + group_frame = self.get_group_frame(option.group) + sub_group_frame = self._get_subgroup_frame(group_frame["frame"], option.subgroup) + frame = group_frame["frame"] if sub_group_frame is None else sub_group_frame.subframe + + ctl = ControlBuilder(frame, + option, + label_width=self.label_width, + checkbuttons_frame=group_frame["chkbtns"], + option_columns=self.option_columns, + style=self._style, + blank_nones=blank_nones) + if group_frame["chkbtns"].items > 0: + group_frame["chkbtns"].parent.pack(side=tk.BOTTOM, fill=tk.X, anchor=tk.NW) + + self.controls.append(ctl) + for control in self.controls: + filebrowser = control.filebrowser + if filebrowser is not None: + filebrowser.set_context_action_option(self.options) + logger.debug("Added Config Frame") + + def get_group_frame(self, group): + """ Return a group frame. + + If a group frame has already been created for the given group, then it will be returned, + otherwise it will be created and returned. + + Parameters + ---------- + group: str + The name of the group to obtain the group frame for + + Returns + ------- + :class:`ttk.Frame` or :class:`ToggledFrame` + If this is a 'master' group frame then returns a standard frame. If this is any + other group, then will return the ToggledFrame for that group + """ + group = group.lower() + + if self.group_frames.get(group, None) is None: + logger.debug("Creating new group frame for: %s", group) + is_master = group == "_master" + opts_frame = self.optsframe.subframe + if is_master: + group_frame = ttk.Frame(opts_frame, style=f"{self._style}.Group.TFrame") + retval = group_frame + else: + group_frame = ToggledFrame(opts_frame, text=group.title(), theme=self._style) + retval = group_frame.sub_frame + + group_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5, anchor=tk.NW) + + self.group_frames[group] = {"frame": retval, + "chkbtns": self.checkbuttons_frame(retval)} + group_frame = self.group_frames[group] + return group_frame + + def add_scrollbar(self): + """ Add a scrollbar to the options frame """ + logger.debug("Add Config Scrollbar") + scrollbar = ttk.Scrollbar(self, + command=self._canvas.yview, + style=f"{self._style}Vertical.TScrollbar") + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self._canvas.config(yscrollcommand=scrollbar.set) + self.mainframe.bind("", self.update_scrollbar) + logger.debug("Added Config Scrollbar") + + def update_scrollbar(self, event): # pylint:disable=unused-argument + """ Update the options frame scrollbar """ + self._canvas.configure(scrollregion=self._canvas.bbox("all")) + + def resize_frame(self, event): + """ Resize the options frame to fit the canvas """ + logger.debug("Resize Config Frame") + canvas_width = event.width + self._canvas.itemconfig(self._optscanvas, width=canvas_width) + self.optsframe.rearrange_columns(canvas_width) + logger.debug("Resized Config Frame") + + def checkbuttons_frame(self, frame): + """ Build and format frame for holding the check buttons + if is_master then check buttons will be placed in a LabelFrame + otherwise in a standard frame """ + logger.debug("Add Options CheckButtons Frame") + chk_frame = ttk.Frame(frame, name="chkbuttons", style=f"{self._style}Group.TFrame") + holder = AutoFillContainer(chk_frame, + self.option_columns, + self.option_columns, + style=f"{self._style}Group.") + logger.debug("Added Options CheckButtons Frame") + return holder + + def _get_subgroup_frame(self, parent, subgroup): + if subgroup is None: + return subgroup + if subgroup not in self._sub_group_frames: + sub_frame = ttk.Frame(parent, style=f"{self._style}Group.TFrame") + self._sub_group_frames[subgroup] = AutoFillContainer(sub_frame, + self.option_columns, + self.option_columns, + style=f"{self._style}Group.") + sub_frame.pack(anchor=tk.W, expand=True, fill=tk.X) + logger.debug("Added Subgroup Frame: %s", subgroup) + return self._sub_group_frames[subgroup] + + +class AutoFillContainer(): + """ A container object that auto-fills columns. + + Parameters + ---------- + parent: :class:`ttk.Frame` + The parent widget that holds this container + initial_columns: int + The initial number of columns that this container should display + max_columns: int + The maximum number of column that this container is permitted to display + style: str, optional + The name of the style to use for the control panel. Styles are configured when TkInter + initializes. The style name is the common prefix prior to the widget name. Default: + empty string (use the OS style) + """ + def __init__(self, parent, initial_columns, max_columns, style=""): + logger.debug("Initializing: %s: (parent: %s, initial_columns: %s, max_columns: %s)", + self.__class__.__name__, parent, initial_columns, max_columns) + self.max_columns = max_columns + self.columns = initial_columns + self.parent = parent + self._style = style +# self.columns = min(columns, self.max_columns) + self.single_column_width = self.scale_column_width(288, 9) + self.max_width = self.max_columns * self.single_column_width + self._items = 0 + self._idx = 0 + self._widget_config = [] # Master list of all children in order + self.subframes = self.set_subframes() + logger.debug("Initialized: %s", self.__class__.__name__) + + @staticmethod + def scale_column_width(original_size, original_fontsize): + """ Scale the column width based on selected font size """ + font_size = cfg.font_size() + if font_size == original_fontsize: + return original_size + scale = 1 + (((font_size / original_fontsize) - 1) / 2) + retval = round(original_size * scale) + logger.debug("scaled column width: (old_width: %s, scale: %s, new_width:%s)", + original_size, scale, retval) + return retval + + @property + def items(self): + """ Returns the number of items held in this container """ + return self._items + + @property + def subframe(self): + """ Returns the next sub-frame to be populated """ + frame = self.subframes[self._idx] + next_idx = self._idx + 1 if self._idx + 1 < self.columns else 0 + logger.debug("current_idx: %s, next_idx: %s", self._idx, next_idx) + self._idx = next_idx + self._items += 1 + return frame + + def set_subframes(self): + """ Set a sub-frame for each possible column """ + subframes = [] + for idx in range(self.max_columns): + name = f"af_subframe_{idx}" + subframe = ttk.Frame(self.parent, name=name, style=f"{self._style}TFrame") + if idx < self.columns: + # Only pack visible columns + subframe.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N, expand=True, fill=tk.X) + subframes.append(subframe) + logger.debug("Added subframe: %s", name) + return subframes + + def rearrange_columns(self, width): + """ On column number change redistribute widgets """ + if not self.validate(width): + return + + new_columns = min(self.max_columns, max(1, width // self.single_column_width)) + logger.debug("Rearranging columns: (width: %s, old_columns: %s, new_columns: %s)", + width, self.columns, new_columns) + self.columns = new_columns + if not self._widget_config: + self.compile_widget_config() + self.destroy_children() + self.repack_columns() + # Reset counters + self._items = 0 + self._idx = 0 + self.pack_widget_clones(self._widget_config) + + def validate(self, width): + """ Validate that passed in width should trigger column re-arranging """ + if ((width < self.single_column_width and self.columns == 1) or + (width > self.max_width and self.columns == self.max_columns)): + logger.debug("width outside min/max thresholds: (min: %s, width: %s, max: %s)", + self.single_column_width, width, self.max_width) + return False + range_min = self.columns * self.single_column_width + range_max = (self.columns + 1) * self.single_column_width + if range_min < width < range_max: + logger.debug("width outside next step refresh threshold: (step down: %s, width: %s," + "step up: %s)", range_min, width, range_max) + return False + return True + + def compile_widget_config(self): + """ Compile all children recursively in correct order if not already compiled and add + to :attr:`_widget_config` """ + zipped = zip_longest(*(subframe.winfo_children() for subframe in self.subframes)) + children = [child for group in zipped for child in group if child is not None] + self._widget_config = [{"class": child.__class__, + "id": str(child), + "tooltip": _RECREATE_OBJECTS["tooltips"].get(str(child), None), + "rc_menu": _RECREATE_OBJECTS["contextmenus"].get(str(child), None), + "pack_info": self.pack_config_cleaner(child), + "name": child.winfo_name(), + "config": self.config_cleaner(child), + "children": self.get_all_children_config(child, []), + # Some children have custom kwargs, so keep dicts in sync + "custom_kwargs": self._custom_kwargs(child)} + for idx, child in enumerate(children)] + logger.debug("Compiled AutoFillContainer children: %s", self._widget_config) + + @classmethod + def _custom_kwargs(cls, widget): + """ For custom widgets some custom arguments need to be passed from the old widget to the + newly created widget. + + Parameters + ---------- + widget: tkinter widget + The widget to be checked for custom keyword arguments + + Returns + ------- + dict + The custom keyword arguments required for recreating the given widget + """ + retval = {} + if widget.__class__.__name__ == "MultiOption": + retval = {"value": widget._value, # pylint:disable=protected-access + "variable": widget._master_variable} # pylint:disable=protected-access + elif widget.__class__.__name__ == "ToggledFrame": + # Toggled Frames need to have their variable tracked + retval = {"text": widget._text, # pylint:disable=protected-access + "toggle_var": widget._toggle_var} # pylint:disable=protected-access + return retval + + def get_all_children_config(self, widget, child_list): + """ Return all children, recursively, of given widget. + + Parameters + ---------- + widget: tkinter widget + The widget to recursively obtain the configurations of each child + child_list: list + The list of child configurations already collected + + Returns + ------- + list + The list of configurations for all recursive children of the given widget + """ + unpack = set() + for child in widget.winfo_children(): + # Hidden Toggle Frame boxes need to be mapped + if child.winfo_ismapped() or "toggledframe_subframe" in str(child): + not_mapped = not child.winfo_ismapped() + # ToggleFrame is a custom widget that creates it's own children and handles + # bindings on the headers, to auto-hide the contents. To ensure that all child + # information (specifically pack information) can be collected, we need to pack + # any hidden sub-frames. These are then hidden again once collected. + if not_mapped and (child.winfo_name() == "toggledframe_subframe" or + child.winfo_name() == "chkbuttons"): + child.pack(fill=tk.X, expand=True) + child.update_idletasks() # Updates the packing info of children + unpack.add(child) + + if child.winfo_name().startswith("toggledframe_header"): + # Headers should be entirely handled by parent widget + continue + + child_list.append({ + "class": child.__class__, + "id": str(child), + "tooltip": _RECREATE_OBJECTS["tooltips"].get(str(child), None), + "rc_menu": _RECREATE_OBJECTS["contextmenus"].get(str(child), None), + "pack_info": self.pack_config_cleaner(child), + "name": child.winfo_name(), + "config": self.config_cleaner(child), + "parent": child.winfo_parent(), + "custom_kwargs": self._custom_kwargs(child)}) + self.get_all_children_config(child, child_list) + + # Re-hide any toggle frames that were expanded + for hide in unpack: + hide.pack_forget() + hide.update_idletasks() + return child_list + + @staticmethod + def config_cleaner(widget): + """ Some options don't like to be copied, so this returns a cleaned + configuration from a widget + We use config() instead of configure() because some items (ttk Scale) do + not populate configure()""" + new_config = {} + for key in widget.config(): + if key == "class": + continue + val = widget.cget(key) + # Some keys default to "" but tkinter doesn't like to set config to this value + # so skip them to use default value. + if key in ("anchor", "justify", "compound") and val == "": + continue + # Following keys cannot be defined after widget is created: + if key in ("colormap", "container", "visual"): + continue + val = str(val) if isinstance(val, Tcl_Obj) else val + # Return correct command from master command dict + val = _RECREATE_OBJECTS["commands"][val] if key == "command" and val != "" else val + new_config[key] = val + return new_config + + @staticmethod + def pack_config_cleaner(widget): + """ Some options don't like to be copied, so this returns a cleaned + configuration from a widget """ + return {key: val for key, val in widget.pack_info().items() if key != "in"} + + def destroy_children(self): + """ Destroy the currently existing widgets """ + for subframe in self.subframes: + for child in subframe.winfo_children(): + child.destroy() + + def repack_columns(self): + """ Repack or unpack columns based on display columns """ + for idx, subframe in enumerate(self.subframes): + logger.trace("Processing subframe: %s", subframe) + if idx < self.columns and not subframe.winfo_ismapped(): + logger.trace("Packing subframe: %s", subframe) + subframe.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N, expand=True, fill=tk.X) + elif idx >= self.columns and subframe.winfo_ismapped(): + logger.trace("Forgetting subframe: %s", subframe) + subframe.pack_forget() + + def pack_widget_clones(self, widget_dicts, old_children=None, new_children=None): + """ Recursively pass through the list of widgets creating clones and packing all + children. + + Widgets cannot be given a new parent so we need to clone them and then pack the + new widgets. + + Parameters + ---------- + widget_dicts: list + List of dictionaries, in appearance order, of widget information for cloning widgets + old_childen: list, optional + Used for recursion. Leave at ``None`` + new_childen: list, optional + Used for recursion. Leave at ``None`` + """ + for widget_dict in widget_dicts: + logger.debug("Cloning widget: %s", widget_dict) + old_children = [] if old_children is None else old_children + new_children = [] if new_children is None else new_children + if widget_dict.get("parent", None) is not None: + parent = new_children[old_children.index(widget_dict["parent"])] + logger.trace("old parent: '%s', new_parent: '%s'", widget_dict["parent"], parent) + else: + # Get the next sub-frame if this doesn't have a logged parent + parent = self.subframe + clone = widget_dict["class"](parent, + name=widget_dict["name"], + **widget_dict["custom_kwargs"]) + if widget_dict["config"] is not None: + clone.configure(**widget_dict["config"]) + if widget_dict["tooltip"] is not None: + Tooltip(clone, **widget_dict["tooltip"]) + rc_menu = widget_dict["rc_menu"] + if rc_menu is not None: + # Re-initialize for new widget and bind + rc_menu.__init__(widget=clone) # pylint:disable=unnecessary-dunder-call + rc_menu.cm_bind() + clone.pack(**widget_dict["pack_info"]) + + # Handle ToggledFrame sub-frames. If the parent is not set to expanded, then we need to + # hide the sub-frame + if clone.winfo_name() == "toggledframe_subframe": + toggle_frame = clone.nametowidget(clone.winfo_parent()) + if not toggle_frame.is_expanded: + logger.debug("Hiding minimized toggle box: %s", clone) + clone.pack_forget() + + old_children.append(widget_dict["id"]) + new_children.append(clone) + if widget_dict.get("children", None) is not None: + self.pack_widget_clones(widget_dict["children"], old_children, new_children) + + +class ControlBuilder(): + """ + Builds and returns a frame containing a tkinter control with label + This should only be called from the ControlPanel class + + Parameters + ---------- + parent: tkinter object + Parent tkinter object + option: ControlPanelOption object + Holds all of the required option information + option_columns: int + Number of options to put on a single row for check-buttons/radio-buttons + label_width: int + Sets the width of the control label + checkbuttons_frame: tkinter.frame + If a check-button frame is passed in, then check-buttons will be placed in this frame + rather than the main options frame + style: str + The name of the style to use for the control panel. Styles are configured when TkInter + initializes. The style name is the common prefix prior to the widget name. Provide an empty + string to use the OS style + blank_nones: bool + Sets selected values to an empty string rather than None if this is true. + """ + def __init__(self, parent, option, option_columns, # pylint:disable=too-many-arguments + label_width, checkbuttons_frame, style, blank_nones): + logger.debug("Initializing %s: (parent: %s, option: %s, option_columns: %s, " + "label_width: %s, checkbuttons_frame: %s, style: %s, blank_nones: %s)", + self.__class__.__name__, parent, option, option_columns, label_width, + checkbuttons_frame, style, blank_nones) + + self.option = option + self.option_columns = option_columns + self.helpset = False + self.label_width = label_width + self.filebrowser = None + # Default to Control Panel Style + self._style = style = style if style else "CPanel." + self._theme = get_config().user_theme["group_panel"] + if self._style.startswith("SPanel"): + self._theme = {**self._theme, **get_config().user_theme["group_settings"]} + + self.frame = self.control_frame(parent) + self.chkbtns = checkbuttons_frame + + self.set_tk_var(blank_nones) + self.build_control() + logger.debug("Initialized: %s", self.__class__.__name__) + + # Frame, control type and variable + def control_frame(self, parent): + """ Frame to hold control and it's label """ + logger.debug("Build control frame") + frame = ttk.Frame(parent, + name=f"fr_{self.option.name}", + style=f"{self._style}Group.TFrame") + frame.pack(fill=tk.X) + logger.debug("Built control frame") + return frame + + def set_tk_var(self, blank_nones): + """ Correct variable type for control """ + val = "" if self.option.value is None and blank_nones else self.option.value + self.option.tk_var.set(val) + logger.debug("Set tk variable: (option: '%s', variable: %s, value: '%s')", + self.option.name, self.option.tk_var, val) + + # Build the full control + def build_control(self): + """ Build the correct control type for the option passed through """ + logger.debug("Build config option control") + if self.option.control not in (ttk.Checkbutton, "radio", "multi", "colorchooser"): + self.build_control_label() + self.build_one_control() + logger.debug("Built option control") + + def build_control_label(self): + """ Label for control """ + logger.debug("Build control label: (option: '%s')", self.option.name) + lbl = ttk.Label(self.frame, + text=self.option.title, + width=self.label_width, + anchor=tk.W, + style=f"{self._style}Group.TLabel") + lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) + if self.option.helptext is not None: + _get_tooltip(lbl, text=self.option.helptext) + logger.debug("Built control label: (widget: '%s', title: '%s'", + self.option.name, self.option.title) + + def build_one_control(self): + """ Build and place the option controls """ + logger.debug("Build control: '%s')", self.option.name) + if self.option.control == "scale": + ctl = self.slider_control() + elif self.option.control in ("radio", "multi"): + ctl = self._multi_option_control(self.option.control) + elif self.option.control == "colorchooser": + ctl = self._color_control() + elif self.option.control == ttk.Checkbutton: + ctl = self.control_to_checkframe() + else: + ctl = self.control_to_optionsframe() + if self.option.control != ttk.Checkbutton: + ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) + if self.option.helptext is not None and not self.helpset: + tooltip_kwargs = {"text": self.option.helptext} + if self.option.sysbrowser is not None: + tooltip_kwargs["text_variable"] = self.option.tk_var + _get_tooltip(ctl, **tooltip_kwargs) + + logger.debug("Built control: '%s'", self.option.name) + + def _multi_option_control(self, option_type): + """ Create a group of buttons for single or multi-select + + Parameters + ---------- + option_type: {"radio", "multi"} + The type of boxes that this control should hold. "radio" for single item select, + "multi" for multi item select. + + """ + logger.debug("Adding %s group: %s", option_type, self.option.name) + help_intro, help_items = self._get_multi_help_items(self.option.helptext) + ctl = ttk.LabelFrame(self.frame, + text=self.option.title, + name=f"{option_type}_labelframe", + style=f"{self._style}Group.TLabelframe") + holder = AutoFillContainer(ctl, + self.option_columns, + self.option_columns, + style=f"{self._style}Group.") + for choice in self.option.choices: + if option_type == "radio": + ctl = ttk.Radiobutton + style = f"{self._style}Group.TRadiobutton" + else: + ctl = MultiOption + style = f"{self._style}Group.TCheckbutton" + + ctl = ctl(holder.subframe, + text=choice.replace("_", " ").title(), + value=choice, + variable=self.option.tk_var, + style=style) + if choice.lower() in help_items: + self.helpset = True + helptext = help_items[choice.lower()] + helptext = f"{helptext}\n\n - {help_intro}" + _get_tooltip(ctl, text=helptext) + ctl.pack(anchor=tk.W, fill=tk.X) + logger.debug("Added %s option %s", option_type, choice) + return holder.parent + + @staticmethod + def _get_multi_help_items(helptext): + """ Split the help text up, for formatted help text, into the individual options + for multi/radio buttons. + + Parameters + ---------- + helptext: str + The raw help text for this cli. option + + Returns + ------- + tuple (`str`, `dict`) + The help text intro and a dictionary containing the help text split into separate + entries for each option choice + """ + logger.debug("raw help: %s", helptext) + all_help = helptext.splitlines() + intro = "" + if any(line.startswith(" - ") for line in all_help): + intro = all_help[0] + retval = (intro, + {re.sub(r"[^\w\-\_]+", "", + line.split()[1].lower()): " ".join(line.replace("_", " ").split()[1:]) + for line in all_help if line.startswith(" - ")}) + logger.debug("help items: %s", retval) + return retval + + def slider_control(self): + """ A slider control with corresponding Entry box """ + logger.debug("Add slider control to Options Frame: (widget: '%s', dtype: %s, " + "rounding: %s, min_max: %s)", self.option.name, self.option.dtype, + self.option.rounding, self.option.min_max) + validate = self.slider_check_int if self.option.dtype == int else self.slider_check_float + vcmd = self.frame.register(validate) + tbox = tk.Entry(self.frame, + width=8, + textvariable=self.option.tk_var, + justify=tk.RIGHT, + font=get_config().default_font, + validate="all", + validatecommand=(vcmd, "%P"), + bg=self._theme["input_color"], + fg=self._theme["input_font"], + highlightbackground=self._theme["input_font"], + highlightthickness=1, + bd=0) + tbox.pack(padx=(0, 5), side=tk.RIGHT) + cmd = partial(set_slider_rounding, + var=self.option.tk_var, + d_type=self.option.dtype, + round_to=self.option.rounding, + min_max=self.option.min_max) + ctl = ttk.Scale(self.frame, + variable=self.option.tk_var, + command=cmd, + style=f"{self._style}Horizontal.TScale") + _add_command(ctl.cget("command"), cmd) + rc_menu = _get_contextmenu(tbox) + rc_menu.cm_bind() + ctl["from_"] = self.option.min_max[0] + ctl["to"] = self.option.min_max[1] + logger.debug("Added slider control to Options Frame: %s", self.option.name) + return ctl + + @staticmethod + def slider_check_int(value): + """ Validate a slider's text entry box for integer values. + + Parameters + ---------- + value: str + The slider text entry value to validate + """ + if value.isdigit() or value == "": + return True + return False + + @staticmethod + def slider_check_float(value): + """ Validate a slider's text entry box for float values. + Parameters + ---------- + value: str + The slider text entry value to validate + """ + if value: + try: + float(value) + except ValueError: + return False + return True + + def control_to_optionsframe(self): + """ Standard non-check buttons sit in the main options frame """ + logger.debug("Add control to Options Frame: (widget: '%s', control: %s, choices: %s)", + self.option.name, self.option.control, self.option.choices) + if self.option.sysbrowser is not None: + self.filebrowser = FileBrowser(self.option.name, + self.option.tk_var, + self.frame, + self.option.sysbrowser, + self._style) + + if self.option.control == tk.Entry: + ctl = self.option.control(self.frame, + textvariable=self.option.tk_var, + font=get_config().default_font, + bg=self._theme["input_color"], + fg=self._theme["input_font"], + highlightbackground=self._theme["input_font"], + highlightthickness=1, + bd=0) + else: # Combobox + ctl = self.option.control(self.frame, + textvariable=self.option.tk_var, + font=get_config().default_font, + state="readonly", + style=f"{self._style}TCombobox") + + # Style for combo list boxes needs to be set directly on widget as no style parameter + cmd = f"[ttk::combobox::PopdownWindow {ctl}].f.l configure -" + ctl.tk.eval(f"{cmd}foreground {self._theme['input_font']}") + ctl.tk.eval(f"{cmd}background {self._theme['input_color']}") + ctl.tk.eval(f"{cmd}selectforeground {self._theme['control_active']}") + ctl.tk.eval(f"{cmd}selectbackground {self._theme['control_disabled']}") + + rc_menu = _get_contextmenu(ctl) + rc_menu.cm_bind() + + if self.option.choices: + logger.debug("Adding combo choices: %s", self.option.choices) + ctl["values"] = self.option.choices + ctl["state"] = "readonly" + logger.debug("Added control to Options Frame: %s", self.option.name) + return ctl + + def _color_control(self): + """ Clickable label holding the currently selected color """ + logger.debug("Add control to Options Frame: (widget: '%s', control: %s, choices: %s)", + self.option.name, self.option.control, self.option.choices) + frame = ttk.Frame(self.frame, style=f"{self._style}Group.TFrame") + lbl = ttk.Label(frame, + text=self.option.title, + width=self.label_width, + anchor=tk.W, + style=f"{self._style}Group.TLabel") + ctl = tk.Frame(frame, + bg=self.option.tk_var.get(), + bd=2, + cursor="hand2", + relief=tk.SUNKEN, + width=round(int(20 * get_config().scaling_factor)), + height=round(int(14 * get_config().scaling_factor))) + ctl.bind("", lambda *e, c=ctl, t=self.option.title: self._ask_color(c, t)) + lbl.pack(side=tk.LEFT, anchor=tk.N) + ctl.pack(side=tk.RIGHT, anchor=tk.W) + frame.pack(padx=5, side=tk.LEFT, anchor=tk.W) + if self.option.helptext is not None: + _get_tooltip(frame, text=self.option.helptext) + # Callback to set the color chooser background on an update (e.g. reset) + self.option.tk_var.trace("w", lambda *e: ctl.config(bg=self.option.tk_var.get())) + logger.debug("Added control to Options Frame: %s", self.option.name) + return ctl + + def _ask_color(self, frame, title): + """ Pop ask color dialog set to variable and change frame color """ + color = self.option.tk_var.get() + chosen = colorchooser.askcolor(parent=frame, color=color, title=f"{title} Color")[1] + if chosen is None: + return + self.option.tk_var.set(chosen) + + def control_to_checkframe(self): + """ Add check-buttons to the check-button frame """ + logger.debug("Add control checkframe: '%s'", self.option.name) + chkframe = self.chkbtns.subframe + ctl = self.option.control(chkframe, + variable=self.option.tk_var, + text=self.option.title, + name=self.option.name, + style=f"{self._style}Group.TCheckbutton") + _get_tooltip(ctl, text=self.option.helptext) + ctl.pack(side=tk.TOP, anchor=tk.W, fill=tk.X) + logger.debug("Added control checkframe: '%s'", self.option.name) + return ctl + + +class FileBrowser(): + """ Add FileBrowser buttons to control and handle routing """ + def __init__(self, opt_name, tk_var, control_frame, sysbrowser_dict, style): + logger.debug("Initializing: %s: (tk_var: %s, control_frame: %s, sysbrowser_dict: %s, " + "style: %s)", self.__class__.__name__, tk_var, control_frame, + sysbrowser_dict, style) + self._opt_name = opt_name + self.tk_var = tk_var + self.frame = control_frame + self._style = style + self.browser = sysbrowser_dict["browser"] + self.filetypes = sysbrowser_dict["filetypes"] + self.action_option = self.format_action_option(sysbrowser_dict.get("action_option", None)) + self.command = sysbrowser_dict.get("command", None) + self.destination = sysbrowser_dict.get("destination", None) + self.add_browser_buttons() + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def helptext(self): + """ Dict containing tooltip text for buttons """ + retval = {"folder": _("Select a folder..."), + "load": _("Select a file..."), + "load2": _("Select a file..."), + "picture": _("Select a folder of images..."), + "video": _("Select a video..."), + "model": _("Select a model folder..."), + "multi_load": _("Select one or more files..."), + "context": _("Select a file or folder..."), + "save_as": _("Select a save location...")} + return retval + + @staticmethod + def format_action_option(action_option): + """ Format the action option to remove any dashes at the start """ + if action_option is None: + return action_option + if action_option.startswith("--"): + return action_option[2:] + if action_option.startswith("-"): + return action_option[1:] + return action_option + + def add_browser_buttons(self): + """ Add correct file browser button for control """ + logger.debug("Adding browser buttons: (sysbrowser: %s", self.browser) + frame = ttk.Frame(self.frame, style=f"{self._style}Group.TFrame") + frame.pack(side=tk.RIGHT, padx=(0, 5)) + + for browser in self.browser: + if browser == "save": + lbl = "save_as" + elif browser == "load" and self.filetypes == "video": + lbl = self.filetypes + elif browser == "load": + lbl = "load2" + elif browser == "folder" and (self._opt_name.startswith(("frames", "faces")) + or "input" in self._opt_name): + lbl = "picture" + elif browser == "folder" and "model" in self._opt_name: + lbl = "model" + else: + lbl = browser + img = get_images().icons[lbl] + action = getattr(self, "ask_" + browser) + cmd = partial(action, filepath=self.tk_var, filetypes=self.filetypes) + fileopn = tk.Button(frame, + image=img, + command=cmd, + relief=tk.SOLID, + bd=1, + bg=get_config().user_theme["group_panel"]["button_background"], + cursor="hand2") + _add_command(fileopn.cget("command"), cmd) + fileopn.pack(padx=1, side=tk.RIGHT) + _get_tooltip(fileopn, text=self.helptext[lbl]) + logger.debug("Added browser buttons: (action: %s, filetypes: %s", + action, self.filetypes) + + def set_context_action_option(self, options): + """ Set the tk_var for the source action option + that dictates the context sensitive file browser. """ + if self.browser != ["context"]: + return + actions = {opt.name: opt.tk_var for opt in options} + logger.debug("Settiong action option for opt %s", self.action_option) + self.action_option = actions[self.action_option] + + @staticmethod + def ask_folder(filepath, filetypes=None): + """ Pop-up to get path to a directory + :param filepath: tkinter StringVar object + that will store the path to a directory. + :param filetypes: Unused argument to allow + filetypes to be given in ask_load(). """ + dirname = FileHandler("dir", filetypes).return_file + if dirname: + logger.debug(dirname) + filepath.set(dirname) + + @staticmethod + def ask_load(filepath, filetypes): + """ Pop-up to get path to a file """ + filename = FileHandler("filename", filetypes).return_file + if filename: + logger.debug(filename) + filepath.set(filename) + + @staticmethod + def ask_multi_load(filepath, filetypes): + """ Pop-up to get path to a file """ + filenames = FileHandler("filename_multi", filetypes).return_file + if filenames: + final_names = " ".join(f"\"{fname}\"" for fname in filenames) + logger.debug(final_names) + filepath.set(final_names) + + @staticmethod + def ask_save(filepath, filetypes=None): + """ Pop-up to get path to save a new file """ + filename = FileHandler("save_filename", filetypes).return_file + if filename: + logger.debug(filename) + filepath.set(filename) + + @staticmethod + def ask_nothing(filepath, filetypes=None): # pylint:disable=unused-argument + """ Method that does nothing, used for disabling open/save pop up """ + return + + def ask_context(self, filepath, filetypes): + """ Method to pop the correct dialog depending on context """ + logger.debug("Getting context filebrowser") + selected_action = self.action_option.get() + selected_variable = self.destination + filename = FileHandler("context", + filetypes, + command=self.command, + action=selected_action, + variable=selected_variable).return_file + if filename: + logger.debug(filename) + filepath.set(filename) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/custom_widgets.py b/lib/gui/custom_widgets.py new file mode 100644 index 0000000000..90652376eb --- /dev/null +++ b/lib/gui/custom_widgets.py @@ -0,0 +1,1023 @@ +#!/usr/bin/env python3 +""" Custom widgets for Faceswap GUI """ + +import logging +import platform +import re +import sys +import typing as T +import tkinter as tk +from tkinter import ttk, TclError + +import numpy as np + +from lib.utils import get_module_objects + +from .utils import get_config + +logger = logging.getLogger(__name__) + + +class ContextMenu(tk.Menu): # pylint:disable=too-many-ancestors + """ A Pop up menu to be triggered when right clicking on widgets that this menu has been + applied to. + + This widget provides a simple right click pop up menu to the widget passed in with `Cut`, + `Copy`, `Paste` and `Select all` menu items. + + Parameters + ---------- + widget: tkinter object + The widget to apply the :class:`ContextMenu` to + + Example + ------- + >>> text_box = ttk.Entry(parent) + >>> text_box.pack() + >>> right_click_menu = ContextMenu(text_box) + >>> right_click_menu.cm_bind() + """ + def __init__(self, widget): + logger.debug("Initializing %s: (widget_class: '%s')", + self.__class__.__name__, widget.winfo_class()) + super().__init__(tearoff=0) + self._widget = widget + self._standard_actions() + logger.debug("Initialized %s", self.__class__.__name__) + + def _standard_actions(self): + """ Standard menu actions """ + self.add_command(label="Cut", command=lambda: self._widget.event_generate("<>")) + self.add_command(label="Copy", command=lambda: self._widget.event_generate("<>")) + self.add_command(label="Paste", command=lambda: self._widget.event_generate("<>")) + self.add_separator() + self.add_command(label="Select all", command=self._select_all) + + def cm_bind(self): + """ Bind the menu to the given widgets Right Click event + + After associating a widget with this :class:`ContextMenu` this function should be called + to bind it to the right click button + """ + button = "" if platform.system() == "Darwin" else "" + logger.debug("Binding '%s' to '%s'", button, self._widget.winfo_class()) + self._widget.bind(button, lambda event: self.tk_popup(event.x_root, event.y_root)) + + def _select_all(self): + """ Select all for Text or Entry widgets """ + logger.debug("Selecting all for '%s'", self._widget.winfo_class()) + if self._widget.winfo_class() == "Text": + self._widget.focus_force() + self._widget.tag_add("sel", "1.0", "end") + else: + self._widget.focus_force() + self._widget.select_range(0, tk.END) + + +class RightClickMenu(tk.Menu): # pylint:disable=too-many-ancestors + """ A Pop up menu that can be bound to a right click mouse event to bring up a context menu + + Parameters + ---------- + labels: list + A list of label titles that will appear in the right click menu + actions: list + A list of python functions that are called when the corresponding label is clicked on + hotkeys: list, optional + The hotkeys corresponding to the labels. If using hotkeys, then there must be an entry in + the list for every label even if they don't all use hotkeys. Labels without a hotkey can be + an empty string or ``None``. Passing ``None`` instead of a list means that no actions will + be given hotkeys. NB: The hotkey is not bound by this class, that needs to be done in code. + Giving hotkeys here means that they will be displayed in the menu though. Default: ``None`` + """ + # TODO This should probably be merged with Context Menu + def __init__(self, labels, actions, hotkeys=None): + logger.debug("Initializing %s: (labels: %s, actions: %s)", self.__class__.__name__, labels, + actions) + super().__init__(tearoff=0) + self._labels = labels + self._actions = actions + self._hotkeys = hotkeys + self._create_menu() + logger.debug("Initialized %s", self.__class__.__name__) + + def _create_menu(self): + """ Create the menu based on :attr:`_labels` and :attr:`_actions`. """ + for idx, (label, action) in enumerate(zip(self._labels, self._actions)): + kwargs = {"label": label, "command": action} + if isinstance(self._hotkeys, (list, tuple)) and self._hotkeys[idx]: + kwargs["accelerator"] = self._hotkeys[idx] + self.add_command(**kwargs) + + def popup(self, event): + """ Pop up the right click menu. + + Parameters + ---------- + event: class:`tkinter.Event` + The tkinter mouse event calling this popup + """ + self.tk_popup(event.x_root, event.y_root) + + +class ConsoleOut(ttk.Frame): # pylint:disable=too-many-ancestors + """ The Console out section of the GUI. + + A Read only text box for displaying the output from stdout/stderr. + + All handling is internal to this method. To clear the console, the stored tkinter variable in + :attr:`~lib.gui.Config.tk_vars` ``console_clear`` should be triggered. + + Parameters + ---------- + parent: tkinter object + The Console's parent widget + debug: bool + ``True`` if console output should not be directed to this widget otherwise ``False`` + """ + + def __init__(self, parent, debug): + logger.debug("Initializing %s: (parent: %s, debug: %s)", + self.__class__.__name__, parent, debug) + super().__init__(parent, relief=tk.SOLID, padding=1, style="Console.TFrame") + self._theme = get_config().user_theme["console"] + self._console = _ReadOnlyText(self, relief=tk.FLAT) + rc_menu = ContextMenu(self._console) + rc_menu.cm_bind() + self._console_clear = get_config().tk_vars.console_clear + self._set_console_clear_var_trace() + self._debug = debug + self._build_console() + self._add_tags() + self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0), + fill=tk.BOTH, expand=True) + logger.debug("Initialized %s", self.__class__.__name__) + + def _set_console_clear_var_trace(self): + """ Set a trace on the console clear tkinter variable to trigger :func:`_clear` """ + logger.debug("Set clear trace") + self._console_clear.trace("w", self._clear) + + def _build_console(self): + """ Build and place the console and add stdout/stderr redirection """ + logger.debug("Build console") + self._console.config(width=100, + height=6, + bg=self._theme["background_color"], + fg=self._theme["stdout_color"]) + + scrollbar = ttk.Scrollbar(self, + command=self._console.yview, + style="Console.Vertical.TScrollbar") + self._console.configure(yscrollcommand=scrollbar.set) + + scrollbar.pack(side=tk.RIGHT, fill="y") + self._console.pack(side=tk.LEFT, anchor=tk.N, fill=tk.BOTH, expand=True) + self._redirect_console() + logger.debug("Built console") + + def _add_tags(self): + """ Add tags to text widget to color based on output """ + logger.debug("Adding text color tags") + self._console.tag_config("default", foreground=self._theme["stdout_color"]) + self._console.tag_config("stderr", foreground=self._theme["stderr_color"]) + self._console.tag_config("info", foreground=self._theme["info_color"]) + self._console.tag_config("verbose", foreground=self._theme["verbose_color"]) + self._console.tag_config("warning", foreground=self._theme["warning_color"]) + self._console.tag_config("critical", foreground=self._theme["critical_color"]) + self._console.tag_config("error", foreground=self._theme["error_color"]) + + def _redirect_console(self): + """ Redirect stdout/stderr to console Text Box """ + logger.debug("Redirect console") + if self._debug: + logger.info("Console debug activated. Outputting to main terminal") + else: + sys.stdout = _SysOutRouter(self._console, "stdout") + sys.stderr = _SysOutRouter(self._console, "stderr") + logger.debug("Redirected console") + + def _clear(self, *args): # pylint:disable=unused-argument + """ Clear the console output screen """ + logger.debug("Clear console") + if not self._console_clear.get(): + logger.debug("Console not set for clearing. Skipping") + return + self._console.delete(1.0, tk.END) + self._console_clear.set(False) + logger.debug("Cleared console") + + +class _ReadOnlyText(tk.Text): # pylint:disable=too-many-ancestors + """ A read only text widget. + + Standard tkinter Text widgets are read/write by default. As we want to make the console + display writable by the Faceswap process but not the user, we need to redirect its insert and + delete attributes. + + Source: https://stackoverflow.com/questions/3842155 + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.redirector = _WidgetRedirector(self) + self.insert = self.redirector.register("insert", lambda *args, **kw: "break") + self.delete = self.redirector.register("delete", lambda *args, **kw: "break") + + +class _SysOutRouter(): + """ Route stdout/stderr to the given text box. + + Parameters + ---------- + console: tkinter Object + The widget that will receive the output from stderr/stdout + out_type: ['stdout', 'stderr'] + The output type to redirect + """ + + def __init__(self, console, out_type): + logger.debug("Initializing %s: (console: %s, out_type: '%s')", + self.__class__.__name__, console, out_type) + self._console = console + self._out_type = out_type + self._recolor = re.compile(r".+?(\s\d+:\d+:\d+\s)(?P[A-Z]+)\s") + self._ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_tag(self, string): + """ Set the tag based on regex of log output """ + if self._out_type == "stderr": + # Output all stderr in red + return self._out_type + + output = self._recolor.match(string) + if not output: + return "default" + tag = output.groupdict()["lvl"].strip().lower() + return tag + + def write(self, string): + """ Capture stdout/stderr """ + string = self._ansi_escape.sub("", string) + self._console.insert(tk.END, string, self._get_tag(string)) + self._console.see(tk.END) + + @staticmethod + def flush(): + """ If flush is forced, send it to normal terminal """ + sys.__stdout__.flush() + + +class _WidgetRedirector: + """Support for redirecting arbitrary widget sub-commands. + + Some Tk operations don't normally pass through tkinter. For example, if a + character is inserted into a Text widget by pressing a key, a default Tk + binding to the widget's 'insert' operation is activated, and the Tk library + processes the insert without calling back into tkinter. + + Although a binding to could be made via tkinter, what we really want + to do is to hook the Tk 'insert' operation itself. For one thing, we want + a text.insert call in idle code to have the same effect as a key press. + + When a widget is instantiated, a Tcl command is created whose name is the + same as the path name widget._w. This command is used to invoke the various + widget operations, e.g. insert (for a Text widget). We are going to hook + this command and provide a facility ('register') to intercept the widget + operation. We will also intercept method calls on the tkinter class + instance that represents the tk widget. + + In IDLE, WidgetRedirector is used in Percolator to intercept Text + commands. The function being registered provides access to the top + of a Percolator chain. At the bottom of the chain is a call to the + original Tk widget operation. + + Attributes + ----------- + _operations: dict + Dictionary mapping operation name to new function. widget: the widget whose tcl command + is to be intercepted. + tk: widget.tk + A convenience attribute, probably not needed. + orig: str + new name of the original tcl command. + + Notes + ----- + Since renaming to orig fails with TclError when orig already exists, only one + WidgetDirector can exist for a given widget. + """ + def __init__(self, widget): + self._operations = {} + self.widget = widget # widget instance + self.tk_ = tk_ = widget.tk # widget's root + wgt = widget._w # pylint:disable=protected-access # widget's (full) Tk pathname + self.orig = wgt + "_orig" + # Rename the Tcl command within Tcl: + tk_.call("rename", wgt, self.orig) + # Create a new Tcl command whose name is the widget's path name, and + # whose action is to dispatch on the operation passed to the widget: + tk_.createcommand(wgt, self.dispatch) + + def __repr__(self): + return (f"{self.__class__.__name__}({self.widget.__class__.__name__}" + f"<{self.widget._w}>)") # pylint:disable=protected-access + + def close(self): + "de-register operations and revert redirection created by .__init__." + for operation in list(self._operations): + self.unregister(operation) + widget = self.widget + tk_ = widget.tk + wgt = widget._w # pylint:disable=protected-access + # Restore the original widget Tcl command. + tk_.deletecommand(wgt) + tk_.call("rename", self.orig, wgt) + del self.widget, self.tk_ # Should not be needed + # if instance is deleted after close, as in Percolator. + + def register(self, operation, function): + """Return _OriginalCommand(operation) after registering function. + + Registration adds an operation: function pair to ._operations. + It also adds a widget function attribute that masks the tkinter + class instance method. Method masking operates independently + from command dispatch. + + If a second function is registered for the same operation, the + first function is replaced in both places. + """ + self._operations[operation] = function + setattr(self.widget, operation, function) + return _OriginalCommand(self, operation) + + def unregister(self, operation): + """Return the function for the operation, or None. + + Deleting the instance attribute unmasks the class attribute. + """ + if operation in self._operations: + function = self._operations[operation] + del self._operations[operation] + try: + delattr(self.widget, operation) + except AttributeError: + pass + return function + return None + + def dispatch(self, operation, *args): + """Callback from Tcl which runs when the widget is referenced. + + If an operation has been registered in self._operations, apply the + associated function to the args passed into Tcl. Otherwise, pass the + operation through to Tk via the original Tcl function. + + Note that if a registered function is called, the operation is not + passed through to Tk. Apply the function returned by self.register() + to *args to accomplish that. + + """ + op_ = self._operations.get(operation) + try: + if op_: + return op_(*args) + return self.tk_.call((self.orig, operation) + args) + except TclError: + return "" + + +class _OriginalCommand: + """Callable for original tk command that has been redirected. + + Returned by .register; can be used in the function registered. + redirect = WidgetRedirector(text) + def my_insert(*args): + print("insert", args) + original_insert(*args) + original_insert = redirect.register("insert", my_insert) + """ + + def __init__(self, redirect, operation): + """Create .tk_call and .orig_and_operation for .__call__ method. + + .redirect and .operation store the input args for __repr__. + .tk and .orig copy attributes of .redirect (probably not needed). + """ + self.redirect = redirect + self.operation = operation + self.tk_ = redirect.tk_ # redundant with self.redirect + self.orig = redirect.orig # redundant with self.redirect + # These two could be deleted after checking recipient code. + self.tk_call = redirect.tk_.call + self.orig_and_operation = (redirect.orig, operation) + + def __repr__(self): + return f"{self.__class__.__name__}({self.redirect}, {self.operation})" + + def __call__(self, *args): + return self.tk_call(self.orig_and_operation + args) + + +class StatusBar(ttk.Frame): # pylint:disable=too-many-ancestors + """ Status Bar for displaying the Status Message and Progress Bar at the bottom of the GUI. + + Parameters + ---------- + parent: tkinter object + The parent tkinter widget that will hold the status bar + hide_status: bool, optional + ``True`` to hide the status message that appears at the far left hand side of the status + frame otherwise ``False``. Default: ``False`` + """ + + def __init__(self, parent: ttk.Frame, hide_status: bool = False) -> None: + super().__init__(parent) + self._frame = ttk.Frame(self) + self._message = tk.StringVar() + self._pbar_message = tk.StringVar() + self._pbar_position = tk.IntVar() + self._mode: T.Literal["indeterminate", "determinate"] = "determinate" + + self._message.set("Ready") + + self._status(hide_status) + self._pbar = self._progress_bar() + self.pack(side=tk.BOTTOM, fill=tk.X, expand=False) + self._frame.pack(padx=10, pady=2, fill=tk.X, expand=False) + + @property + def message(self) -> tk.StringVar: + """:class:`tkinter.StringVar`: The variable to hold the status bar message on the left + hand side of the status bar. """ + return self._message + + def _status(self, hide_status: bool) -> None: + """ Place Status label into left of the status bar. + + Parameters + ---------- + hide_status: bool, optional + ``True`` to hide the status message that appears at the far left hand side of the + status frame otherwise ``False`` + """ + if hide_status: + return + + statusframe = ttk.Frame(self._frame) + statusframe.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=False) + + lbltitle = ttk.Label(statusframe, text="Status:", width=6, anchor=tk.W) + lbltitle.pack(side=tk.LEFT, expand=False) + + lblstatus = ttk.Label(statusframe, + width=40, + textvariable=self._message, + anchor=tk.W) + lblstatus.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=True) + + def _progress_bar(self) -> ttk.Progressbar: + """ Place progress bar into right of the status bar. + + Returns + ------- + :class:`tkinter.ttk.Progressbar` + The progress bar object + """ + progressframe = ttk.Frame(self._frame) + progressframe.pack(side=tk.RIGHT, anchor=tk.E, fill=tk.X) + + lblmessage = ttk.Label(progressframe, textvariable=self._pbar_message) + lblmessage.pack(side=tk.LEFT, padx=3, fill=tk.X, expand=True) + + pbar = ttk.Progressbar(progressframe, + length=200, + variable=self._pbar_position, + maximum=100, + mode=self._mode) + pbar.pack(side=tk.LEFT, padx=2, fill=tk.X, expand=True) + pbar.pack_forget() + return pbar + + def start(self, mode: T.Literal["indeterminate", "determinate"]) -> None: + """ Set progress bar mode and display, + + Parameters + ---------- + mode: ["indeterminate", "determinate"] + The mode that the progress bar should be executed in + """ + self._set_mode(mode) + self._pbar.pack() + + def stop(self) -> None: + """ Reset progress bar and hide """ + self._pbar_message.set("") + self._pbar_position.set(0) + self._mode = "determinate" + self._set_mode(self._mode) + self._pbar.pack_forget() + + def _set_mode(self, mode: T.Literal["indeterminate", "determinate"]) -> None: + """ Set the progress bar mode + + Parameters + ---------- + mode: ["indeterminate", "determinate"] + The mode that the progress bar should be executed in + """ + self._mode = mode + self._pbar.config(mode=self._mode) + if mode == "indeterminate": + self._pbar.config(maximum=100) + self._pbar.start() + else: + self._pbar.stop() + self._pbar.config(maximum=100) + + def set_mode(self, mode: T.Literal["indeterminate", "determinate"]) -> None: + """ Set the mode of a currently displayed progress bar and reset position to 0. + + If the given mode is the same as the currently configured mode, returns without performing + any action. + + Parameters + ---------- + mode: ["indeterminate", "determinate"] + The mode that the progress bar should be set to + """ + if mode == self._mode: + return + self.stop() + self.start(mode) + + def progress_update(self, message: str, position: int, update_position: bool = True) -> None: + """ Update the GUIs progress bar and position. + + Parameters + ---------- + message: str + The message to display next to the progress bar + position: int + The position that the progress bar should be set to + update_position: bool, optional + If ``True`` then the progress bar will be updated to the position given in + :attr:`position`. If ``False`` the progress bar will not be updates. Default: ``True`` + """ + self._pbar_message.set(message) + if update_position: + self._pbar_position.set(position) + + +class Tooltip: # pylint:disable=too-few-public-methods + """ Create a tooltip for a given widget as the mouse goes on it. + + Parameters + ---------- + widget: tkinter object + The widget to apply the tool-tip to + pad: tuple, optional + (left, top, right, bottom) padding for the tool-tip. Default: (5, 3, 5, 3) + text: str, optional + The text to be displayed in the tool-tip. Default: 'widget info' + text_variable: :class:`tkinter.strVar`, optional + The text variable to use for dynamic help text. Appended after the contents of :attr:`text` + if provided. Default: ``None`` + wait_time: int, optional + The time in milliseconds to wait before showing the tool-tip. Default: 400 + wrap_length: int, optional + The text length for each line before wrapping. Default: 250 + + Example + ------- + >>> button = ttk.Button(parent, text="Exit") + >>> Tooltip(button, text="Click to exit") + >>> button.pack() + + Notes + ----- + Adapted from StackOverflow: http://stackoverflow.com/questions/3221956 and + http://www.daniweb.com/programming/software-development/code/484591/a-tooltip-class-for-tkinter + """ + def __init__(self, widget, *, pad=(5, 3, 5, 3), text="widget info", + text_variable=None, wait_time=400, wrap_length=250): + + self._waittime = wait_time # in milliseconds, originally 500 + self.wrap_length = wrap_length # in pixels, originally 180 + self._widget = widget + self._text = text + self._text_variable = text_variable + self._widget.bind("", self._on_enter) + self._widget.bind("", self._on_leave) + self._widget.bind("", self._on_leave) + self._theme = get_config().user_theme["tooltip"] + self._pad = pad + self._ident = None + self._topwidget = None + + def _on_enter(self, event=None): # pylint:disable=unused-argument + """ Schedule on an enter event """ + self._schedule() + + def _on_leave(self, event=None): # pylint:disable=unused-argument + """ remove schedule on a leave event """ + self._unschedule() + self._hide() + + def _schedule(self): + """ Show the tooltip after wait period """ + self._unschedule() + self._ident = self._widget.after(self._waittime, self._show) + + def _unschedule(self): + """ Hide the tooltip """ + id_ = self._ident + self._ident = None + if id_: + self._widget.after_cancel(id_) + + def _show(self): + """ Show the tooltip """ + def tip_pos_calculator(widget, label, # pylint:disable=too-many-locals + *, + tip_delta=(10, 5), pad=(5, 3, 5, 3)): + """ Calculate the tooltip position """ + + s_width, s_height = widget.winfo_screenwidth(), widget.winfo_screenheight() + + width, height = (pad[0] + label.winfo_reqwidth() + pad[2], + pad[1] + label.winfo_reqheight() + pad[3]) + + mouse_x, mouse_y = widget.winfo_pointerxy() + + x_1, y_1 = mouse_x + tip_delta[0], mouse_y + tip_delta[1] + x_2, y_2 = x_1 + width, y_1 + height + + x_delta = max(x_2 - s_width, 0) + y_delta = max(y_2 - s_height, 0) + + offscreen = (x_delta, y_delta) != (0, 0) + + if offscreen: + + if x_delta: + x_1 = mouse_x - tip_delta[0] - width + + if y_delta: + y_1 = mouse_y - tip_delta[1] - height + + offscreen_again = y_1 < 0 # out on the top + + if offscreen_again: + # No further checks will be done. + # TIP: + # A further mod might auto-magically augment the wrap length when the tooltip is + # too high to be kept inside the screen. + y_1 = 0 + + return x_1, y_1 + + pad = self._pad + widget = self._widget + + # Creates a top level window + self._topwidget = tk.Toplevel(widget) + if platform.system() == "Darwin": + # For Mac OS + self._topwidget.tk.call("::tk::unsupported::MacWindowStyle", + "style", self._topwidget._w, # pylint:disable=protected-access + "help", "none") + + # Leaves only the label and removes the app window + self._topwidget.wm_overrideredirect(True) + + win = tk.Frame(self._topwidget, + background=self._theme["background_color"], + highlightbackground=self._theme["border_color"], + highlightcolor=self._theme["border_color"], + highlightthickness=1, + borderwidth=0) + + text = self._text + if self._text_variable and self._text_variable.get(): + text += f"\n\nCurrent value: '{self._text_variable.get()}'" + label = tk.Label(win, + text=text, + justify=tk.LEFT, + background=self._theme["background_color"], + foreground=self._theme["font_color"], + relief=tk.SOLID, + borderwidth=0, + wraplength=self.wrap_length) + + label.grid(padx=(pad[0], pad[2]), + pady=(pad[1], pad[3]), + sticky=tk.NSEW) + win.grid() + + xpos, ypos = tip_pos_calculator(widget, label) + + self._topwidget.wm_geometry(f"+{xpos}+{ypos}") + + def _hide(self): + """ Hide the tooltip """ + topwidget = self._topwidget + if topwidget: + topwidget.destroy() + self._topwidget = None + + +class MultiOption(ttk.Checkbutton): # pylint:disable=too-many-ancestors + """ Similar to the standard :class:`ttk.Radio` widget, but with the ability to select + multiple pre-defined options. Selected options are generated as `nargs` for the argument + parser to consume. + + Parameters + ---------- + parent: :class:`ttk.Frame` + The tkinter parent widget for the check button + value: str + The raw option value for this check button + variable: :class:`tkinter.StingVar` + The master variable for the group of check buttons that this check button will belong to. + The output of this variable will be a string containing a space separated list of the + selected check button options + """ + def __init__(self, parent, value, variable, **kwargs): + self._tk_var = tk.BooleanVar() + self._tk_var.set(value in variable.get().split()) + super().__init__(parent, variable=self._tk_var, **kwargs) + self._value = value + self._master_variable = variable + self._tk_var.trace("w", self._on_update) + self._master_variable.trace("w", self._on_master_update) + + @property + def _master_list(self): + """ list: The contents of the check box group's :attr:`_master_variable` in list form. + Selected check boxes will appear in this list. """ + retval = self._master_variable.get().split() + logger.trace(retval) + return retval + + @property + def _master_needs_update(self): + """ bool: ``True`` if :attr:`_master_variable` requires updating otherwise ``False``. """ + active = self._tk_var.get() + retval = ((active and self._value not in self._master_list) or + (not active and self._value in self._master_list)) + logger.trace(retval) + return retval + + def _on_update(self, *args): # pylint:disable=unused-argument + """ Update the master variable on a check button change. + + The value for this checked option is added or removed from the :attr:`_master_variable` + on a ``True``, ``False`` change for this check button. + + Parameters + ---------- + args: tuple + Required for variable callback, but unused + """ + if not self._master_needs_update: + return + new_vals = self._master_list + [self._value] if self._tk_var.get() else [ + val + for val in self._master_list + if val != self._value] + val = " ".join(new_vals) + logger.trace("Setting master variable to: %s", val) + self._master_variable.set(val) + + def _on_master_update(self, *args): # pylint:disable=unused-argument + """ Update the check button on a master variable change (e.g. load .fsw file in the GUI). + + The value for this option is set to ``True`` or ``False`` depending on it's existence in + the :attr:`_master_variable` + + Parameters + ---------- + args: tuple + Required for variable callback, but unused + """ + if not self._master_needs_update: + return + state = self._value in self._master_list + logger.trace("Setting '%s' to %s", self._value, state) + self._tk_var.set(state) + + +class PopupProgress(tk.Toplevel): + """ A simple pop up progress bar that appears of the center of the root window. + + When this is called, the root will be disabled until the :func:`close` method is called. + + Parameters + ---------- + title: str + The title to appear above the progress bar + total: int or float + The total count of items for the progress bar + + Example + ------- + >>> total = 100 + >>> progress = PopupProgress("My title...", total) + >>> for i in range(total): + >>> progress.update(1) + >>> progress.close() + """ + def __init__(self, title, total): + super().__init__() + self._total = total + if platform.system() == "Darwin": # For Mac OS + self.tk.call("::tk::unsupported::MacWindowStyle", + "style", self._w, # pylint:disable=protected-access + "help", "none") + # Leaves only the label and removes the app window + self.wm_overrideredirect(True) + self.attributes('-topmost', 'true') + self.transient() + + self._lbl_title = self._set_title(title) + self._progress_bar = self._get_progress_bar() + + offset = np.array((self.master.winfo_rootx(), self.master.winfo_rooty())) + # TODO find way to get dimensions of the pop up without it flicking onto the screen + self.update_idletasks() + center = np.array(( + (self.master.winfo_width() // 2) - (self.winfo_width() // 2), + (self.master.winfo_height() // 2) - (self.winfo_height() // 2))) + offset + self.wm_geometry(f"+{center[0]}+{center[1]}") + get_config().set_cursor_busy() + self.grab_set() + + @property + def progress_bar(self): + """ :class:`tkinter.ttk.Progressbar`: The progress bar object within the pop up window. """ + return self._progress_bar + + def _set_title(self, title): + """ Set the initial title of the pop up progress bar. + + Parameters + ---------- + title: str + The title to appear above the progress bar + + Returns + ------- + :class:`tkinter.ttk.Label` + The heading label for the progress bar + """ + frame = ttk.Frame(self) + frame.pack(side=tk.TOP, padx=5, pady=5) + lbl = ttk.Label(frame, text=title) + lbl.pack(side=tk.TOP, pady=(5, 0), expand=True, fill=tk.X) + return lbl + + def _get_progress_bar(self): + """ Set up the progress bar with the supplied total. + + Returns + ------- + :class:`tkinter.ttk.Progressbar` + The configured progress bar for the pop up window + """ + frame = ttk.Frame(self) + frame.pack(side=tk.BOTTOM, padx=5, pady=(0, 5)) + pbar = ttk.Progressbar(frame, + length=400, + maximum=self._total, + mode="determinate") + pbar.pack(side=tk.LEFT) + return pbar + + def step(self, amount): + """ Increment the progress bar. + + Parameters + ---------- + amount: int or float + The amount to increment the progress bar by + """ + self._progress_bar.step(amount) + self._progress_bar.update_idletasks() + + def stop(self): + """ Stop the progress bar, re-enable the root window and destroy the pop up window. """ + self._progress_bar.stop() + get_config().set_cursor_default() + self.grab_release() + self.destroy() + + def update_title(self, title): + """ Update the title that displays above the progress bar. + + Parameters + ---------- + title: str + The title to appear above the progress bar + """ + self._lbl_title.config(text=title) + self._lbl_title.update_idletasks() + + +class ToggledFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """ A collapsible and expandable frame. + + The frame contains a header given in the text argument, and adds an expand contract button. + Clicking on the header will expand and contract the sub-frame below + + Parameters + ---------- + text: str + The text to appear in the Toggle Frame header + theme: str, optional + The theme to use for the panel header. Default: `"CPanel"` + subframe_style: str, optional + The name of the ttk Style to use for the sub frame. Default: ``None`` + toggle_var: :class:`tk.BooleanVar`, optional + If provided, this variable will control the expanded (``True``) and minimized (``False``) + state of the widget. Set to None to create the variable internally. Default: ``None`` + """ + def __init__(self, parent, *args, text="", theme="CPanel", toggle_var=None, **kwargs): + logger.debug("Initializing %s: (parent: %s, text: %s, theme: %s, toggle_var: %s)", + self.__class__.__name__, parent, text, theme, toggle_var) + + theme = "CPanel" if not theme else theme + theme = theme[:-1] if theme[-1] == "." else theme + super().__init__(parent, *args, style=f"{theme}.Group.TFrame", **kwargs) + self._text = text + + if toggle_var: + self._toggle_var = toggle_var + else: + self._toggle_var = tk.BooleanVar() + self._toggle_var.set(1) + self._icon_var = tk.StringVar() + self._icon_var.set("-" if self.is_expanded else "+") + + self._build_header(theme) + + self.sub_frame = ttk.Frame(self, style=f"{theme}.Subframe.Group.TFrame", padding=1) + + if self.is_expanded: + self.sub_frame.pack(fill=tk.X, expand=True) + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def is_expanded(self): + """ bool: ``True`` if the Toggle Frame is expanded. ``False`` if it is minimized. """ + return self._toggle_var.get() + + def _build_header(self, theme): + """ The Header row. Contains the title text and is made clickable to expand and contract + the sub-frame. + + Parameters + theme: str + The theme to use for the panel header + """ + header_frame = ttk.Frame(self, name="toggledframe_header") + + text_label = ttk.Label(header_frame, + name="toggledframe_headerlbl", + text=self._text, + style=f"{theme}.Groupheader.TLabel", + cursor="hand2") + toggle_button = ttk.Label(header_frame, + name="toggledframe_headerbtn", + textvariable=self._icon_var, + style=f"{theme}.Groupheader.TLabel", + cursor="hand2", + width=2) + text_label.bind("", self._toggle) + toggle_button.bind("", self._toggle) + + text_label.pack(side=tk.LEFT, fill=tk.X, expand=True) + toggle_button.pack(side=tk.RIGHT) + header_frame.pack(fill=tk.X, expand=True) + + def _toggle(self, event): # pylint:disable=unused-argument + """ Toggle the sub-frame between contracted or expanded, and update the toggle icon + appropriately. + + Parameters + ---------- + event: tkinter event + Required but unused + """ + if self.is_expanded: + self.sub_frame.forget() + self._icon_var.set("+") + self._toggle_var.set(0) + else: + self.sub_frame.pack(fill=tk.X, expand=True) + self._icon_var.set("-") + self._toggle_var.set(1) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/display.py b/lib/gui/display.py index bb70b63423..0c9dd07bde 100644 --- a/lib/gui/display.py +++ b/lib/gui/display.py @@ -1,121 +1,195 @@ #!/usr/bin python3 """ Display Frame of the Faceswap GUI - What is displayed in the Display Frame varies - depending on what tasked is being run """ +This is the large right hand area of the GUI. At default, the Analysis tab is always displayed +here. Further optional tabs will also be displayed depending on the currently executing Faceswap +task. """ import logging +import gettext import tkinter as tk from tkinter import ttk +from lib.logger import parse_class_init +from lib.utils import get_module_objects + from .display_analysis import Analysis from .display_command import GraphDisplay, PreviewExtract, PreviewTrain from .utils import get_config -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext -class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors - """ The display tabs """ +class DisplayNotebook(ttk.Notebook): # pylint:disable=too-many-ancestors + """ The tkinter Notebook that holds the display items. + + Parameters + ---------- + parent: :class:`tk.PanedWindow` + The paned window that holds the Display Notebook + """ def __init__(self, parent): - logger.debug("Initializing %s", self.__class__.__name__) - ttk.Notebook.__init__(self, parent, width=780) + logger.debug(parse_class_init(locals())) + super().__init__(parent) parent.add(self) tk_vars = get_config().tk_vars - self.wrapper_var = tk_vars["display"] - self.runningtask = tk_vars["runningtask"] - - self.set_wrapper_var_trace() - self.add_static_tabs() - self.static_tabs = [child for child in self.tabs()] + self._wrapper_var = tk_vars.display + self._running_task = tk_vars.running_task + + self._set_wrapper_var_trace() + self._add_static_tabs() + # pylint:disable=unnecessary-comprehension + self._static_tabs = [child for child in self.tabs()] + self.bind("<>", self._on_tab_change) logger.debug("Initialized %s", self.__class__.__name__) - def set_wrapper_var_trace(self): - """ Set the trigger actions for the display vars - when they have been triggered in the Process Wrapper """ + @property + def running_task(self): + """ :class:`tkinter.BooleanVar`: The global tkinter variable that indicates whether a + Faceswap task is currently running or not. """ + return self._running_task + + def _set_wrapper_var_trace(self): + """ Sets the trigger to update the displayed notebook's pages when the global tkinter + variable `display` is updated in the :class:`~lib.gui.wrapper.ProcessWrapper`. """ logger.debug("Setting wrapper var trace") - self.wrapper_var.trace("w", self.update_displaybook) + self._wrapper_var.trace("w", self._update_displaybook) + + def _add_static_tabs(self): + """ Add the tabs to the Display Notebook that are permanently displayed. - def add_static_tabs(self): - """ Add tabs that are permanently available """ + Currently this is just the `Analysis` tab. + """ logger.debug("Adding static tabs") for tab in ("job queue", "analysis"): if tab == "job queue": continue # Not yet implemented if tab == "analysis": helptext = {"stats": - "Summary statistics for each training session"} + _("Summary statistics for each training session")} frame = Analysis(self, tab, helptext) else: - frame = self.add_frame() + frame = self._add_frame() self.add(frame, text=tab.title()) - def add_frame(self): - """ Add a single frame for holding tab's contents """ + def _add_frame(self): + """ Add a single frame for holding a static tab's contents. + + Returns + ------- + ttk.Frame + The frame, packed into position + """ logger.debug("Adding frame") frame = ttk.Frame(self) frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) return frame - def command_display(self, command): - """ Select what to display based on incoming - command """ - build_tabs = getattr(self, "{}_tabs".format(command)) + def _command_display(self, command): + """ Build the relevant command specific tabs based on the incoming Faceswap command. + + Parameters + ---------- + command: str + The Faceswap command that is being executed + """ + build_tabs = getattr(self, f"_{command}_tabs") build_tabs() - def extract_tabs(self): - """ Build the extract tabs """ + def _extract_tabs(self, command="extract"): + """ Build the display tabs that are used for Faceswap extract and convert tasks. + + Notes + ----- + The same display tabs are used for both convert and extract tasks. + + command: [`"extract"`, `"convert"`], optional + The command that the display tabs are being built for. Default: `"extract"` + + """ logger.debug("Build extract tabs") - helptext = ("Updates preview from output every 5 " - "seconds to limit disk contention") - PreviewExtract(self, "preview", helptext, 5000) + helptext = _("Preview updates every 5 seconds") + PreviewExtract(self, "preview", helptext, 5000, command) logger.debug("Built extract tabs") - def train_tabs(self): - """ Build the train tabs """ + def _train_tabs(self): + """ Build the display tabs that are used for the Faceswap train task.""" logger.debug("Build train tabs") for tab in ("graph", "preview"): if tab == "graph": - helptext = "Graph showing Loss vs Iterations" + helptext = _("Graph showing Loss vs Iterations") GraphDisplay(self, "graph", helptext, 5000) elif tab == "preview": - helptext = "Training preview. Updated on every save iteration" + helptext = _("Training preview. Updated on every save iteration") PreviewTrain(self, "preview", helptext, 1000) logger.debug("Built train tabs") - def convert_tabs(self): - """ Build the convert tabs - Currently identical to Extract, so just call that """ + def _convert_tabs(self): + """ Build the display tabs that are used for the Faceswap convert task. + + Notes + ----- + The tabs displayed are the same as used for extract, so :func:`_extract_tabs` is called. + """ logger.debug("Build convert tabs") - self.extract_tabs() + self._extract_tabs(command="convert") logger.debug("Built convert tabs") - def remove_tabs(self): - """ Remove all command specific tabs """ + def _remove_tabs(self): + """ Remove all optional displayed command specific tabs from the notebook. """ for child in self.tabs(): - if child in self.static_tabs: + if child in self._static_tabs: continue logger.debug("removing child: %s", child) child_name = child.split(".")[-1] - child_object = self.children[child_name] - self.destroy_tabs_children(child_object) + child_object = self.children.get(child_name) # returns the OptionalDisplayPage object + if not child_object: + continue + child_object.close() # Call the OptionalDisplayPage close() method self.forget(child) - @staticmethod - def destroy_tabs_children(tab): - """ Destroy all tabs children - Children must be destroyed as forget only hides display + def _update_displaybook(self, *args): # pylint:disable=unused-argument + """ Callback to be executed when the global tkinter variable `display` + (:attr:`wrapper_var`) is updated when a Faceswap task is executed. + + Currently only updates when a core faceswap task (extract, train or convert) is executed. + + Parameters + ---------- + args: tuple + Required for tkinter callback events, but unused. + """ - logger.debug("Destroying children for tab: %s", tab) - for child in tab.winfo_children(): - logger.debug("Destroying child: %s", child) - child.destroy() - - def update_displaybook(self, *args): # pylint: disable=unused-argument - """ Set the display tabs based on executing task """ - command = self.wrapper_var.get() - self.remove_tabs() + command = self._wrapper_var.get() + self._remove_tabs() if not command or command not in ("extract", "train", "convert"): return - self.command_display(command) + self._command_display(command) + + def _on_tab_change(self, event): # pylint:disable=unused-argument + """ Event trigger for tab change events. + + Calls the selected tabs :func:`on_tab_select` method, if it exists, otherwise returns. + + Parameters + ---------- + event: tkinter callback event + Required, but unused + """ + selected = self.select().split(".")[-1] + logger.debug("Selected tab: %s", selected) + selected_object = self.children[selected] + if hasattr(selected_object, "on_tab_select"): + logger.debug("Calling on_tab_select for '%s'", selected_object) + selected_object.on_tab_select() + else: + logger.debug("Object does not have on_tab_select method. Returning: '%s'", + selected_object) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/display_analysis.py b/lib/gui/display_analysis.py index 4646fe902a..127bf72a74 100644 --- a/lib/gui/display_analysis.py +++ b/lib/gui/display_analysis.py @@ -2,239 +2,457 @@ """ Analysis tab of Display Frame of the Faceswap GUI """ import csv +import gettext import logging import os import tkinter as tk from tkinter import ttk -from .display_graph import SessionGraph +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .custom_widgets import Tooltip from .display_page import DisplayPage -from .stats import Calculations, Session -from .tooltip import Tooltip -from .utils import FileHandler, get_config, get_images +from .popup_session import SessionPopUp +from .analysis import Session +from .utils import FileHandler, get_config, get_images, LongRunningTask + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class Analysis(DisplayPage): # pylint:disable=too-many-ancestors + """ Session Analysis Tab. + + The area of the GUI that holds the session summary stats for model training sessions. + + Parameters + ---------- + parent: :class:`lib.gui.display.DisplayNotebook` + The :class:`ttk.Notebook` that holds this session summary statistics page + tab_name: str + The name of the tab to be displayed in the notebook + helptext: str + The help text to display for the summary statistics page + """ + def __init__(self, parent, tab_name, helptext): + logger.debug(parse_class_init(locals())) + super().__init__(parent, tab_name, helptext) + self._summary = None + + self._reset_session_info() + _Options(self) + self._stats = self._get_main_frame() + + self._thread = None # Thread for compiling stats data in background + self._set_callbacks() + logger.debug("Initialized: %s", self.__class__.__name__) -logger = logging.getLogger(__name__) # pylint: disable=invalid-name + def set_vars(self): + """ Set the analysis specific tkinter variables to :attr:`vars`. + + The tracked variables are the global variables that: + * Trigger when a graph refresh has been requested. + * Trigger training is commenced or halted + * The variable holding the location of the current Tensorboard log folder. + + Returns + ------- + dict + The dictionary of variable names to tkinter variables + """ + return {"selected_id": tk.StringVar(), + "refresh_graph": get_config().tk_vars.refresh_graph, + "is_training": get_config().tk_vars.is_training, + "analysis_folder": get_config().tk_vars.analysis_folder} + + def on_tab_select(self): + """ Callback for when the analysis tab is selected. + + Update the statistics with the latest values. + """ + logger.debug("Analysis update callback received") + self._reset_session() + + def _get_main_frame(self): + """ Get the main frame to the sub-notebook to hold stats and session data. + + Returns + ------- + :class:`StatsData` + The frame that holds the analysis statistics for the Analysis notebook page + """ + logger.debug("Getting main stats frame") + mainframe = self.subnotebook_add_page("stats") + retval = StatsData(mainframe, self.vars["selected_id"], self.helptext["stats"]) + logger.debug("got main frame: %s", retval) + return retval + def _set_callbacks(self): + """ Adds callbacks to update the analysis summary statistics and add them to :attr:`vars` -class Analysis(DisplayPage): # pylint: disable=too-many-ancestors - """ Session analysis tab """ - def __init__(self, parent, tabname, helptext): - logger.debug("Initializing: %s: (parent, %s, tabname: '%s', helptext: '%s')", - self.__class__.__name__, parent, tabname, helptext) - super().__init__(parent, tabname, helptext) + Training graph refresh - Updates the stats for the current training session when the graph + has been updated. - self.summary = None - self.session = None - self.add_options() - self.add_main_frame() - logger.debug("Initialized: %s", self.__class__.__name__) + When the analysis folder has been populated - Updates the stats from that folder. + """ + self.vars["refresh_graph"].trace("w", self._update_current_session) + self.vars["analysis_folder"].trace("w", self._populate_from_folder) - def set_vars(self): - """ Analysis specific vars """ - selected_id = tk.StringVar() - return {"selected_id": selected_id} - - def add_main_frame(self): - """ Add the main frame to the subnotebook - to hold stats and session data """ - logger.debug("Adding main frame") - mainframe = self.subnotebook_add_page("stats") - self.stats = StatsData(mainframe, - self.vars["selected_id"], - self.helptext["stats"]) - logger.debug("Added main frame") - - def add_options(self): - """ Add the options bar """ - logger.debug("Adding options") - self.reset_session_info() - options = Options(self) - options.add_options() - logger.debug("Added options") - - def reset_session_info(self): + def _update_current_session(self, *args): # pylint:disable=unused-argument + """ Update the currently training session data on a graph update callback. """ + if not self.vars["refresh_graph"].get(): + return + if not self._tab_is_active: + logger.debug("Analyis tab not selected. Not updating stats") + return + logger.debug("Analysis update callback received") + self._reset_session() + + def _reset_session_info(self): """ Reset the session info status to default """ logger.debug("Resetting session info") self.set_info("No session data loaded") - def load_session(self): - """ Load previously saved sessions """ - logger.debug("Loading session") - self.clear_session() - fullpath = FileHandler("filename", "state").retfile - if not fullpath: + def _populate_from_folder(self, *args): # pylint:disable=unused-argument + """ Populate the Analysis tab from a model folder. + + Triggered when :attr:`vars` ``analysis_folder`` variable is is set. + """ + if Session.is_training: return - logger.debug("state_file: '%s'", fullpath) - model_dir, state_file = os.path.split(fullpath) - logger.debug("model_dir: '%s'", model_dir) - model_name = self.get_model_name(model_dir, state_file) - if not model_name: + + folder = self.vars["analysis_folder"].get() + if not folder or not os.path.isdir(folder): + logger.debug("Not a valid folder") + self._clear_session() return - self.session = Session(model_dir=model_dir, model_name=model_name) - self.session.initialize_session(is_training=False) - msg = os.path.split(state_file)[0] - if len(msg) > 70: - msg = "...{}".format(msg[-70:]) - self.set_session_summary(msg) - @staticmethod - def get_model_name(model_dir, state_file): - """ Get the state file from the model directory """ + state_files = [fname + for fname in os.listdir(folder) + if fname.endswith("_state.json")] + if not state_files: + logger.debug("No state files found in folder: '%s'", folder) + self._clear_session() + return + + state_file = state_files[0] + if len(state_files) > 1: + logger.debug("Multiple models found. Selecting: '%s'", state_file) + + if self._thread is None: + self._load_session(full_path=os.path.join(folder, state_file)) + + @classmethod + def _get_model_name(cls, model_dir, state_file): + """ Obtain the model name from a state file's file name. + + Parameters + ---------- + model_dir: str + The folder that the model's state file resides in + state_file: str + The filename of the model's state file + + Returns + ------- + str or ``None`` + The name of the model extracted from the state file's file name or ``None`` if no + log folders were found in the model folder + """ logger.debug("Getting model name") model_name = state_file.replace("_state.json", "") logger.debug("model_name: %s", model_name) - logs_dir = os.path.join(model_dir, "{}_logs".format(model_name)) + logs_dir = os.path.join(model_dir, f"{model_name}_logs") if not os.path.isdir(logs_dir): logger.warning("No logs folder found in folder: '%s'", logs_dir) return None return model_name - def reset_session(self): - """ Reset currently training sessions """ + def _set_session_summary(self, message): + """ Set the summary data and info message. + + Parameters + ---------- + message: str + The information message to set + """ + if self._thread is None: + logger.debug("Setting session summary. (message: '%s')", message) + self._thread = LongRunningTask(target=self._summarise_data, + args=(Session, ), + widget=self) + self._thread.start() + self.after(1000, lambda msg=message: self._set_session_summary(msg)) + elif not self._thread.complete.is_set(): + logger.debug("Data not yet available") + self.after(1000, lambda msg=message: self._set_session_summary(msg)) + else: + logger.debug("Retrieving data from thread") + result = self._thread.get_result() + del self._thread + self._thread = None + if not result: + logger.debug("No result from session summary. Clearing analysis view") + self._clear_session() + return + self._summary = result + self.set_info(f"Session: {message}") + self._stats.tree_insert_data(self._summary) + + @classmethod + def _summarise_data(cls, session): + """ Summarize data in a LongRunningThread as it can take a while. + + Parameters + ---------- + session: :class:`lib.gui.analysis.Session` + The session object to generate the summary for + """ + return session.full_summary + + def _clear_session(self): + """ Clear the currently displayed analysis data from the Tree-View. """ + logger.debug("Clearing session") + if not Session.is_loaded: + logger.trace("No session loaded. Returning") + return + self._summary = None + self._stats.tree_clear() + if not Session.is_training: + self._reset_session_info() + Session.clear() + + def _load_session(self, full_path=None): + """ Load the session statistics from a model's state file into the Analysis tab of the GUI + display window. + + If a model's log files cannot be found within the model folder then the session is cleared. + + Parameters + ---------- + full_path: str, optional + The path to the state file to load session information from. If this is ``None`` then + a file dialog is popped to enable the user to choose a state file. Default: ``None`` + """ + logger.debug("Loading session") + if full_path is None: + full_path = FileHandler("filename", "state").return_file + if not full_path: + return + self._clear_session() + logger.debug("state_file: '%s'", full_path) + model_dir, state_file = os.path.split(full_path) + logger.debug("model_dir: '%s'", model_dir) + model_name = self._get_model_name(model_dir, state_file) + if not model_name: + return + Session.initialize_session(model_dir, model_name, is_training=False) + msg = full_path + if len(msg) > 70: + msg = f"...{msg[-70:]}" + self._set_session_summary(msg) + + def _reset_session(self): + """ Reset currently training sessions. Clears the current session and loads in the latest + data. """ logger.debug("Reset current training session") - self.clear_session() - session = get_config().session - if not session.initialized: + if not Session.is_training: logger.debug("Training not running") - print("Training not running") return - msg = "Currently running training session" - self.session = session - self.set_session_summary(msg) - - def set_session_summary(self, message): - """ Set the summary data and info message """ - logger.debug("Setting session summary. (message: '%s')", message) - self.summary = self.session.full_summary - self.set_info("Session: {}".format(message)) - self.stats.session = self.session - self.stats.tree_insert_data(self.summary) - - def clear_session(self): - """ Clear sessions stats """ - logger.debug("Clearing session") - self.summary = None - self.stats.session = None - self.stats.tree_clear() - self.reset_session_info() + if Session.logging_disabled: + logger.trace("Logging disabled. Not triggering analysis update") + return + self._clear_session() + self._set_session_summary("Currently running training session") - def save_session(self): - """ Save sessions stats to csv """ + def _save_session(self): + """ Launch a file dialog pop-up to save the current analysis data to a CSV file. """ logger.debug("Saving session") - if not self.summary: + if not self._summary: logger.debug("No summary data loaded. Nothing to save") print("No summary data loaded. Nothing to save") return - savefile = FileHandler("save", "csv").retfile + savefile = FileHandler("save", "csv").return_file if not savefile: logger.debug("No save file. Returning") return - write_dicts = [val for val in self.summary.values()] - fieldnames = sorted(key for key in write_dicts[0].keys()) - logger.debug("Saving to: '%s'", savefile) + fieldnames = sorted(key for key in self._summary[0].keys()) with savefile as outfile: csvout = csv.DictWriter(outfile, fieldnames) csvout.writeheader() - for row in write_dicts: + for row in self._summary: csvout.writerow(row) -class Options(): - """ Options bar of Analysis tab """ +class _Options(): # pylint:disable=too-few-public-methods + """ Options buttons for the Analysis tab. + + Parameters + ---------- + parent: :class:`Analysis` + The Analysis Display Tab that holds the options buttons + """ def __init__(self, parent): - logger.debug("Initializing: %s", self.__class__.__name__) - self.optsframe = parent.optsframe - self.parent = parent + logger.debug(parse_class_init(locals())) + self._parent = parent + self._buttons = self._add_buttons() + self._add_training_callback() logger.debug("Initialized: %s", self.__class__.__name__) - def add_options(self): - """ Add the display tab options """ - self.add_buttons() + def _add_buttons(self): + """ Add the option buttons. - def add_buttons(self): - """ Add the option buttons """ - for btntype in ("reset", "clear", "save", "load"): + Returns + ------- + dict + The button names to button objects + """ + buttons = {} + for btntype in ("clear", "save", "load"): logger.debug("Adding button: '%s'", btntype) - cmd = getattr(self.parent, "{}_session".format(btntype)) - btn = ttk.Button(self.optsframe, + cmd = getattr(self._parent, f"_{btntype}_session") + btn = ttk.Button(self._parent.optsframe, image=get_images().icons[btntype], command=cmd) btn.pack(padx=2, side=tk.RIGHT) - hlp = self.set_help(btntype) - Tooltip(btn, text=hlp, wraplength=200) - - @staticmethod - def set_help(btntype): - """ Set the helptext for option buttons """ + hlp = self._set_help(btntype) + Tooltip(btn, text=hlp, wrap_length=200) + buttons[btntype] = btn + logger.debug("buttons: %s", buttons) + return buttons + + @classmethod + def _set_help(cls, button_type): + """ Set the help text for option buttons. + + Parameters + ---------- + button_type: {"reload", "clear", "save", "load"} + The type of button to set the help text for + """ logger.debug("Setting help") hlp = "" - if btntype == "reset": - hlp = "Load/Refresh stats for the currently training session" - elif btntype == "clear": - hlp = "Clear currently displayed session stats" - elif btntype == "save": - hlp = "Save session stats to csv" - elif btntype == "load": - hlp = "Load saved session stats" + if button_type == "reload": + hlp = _("Load/Refresh stats for the currently training session") + elif button_type == "clear": + hlp = _("Clear currently displayed session stats") + elif button_type == "save": + hlp = _("Save session stats to csv") + elif button_type == "load": + hlp = _("Load saved session stats") return hlp - -class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors - """ Stats frame of analysis tab """ + def _add_training_callback(self): + """ Add a callback to the training tkinter variable to disable save and clear buttons + when a model is training. """ + var = self._parent.vars["is_training"] + var.trace("w", self._set_buttons_state) + + def _set_buttons_state(self, *args): # pylint:disable=unused-argument + """ Callback to enable/disable button when training is commenced and stopped. """ + is_training = self._parent.vars["is_training"].get() + state = "disabled" if is_training else "!disabled" + for name, button in self._buttons.items(): + if name not in ("load", "clear"): + continue + logger.debug("Setting %s button state to %s", name, state) + button.state([state]) + + +class StatsData(ttk.Frame): # pylint:disable=too-many-ancestors + """ Stats frame of analysis tab. + + Holds the tree-view containing the summarized session statistics in the Analysis tab. + + Parameters + ---------- + parent: :class:`tkinter.Frame` + The frame within the Analysis Notebook that will hold the statistics + selected_id: :class:`tkinter.IntVar` + The tkinter variable that holds the currently selected session ID + helptext: str + The help text to display for the summary statistics page + """ def __init__(self, parent, selected_id, helptext): - logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')", - self.__class__.__name__, parent, selected_id, helptext) + logger.debug(parse_class_init(locals())) super().__init__(parent) - self.pack(side=tk.TOP, padx=5, pady=5, fill=tk.BOTH, expand=True) - self.session = None # set when loading or clearing from parent - self.selected_id = selected_id - self.popup_positions = list() + self._selected_id = selected_id + + self._canvas = tk.Canvas(self, bd=0, highlightthickness=0) + tree_frame = ttk.Frame(self._canvas) + self._tree_canvas = self._canvas.create_window((0, 0), window=tree_frame, anchor=tk.NW) + self._sub_frame = ttk.Frame(tree_frame) - self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) - self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self._add_label() - self.tree_frame = ttk.Frame(self.canvas) - self.tree_canvas = self.canvas.create_window((0, 0), window=self.tree_frame, anchor=tk.NW) - self.sub_frame = ttk.Frame(self.tree_frame) - self.sub_frame.pack(side=tk.LEFT, fill=tk.X, anchor=tk.N, expand=True) + self._tree = ttk.Treeview(self._sub_frame, height=1, selectmode=tk.BROWSE) + self._scrollbar = ttk.Scrollbar(tree_frame, orient="vertical", command=self._tree.yview) - self.add_label() - self.tree = ttk.Treeview(self.sub_frame, height=1, selectmode=tk.BROWSE) - self.scrollbar = ttk.Scrollbar(self.tree_frame, orient="vertical", command=self.tree.yview) - self.scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self._columns = self._tree_configure(helptext) + self._canvas.bind("", self._resize_frame) + + self._scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self._tree.pack(side=tk.TOP, fill=tk.X) + self._sub_frame.pack(side=tk.LEFT, fill=tk.X, anchor=tk.N, expand=True) + self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self.pack(side=tk.TOP, padx=5, pady=5, fill=tk.BOTH, expand=True) - self.columns = self.tree_configure(helptext) - self.canvas.bind("", self.resize_frame) logger.debug("Initialized: %s", self.__class__.__name__) - def add_label(self): - """ Add Treeview Title """ + def _add_label(self): + """ Add the title above the tree-view. """ logger.debug("Adding Treeview title") - lbl = ttk.Label(self.sub_frame, text="Session Stats", anchor=tk.CENTER) + lbl = ttk.Label(self._sub_frame, text="Session Stats", anchor=tk.CENTER) lbl.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5) - def resize_frame(self, event): - """ Resize the options frame to fit the canvas """ + def _resize_frame(self, event): + """ Resize the options frame to fit the canvas. + + Parameters + ---------- + event: `tkinter.Event` + The tkinter resize event + """ logger.debug("Resize Analysis Frame") canvas_width = event.width canvas_height = event.height - self.canvas.itemconfig(self.tree_canvas, width=canvas_width, height=canvas_height) + self._canvas.itemconfig(self._tree_canvas, width=canvas_width, height=canvas_height) logger.debug("Resized Analysis Frame") - def tree_configure(self, helptext): - """ Build a treeview widget to hold the sessions stats """ + def _tree_configure(self, helptext): + """ Build a tree-view widget to hold the sessions stats. + + Parameters + ---------- + helptext: str + The helptext to display when the mouse is over the tree-view + + Returns + ------- + list + The list of tree-view columns + """ logger.debug("Configuring Treeview") - self.tree.configure(yscrollcommand=self.scrollbar.set) - self.tree.tag_configure("total", background="black", foreground="white") - self.tree.pack(side=tk.TOP, fill=tk.X) - self.tree.bind("", self.select_item) - Tooltip(self.tree, text=helptext, wraplength=200) - return self.tree_columns() - - def tree_columns(self): - """ Add the columns to the totals treeview """ + self._tree.configure(yscrollcommand=self._scrollbar.set) + self._tree.tag_configure("total", background="black", foreground="white") + self._tree.bind("", self._select_item) + Tooltip(self._tree, text=helptext, wrap_length=200) + return self._tree_columns() + + def _tree_columns(self): + """ Add the columns to the totals tree-view. + + Returns + ------- + list + The list of tree-view columns + """ logger.debug("Adding Treeview columns") columns = (("session", 40, "#"), ("start", 130, None), @@ -243,407 +461,130 @@ def tree_columns(self): ("batch", 50, None), ("iterations", 90, None), ("rate", 60, "EGs/sec")) - self.tree["columns"] = [column[0] for column in columns] + self._tree["columns"] = [column[0] for column in columns] for column in columns: text = column[2] if column[2] else column[0].title() logger.debug("Adding heading: '%s'", text) - self.tree.heading(column[0], text=text) - self.tree.column(column[0], width=column[1], anchor=tk.E, minwidth=40) - self.tree.column("#0", width=40) - self.tree.heading("#0", text="Graphs") + self._tree.heading(column[0], text=text) + self._tree.column(column[0], width=column[1], anchor=tk.E, minwidth=40) + self._tree.column("#0", width=40) + self._tree.heading("#0", text="Graphs") return [column[0] for column in columns] def tree_insert_data(self, sessions_summary): - """ Insert the data into the totals treeview """ + """ Insert the summary data into the statistics tree-view. + + Parameters + ---------- + sessions_summary: list + List of session summary dicts for populating into the tree-view + """ logger.debug("Inserting treeview data") - self.tree.configure(height=len(sessions_summary)) + self._tree.configure(height=len(sessions_summary)) for item in sessions_summary: - values = [item[column] for column in self.columns] - kwargs = {"values": values, "image": get_images().icons["graph"]} + values = [item[column] for column in self._columns] + kwargs = {"values": values} + if self._check_valid_data(values): + # Don't show graph icon for non-existent sessions + kwargs["image"] = get_images().icons["graph"] if values[0] == "Total": kwargs["tags"] = "total" - self.tree.insert("", "end", **kwargs) + self._tree.insert("", "end", **kwargs) def tree_clear(self): - """ Clear the totals tree """ + """ Clear all of the summary data from the tree-view. """ logger.debug("Clearing treeview data") - self.tree.delete(* self.tree.get_children()) - self.tree.configure(height=1) - - def select_item(self, event): - """ Update the session summary info with - the selected item or launch graph """ - region = self.tree.identify("region", event.x, event.y) - selection = self.tree.focus() - values = self.tree.item(selection, "values") + try: + self._tree.delete(* self._tree.get_children()) + self._tree.configure(height=1) + except tk.TclError: + # Catch non-existent tree view when rebuilding the GUI + pass + + def _select_item(self, event): + """ Update the session summary info with the selected item or launch graph. + + If the mouse is clicked on the graph icon, then the session summary pop-up graph is + launched. Otherwise the selected ID is stored. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse button release event + """ + region = self._tree.identify("region", event.x, event.y) + selection = self._tree.focus() + values = self._tree.item(selection, "values") if values: logger.debug("Selected values: %s", values) - self.selected_id.set(values[0]) - if region == "tree": - self.data_popup() + self._selected_id.set(values[0]) + if region == "tree" and self._check_valid_data(values): + data_points = int(values[self._columns.index("iterations")]) + self._data_popup(data_points) + + def _check_valid_data(self, values): + """ Check there is valid data available for popping up a graph. + + Parameters + ---------- + values: list + The values that exist for a single session that are to be validated + """ + col_indices = [self._columns.index("batch"), self._columns.index("iterations")] + for idx in col_indices: + if (isinstance(values[idx], int) or values[idx].isdigit()) and int(values[idx]) == 0: + logger.warning("No data to graph for selected session") + return False + return True + + def _data_popup(self, data_points): + """ Pop up a window and control it's position - def data_popup(self): - """ Pop up a window and control it's position """ + The default view is rolling average over 500 points. If there are fewer data points than + this, switch the default to smoothed, + + Parameters + ---------- + data_points: int + The number of iterations that are to be plotted + """ logger.debug("Popping up data window") scaling_factor = get_config().scaling_factor - toplevel = SessionPopUp(self.session.modeldir, - self.session.modelname, - self.selected_id.get()) - toplevel.title(self.data_popup_title()) - toplevel.tk.call('wm', 'iconphoto', toplevel._w, get_images().icons["favicon"]) - position = self.data_popup_get_position() - height = int(720 * scaling_factor) - width = int(400 * scaling_factor) - toplevel.geometry("{}x{}+{}+{}".format(str(height), - str(width), - str(position[0]), - str(position[1]))) + toplevel = SessionPopUp(self._selected_id.get(), + data_points) + toplevel.title(self._data_popup_title()) + toplevel.tk.call( + 'wm', + 'iconphoto', + toplevel._w, get_images().icons["favicon"]) # pylint:disable=protected-access + + root = get_config().root + offset = (root.winfo_x() + 20, root.winfo_y() + 20) + height = int(900 * scaling_factor) + width = int(480 * scaling_factor) + toplevel.geometry(f"{height}x{width}+{offset[0]}+{offset[1]}") + toplevel.update() - def data_popup_title(self): - """ Set the data popup title """ + def _data_popup_title(self): + """ Get the summary graph popup title. + + Returns + ------- + str + The title to display at the top of the pop-up graph window + """ logger.debug("Setting poup title") - selected_id = self.selected_id.get() + selected_id = self._selected_id.get() + model_dir, model_name = os.path.split(Session.model_filename) title = "All Sessions" if selected_id != "Total": - title = "{} Model: Session #{}".format(self.session.modelname.title(), selected_id) + title = f"{model_name.title()} Model: Session #{selected_id}" logger.debug("Title: '%s'", title) - return "{} - {}".format(title, self.session.modeldir) - - def data_popup_get_position(self): - """ Get the position of the next window """ - logger.debug("getting poup position") - init_pos = [120, 120] - pos = init_pos - while True: - if pos not in self.popup_positions: - self.popup_positions.append(pos) - break - pos = [item + 200 for item in pos] - init_pos, pos = self.data_popup_check_boundaries(init_pos, pos) - logger.debug("Position: %s", pos) - return pos - - def data_popup_check_boundaries(self, initial_position, position): - """ Check that the popup remains within the screen boundaries """ - logger.debug("Checking poup boundaries: (initial_position: %s, position: %s)", - initial_position, position) - boundary_x = self.winfo_screenwidth() - 120 - boundary_y = self.winfo_screenheight() - 120 - if position[0] >= boundary_x or position[1] >= boundary_y: - initial_position = [initial_position[0] + 50, initial_position[1]] - position = initial_position - logger.debug("Returning poup boundaries: (initial_position: %s, position: %s)", - initial_position, position) - return initial_position, position - - -class SessionPopUp(tk.Toplevel): - """ Pop up for detailed graph/stats for selected session """ - def __init__(self, model_dir, model_name, session_id): - logger.debug("Initializing: %s: (model_dir: %s, model_name: %s, session_id: %s)", - self.__class__.__name__, model_dir, model_name, session_id) - super().__init__() - - self.session_id = session_id - self.session = Session(model_dir=model_dir, model_name=model_name) - self.initialize_session() - - self.graph = None - self.display_data = None - - self.vars = dict() - self.graph_initialised = False - self.build() - logger.debug("Initialized: %s", self.__class__.__name__) + return f"{title} - {model_dir}" - @property - def is_totals(self): - """ Return True if these are totals else False """ - return bool(self.session_id == "Total") - - def initialize_session(self): - """ Initialize the session """ - logger.debug("Initializing session") - kwargs = dict(is_training=False) - if not self.is_totals: - kwargs["session_id"] = int(self.session_id) - logger.debug("Session kwargs: %s", kwargs) - self.session.initialize_session(**kwargs) - - def build(self): - """ Build the popup window """ - logger.debug("Building popup") - optsframe, graphframe = self.layout_frames() - - self.opts_build(optsframe) - self.compile_display_data() - self.graph_build(graphframe) - logger.debug("Built popup") - - def layout_frames(self): - """ Top level container frames """ - logger.debug("Layout frames") - leftframe = ttk.Frame(self) - leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5) - - sep = ttk.Frame(self, width=2, relief=tk.RIDGE) - sep.pack(fill=tk.Y, side=tk.LEFT) - - rightframe = ttk.Frame(self) - rightframe.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True) - logger.debug("Laid out frames") - - return leftframe, rightframe - - def opts_build(self, frame): - """ Build Options into the options frame """ - logger.debug("Building Options") - self.opts_combobox(frame) - self.opts_checkbuttons(frame) - self.opts_loss_keys(frame) - self.opts_entry(frame) - self.opts_buttons(frame) - sep = ttk.Frame(frame, height=2, relief=tk.RIDGE) - sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) - logger.debug("Built Options") - - def opts_combobox(self, frame): - """ Add the options combo boxes """ - logger.debug("Building Combo boxes") - choices = {"Display": ("Loss", "Rate"), - "Scale": ("Linear", "Log")} - - for item in ["Display", "Scale"]: - var = tk.StringVar() - cmd = self.optbtn_reset if item == "Display" else self.graph_scale - var.trace("w", cmd) - - cmbframe = ttk.Frame(frame) - cmbframe.pack(fill=tk.X, pady=5, padx=5, side=tk.TOP) - lblcmb = ttk.Label(cmbframe, - text="{}:".format(item), - width=7, - anchor=tk.W) - lblcmb.pack(padx=(0, 2), side=tk.LEFT) - - cmb = ttk.Combobox(cmbframe, textvariable=var, width=10) - cmb["values"] = choices[item] - cmb.current(0) - cmb.pack(fill=tk.X, side=tk.RIGHT) - - self.vars[item.lower().strip()] = var - - hlp = self.set_help(item) - Tooltip(cmbframe, text=hlp, wraplength=200) - logger.debug("Built Combo boxes") - - def opts_checkbuttons(self, frame): - """ Add the options check buttons """ - logger.debug("Building Check Buttons") - for item in ("raw", "trend", "avg", "outliers"): - if item == "avg": - text = "Show Rolling Average" - elif item == "outliers": - text = "Flatten Outliers" - else: - text = "Show {}".format(item.title()) - var = tk.BooleanVar() - - if item == "raw": - var.set(True) - var.trace("w", self.optbtn_reset) - self.vars[item] = var - - ctl = ttk.Checkbutton(frame, variable=var, text=text) - ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W) - - hlp = self.set_help(item) - Tooltip(ctl, text=hlp, wraplength=200) - logger.debug("Built Check Buttons") - - def opts_loss_keys(self, frame): - """ Add loss key selections """ - logger.debug("Building Loss Key Check Buttons") - loss_keys = self.session.loss_keys - lk_vars = dict() - for loss_key in sorted(loss_keys): - text = loss_key.replace("_", " ").title() - helptext = "Display {}".format(text) - var = tk.BooleanVar() - var.set(True) - var.trace("w", self.optbtn_reset) - lk_vars[loss_key] = var - - if len(loss_keys) == 1: - # Don't display if there's only one item - break - - ctl = ttk.Checkbutton(frame, variable=var, text=text) - ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W) - Tooltip(ctl, text=helptext, wraplength=200) - - self.vars["loss_keys"] = lk_vars - logger.debug("Built Loss Key Check Buttons") - - def opts_entry(self, frame): - """ Add the options entry boxes """ - logger.debug("Building Entry Boxes") - for item in ("avgiterations", ): - if item == "avgiterations": - text = "Iterations to Average:" - default = "10" - - entframe = ttk.Frame(frame) - entframe.pack(fill=tk.X, pady=5, padx=5, side=tk.TOP) - lbl = ttk.Label(entframe, text=text, anchor=tk.W) - lbl.pack(padx=(0, 2), side=tk.LEFT) - - ctl = ttk.Entry(entframe, width=4, justify=tk.RIGHT) - ctl.pack(side=tk.RIGHT, anchor=tk.W) - ctl.insert(0, default) - - hlp = self.set_help(item) - Tooltip(entframe, text=hlp, wraplength=200) - - self.vars[item] = ctl - logger.debug("Built Entry Boxes") - - def opts_buttons(self, frame): - """ Add the option buttons """ - logger.debug("Building Buttons") - btnframe = ttk.Frame(frame) - btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM) - - for btntype in ("reset", "save"): - cmd = getattr(self, "optbtn_{}".format(btntype)) - btn = ttk.Button(btnframe, - image=get_images().icons[btntype], - command=cmd) - btn.pack(padx=2, side=tk.RIGHT) - hlp = self.set_help(btntype) - Tooltip(btn, text=hlp, wraplength=200) - logger.debug("Built Buttons") - - def optbtn_save(self): - """ Action for save button press """ - logger.debug("Saving File") - savefile = FileHandler("save", "csv").retfile - if not savefile: - logger.debug("Save Cancelled") - return - logger.debug("Saving to: %s", savefile) - save_data = self.display_data.stats - fieldnames = sorted(key for key in save_data.keys()) - - with savefile as outfile: - csvout = csv.writer(outfile, delimiter=",") - csvout.writerow(fieldnames) - csvout.writerows(zip(*[save_data[key] for key in fieldnames])) - - def optbtn_reset(self, *args): # pylint: disable=unused-argument - """ Action for reset button press and checkbox changes""" - logger.debug("Refreshing Graph") - if not self.graph_initialised: - return - valid = self.compile_display_data() - if not valid: - logger.debug("Invalid data") - return - self.graph.refresh(self.display_data, - self.vars["display"].get(), - self.vars["scale"].get()) - logger.debug("Refreshed Graph") - - def graph_scale(self, *args): # pylint: disable=unused-argument - """ Action for changing graph scale """ - if not self.graph_initialised: - return - self.graph.set_yscale_type(self.vars["scale"].get()) - - @staticmethod - def set_help(control): - """ Set the helptext for option buttons """ - hlp = "" - control = control.lower() - if control == "reset": - hlp = "Refresh graph" - elif control == "save": - hlp = "Save display data to csv" - elif control == "avgiterations": - hlp = "Number of data points to sample for rolling average" - elif control == "outliers": - hlp = "Flatten data points that fall more than 1 standard " \ - "deviation from the mean to the mean value." - elif control == "avg": - hlp = "Display rolling average of the data" - elif control == "raw": - hlp = "Display raw data" - elif control == "trend": - hlp = "Display polynormal data trend" - elif control == "display": - hlp = "Set the data to display" - elif control == "scale": - hlp = "Change y-axis scale" - return hlp - - def compile_display_data(self): - """ Compile the data to be displayed """ - logger.debug("Compiling Display Data") - - loss_keys = [key for key, val in self.vars["loss_keys"].items() - if val.get()] - logger.debug("Selected loss_keys: %s", loss_keys) - - selections = self.selections_to_list() - - if not self.check_valid_selection(loss_keys, selections): - return False - self.display_data = Calculations(session=self.session, - display=self.vars["display"].get(), - loss_keys=loss_keys, - selections=selections, - avg_samples=self.vars["avgiterations"].get(), - flatten_outliers=self.vars["outliers"].get(), - is_totals=self.is_totals) - logger.debug("Compiled Display Data") - return True - - def check_valid_selection(self, loss_keys, selections): - """ Check that there will be data to display """ - display = self.vars["display"].get().lower() - logger.debug("Validating selection. (loss_keys: %s, selections: %s, display: %s)", - loss_keys, selections, display) - if not selections or (display == "loss" and not loss_keys): - msg = "No data to display. Not refreshing" - logger.debug(msg) - print(msg) - return False - return True - def selections_to_list(self): - """ Compile checkbox selections to list """ - logger.debug("Compiling selections to list") - selections = list() - for key, val in self.vars.items(): - if (isinstance(val, tk.BooleanVar) - and key != "outliers" - and val.get()): - selections.append(key) - logger.debug("Compiling selections to list: %s", selections) - return selections - - def graph_build(self, frame): - """ Build the graph in the top right paned window """ - logger.debug("Building Graph") - self.graph = SessionGraph(frame, - self.display_data, - self.vars["display"].get(), - self.vars["scale"].get()) - self.graph.pack(expand=True, fill=tk.BOTH) - self.graph.build() - self.graph_initialised = True - logger.debug("Built Graph") +__all__ = get_module_objects(__name__) diff --git a/lib/gui/display_command.py b/lib/gui/display_command.py index 01af8537d3..785085228f 100644 --- a/lib/gui/display_command.py +++ b/lib/gui/display_command.py @@ -1,234 +1,474 @@ #!/usr/bin python3 """ Command specific tabs of Display Frame of the Faceswap GUI """ import datetime +import gettext import logging import os import tkinter as tk +import typing as T from tkinter import ttk +from lib.logger import parse_class_init +from lib.training.preview_tk import PreviewTk +from lib.utils import get_module_objects from .display_graph import TrainingGraph from .display_page import DisplayOptionalPage -from .tooltip import Tooltip -from .stats import Calculations -from .utils import FileHandler, get_config, get_images +from .custom_widgets import Tooltip +from .analysis import Calculations, Session +from .control_helper import set_slider_rounding +from .utils import FileHandler, get_config, get_images, preview_trigger -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext -class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors + +class PreviewExtract(DisplayOptionalPage): # pylint:disable=too-many-ancestors """ Tab to display output preview images for extract and convert """ + def __init__(self, *args, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + self._preview = get_images().preview_extract + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) - def display_item_set(self): + def display_item_set(self) -> None: """ Load the latest preview if available """ - logger.trace("Loading latest preview") - get_images().load_latest_preview() - self.display_item = get_images().previewoutput + logger.trace("Loading latest preview") # type:ignore[attr-defined] + size = int(256 if self.command == "convert" else 128 * get_config().scaling_factor) + if not self._preview.load_latest_preview(thumbnail_size=size, + frame_dims=(self.winfo_width(), + self.winfo_height())): + logger.trace("Preview not updated") # type:ignore[attr-defined] + return + logger.debug("Preview loaded") + self.display_item = True - def display_item_process(self): + def display_item_process(self) -> None: """ Display the preview """ - logger.trace("Displaying preview") + logger.trace("Displaying preview") # type:ignore[attr-defined] if not self.subnotebook.children: self.add_child() else: self.update_child() - def add_child(self): + def add_child(self) -> None: """ Add the preview label child """ logger.debug("Adding child") preview = self.subnotebook_add_page(self.tabname, widget=None) - lblpreview = ttk.Label(preview, image=get_images().previewoutput[1]) + lblpreview = ttk.Label(preview, image=self._preview.image) # type:ignore[arg-type] lblpreview.pack(side=tk.TOP, anchor=tk.NW) - Tooltip(lblpreview, text=self.helptext, wraplength=200) + Tooltip(lblpreview, text=self.helptext, wrap_length=200) - def update_child(self): + def update_child(self) -> None: """ Update the preview image on the label """ - logger.trace("Updating preview") + logger.trace("Updating preview") # type:ignore[attr-defined] for widget in self.subnotebook_get_widgets(): - widget.configure(image=get_images().previewoutput[1]) + widget.configure(image=self._preview.image) - def save_items(self): + def save_items(self) -> None: """ Open save dialogue and save preview """ - location = FileHandler("dir", None).retfile + location = FileHandler("dir", None).return_file if not location: return filename = "extract_convert_preview" now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - filename = os.path.join(location, - "{}_{}.{}".format(filename, - now, - "png")) - get_images().previewoutput[0].save(filename) - logger.debug("Saved preview to %s", filename) - print("Saved preview to {}".format(filename)) + filename = os.path.join(location, f"{filename}_{now}.png") + self._preview.save(filename) + print(f"Saved preview to {filename}") -class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors +class PreviewTrain(DisplayOptionalPage): # pylint:disable=too-many-ancestors """ Training preview image(s) """ - def __init__(self, *args, **kwargs): - self.update_preview = get_config().tk_vars["updatepreview"] + def __init__(self, *args, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + self._preview = get_images().preview_train + self._display: PreviewTk | None = None super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) - def display_item_set(self): + def add_options(self) -> None: + """ Add the additional options """ + self._add_option_refresh() + self._add_option_mask_toggle() + super().add_options() + + def subnotebook_hide(self) -> None: + """ Override default subnotebook hide action to also remove the embedded option bar + control and reset the training image buffer """ + if self.subnotebook and self.subnotebook.winfo_ismapped(): + logger.debug("Removing preview controls from options bar") + if self._display is not None: + self._display.remove_option_controls() + super().subnotebook_hide() + del self._display + self._display = None + self._preview.reset() + + def _add_option_refresh(self) -> None: + """ Add refresh button to refresh preview immediately """ + logger.debug("Adding refresh option") + btnrefresh = ttk.Button( + self.optsframe, + image=get_images().icons["reload"], # type:ignore[arg-type] + command=lambda x="update": preview_trigger().set(x)) # type:ignore[misc] + btnrefresh.pack(padx=2, side=tk.RIGHT) + Tooltip(btnrefresh, + text=_("Preview updates at every model save. Click to refresh now."), + wrap_length=200) + logger.debug("Added refresh option") + + def _add_option_mask_toggle(self) -> None: + """ Add button to toggle mask display on and off """ + logger.debug("Adding mask toggle option") + btntoggle = ttk.Button( + self.optsframe, + image=get_images().icons["mask2"], # type:ignore[arg-type] + command=lambda x="mask_toggle": preview_trigger().set(x)) # type:ignore[misc] + btntoggle.pack(padx=2, side=tk.RIGHT) + Tooltip(btntoggle, + text=_("Click to toggle mask overlay on and off."), + wrap_length=200) + logger.debug("Added mask toggle option") + + def display_item_set(self) -> None: """ Load the latest preview if available """ - logger.trace("Loading latest preview") - if not self.update_preview.get(): - logger.trace("Preview not updated") + # TODO This seems to be triggering faster than the waittime + logger.trace("Loading latest preview") # type:ignore[attr-defined] + if not self._preview.load(): + logger.trace("Preview not updated") # type:ignore[attr-defined] return - get_images().load_training_preview() - self.display_item = get_images().previewtrain + logger.debug("Preview loaded") + self.display_item = True - def display_item_process(self): + def display_item_process(self) -> None: """ Display the preview(s) resized as appropriate """ - logger.trace("Displaying preview") - sortednames = sorted(list(get_images().previewtrain.keys())) - existing = self.subnotebook_get_titles_ids() - should_update = self.update_preview.get() - - for name in sortednames: - if name not in existing.keys(): - self.add_child(name) - elif should_update: - tab_id = existing[name] - self.update_child(tab_id, name) - - if should_update: - self.update_preview.set(False) - - def add_child(self, name): - """ Add the preview canvas child """ - logger.debug("Adding child") - preview = PreviewTrainCanvas(self.subnotebook, name) - preview = self.subnotebook_add_page(name, widget=preview) - Tooltip(preview, text=self.helptext, wraplength=200) - self.vars["modified"].set(get_images().previewtrain[name][2]) - - def update_child(self, tab_id, name): - """ Update the preview canvas """ - logger.debug("Updating preview") - if self.vars["modified"].get() != get_images().previewtrain[name][2]: - self.vars["modified"].set(get_images().previewtrain[name][2]) - widget = self.subnotebook_page_from_id(tab_id) - widget.reload() - - def save_items(self): + if self.subnotebook.children: + return + + logger.debug("Displaying preview") + self._display = PreviewTk(self._preview.buffer, self.subnotebook, self.optsframe, None) + self.subnotebook_add_page(self.tabname, widget=self._display.master_frame) + + def save_items(self) -> None: """ Open save dialogue and save preview """ - location = FileHandler("dir", None).retfile + if self._display is None: + return + + location = FileHandler("dir", None).return_file if not location: return - for preview in self.subnotebook.children.values(): - preview.save_preview(location) - - -class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors - """ Canvas to hold a training preview image """ - def __init__(self, parent, previewname): - logger.debug("Initializing %s: (previewname: '%s')", self.__class__.__name__, previewname) - ttk.Frame.__init__(self, parent) - - self.name = previewname - get_images().resize_image(self.name, None) - self.previewimage = get_images().previewtrain[self.name][1] - - self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) - self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - self.imgcanvas = self.canvas.create_image(0, - 0, - image=self.previewimage, - anchor=tk.NW) - self.bind("", self.resize) - logger.debug("Initialized %s:", self.__class__.__name__) - - def resize(self, event): - """ Resize the image to fit the frame, maintaining aspect ratio """ - logger.trace("Resizing preview image") - framesize = (event.width, event.height) - # Sometimes image is resized before frame is drawn - framesize = None if framesize == (1, 1) else framesize - get_images().resize_image(self.name, framesize) - self.reload() - - def reload(self): - """ Reload the preview image """ - logger.trace("Reloading preview image") - self.previewimage = get_images().previewtrain[self.name][1] - self.canvas.itemconfig(self.imgcanvas, image=self.previewimage) - - def save_preview(self, location): - """ Save the figure to file """ - filename = self.name - now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - filename = os.path.join(location, - "{}_{}.{}".format(filename, - now, - "png")) - get_images().previewtrain[self.name][0].save(filename) - logger.debug("Saved preview to %s", filename) - print("Saved preview to {}".format(filename)) + self._display.save(location) -class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors - """ The Graph Tab of the Display section """ - def add_options(self): +class GraphDisplay(DisplayOptionalPage): # pylint:disable=too-many-ancestors + """ The Graph Tab of the Display section """ + def __init__(self, + parent: ttk.Notebook, + tab_name: str, + helptext: str, + wait_time: int, + command: str | None = None) -> None: + logger.debug(parse_class_init(locals())) + self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"], + tuple[tk.BooleanVar, str]] = {} + super().__init__(parent, tab_name, helptext, wait_time, command) + logger.debug("Initialized %s", self.__class__.__name__) + + def set_vars(self) -> None: + """ Add graphing specific variables to the default variables. + + Overrides original method. + + Returns + ------- + dict + The variable names with their corresponding tkinter variable + """ + tk_vars = super().set_vars() + + smoothgraph = tk.DoubleVar() + smoothgraph.set(0.900) + tk_vars["smoothgraph"] = smoothgraph + + raw_var = tk.BooleanVar() + raw_var.set(True) + tk_vars["raw_data"] = raw_var + + smooth_var = tk.BooleanVar() + smooth_var.set(True) + tk_vars["smooth_data"] = smooth_var + + iterations_var = tk.IntVar() + iterations_var.set(10000) + tk_vars["display_iterations"] = iterations_var + + logger.debug(tk_vars) + return tk_vars + + def on_tab_select(self) -> None: + """ Callback for when the graph tab is selected. + + Pull latest data and run the tab's update code when the tab is selected. + """ + logger.debug("Callback received for '%s' tab (display_item: %s)", + self.tabname, self.display_item) + if self.display_item is not None: + get_config().tk_vars.refresh_graph.set(True) + self._update_page() + + def add_options(self) -> None: """ Add the additional options """ - self.add_option_refresh() + self._add_option_refresh() super().add_options() + self._add_option_raw() + self._add_option_smoothed() + self._add_option_smoothing() + self._add_option_iterations() - def add_option_refresh(self): + def _add_option_refresh(self) -> None: """ Add refresh button to refresh graph immediately """ logger.debug("Adding refresh option") - tk_var = get_config().tk_vars["refreshgraph"] + tk_var = get_config().tk_vars.refresh_graph btnrefresh = ttk.Button(self.optsframe, - image=get_images().icons["reset"], + image=get_images().icons["reload"], # type:ignore[arg-type] command=lambda: tk_var.set(True)) btnrefresh.pack(padx=2, side=tk.RIGHT) Tooltip(btnrefresh, - text="Graph updates every 100 iterations. Click to refresh now.", - wraplength=200) - - def display_item_set(self): + text=_("Graph updates at every model save. Click to refresh now."), + wrap_length=200) + logger.debug("Added refresh option") + + def _add_option_raw(self) -> None: + """ Add check-button to hide/display raw data """ + logger.debug("Adding display raw option") + tk_var = self.vars["raw_data"] + chkbtn = ttk.Checkbutton( + self.optsframe, + variable=tk_var, + text="Raw", + command=lambda v=tk_var: self._display_data_callback("raw", v)) # type:ignore + chkbtn.pack(side=tk.RIGHT, padx=5, anchor=tk.W) + Tooltip(chkbtn, text=_("Display the raw loss data"), wrap_length=200) + + def _add_option_smoothed(self) -> None: + """ Add check-button to hide/display smoothed data """ + logger.debug("Adding display smoothed option") + tk_var = self.vars["smooth_data"] + chkbtn = ttk.Checkbutton( + self.optsframe, + variable=tk_var, + text="Smoothed", + command=lambda v=tk_var: self._display_data_callback("smoothed", v)) # type:ignore + chkbtn.pack(side=tk.RIGHT, padx=5, anchor=tk.W) + Tooltip(chkbtn, text=_("Display the smoothed loss data"), wrap_length=200) + + def _add_option_smoothing(self) -> None: + """ Add a slider to adjust the smoothing amount """ + logger.debug("Adding Smoothing Slider") + tk_var = self.vars["smoothgraph"] + min_max = (0, 0.999) + hlp = _("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing.") + + ctl_frame = ttk.Frame(self.optsframe) + ctl_frame.pack(padx=2, side=tk.RIGHT) + + lbl = ttk.Label(ctl_frame, text="Smoothing:", anchor=tk.W) + lbl.pack(pady=5, side=tk.LEFT, anchor=tk.N, expand=True) + + tbox = ttk.Entry(ctl_frame, width=6, textvariable=tk_var, justify=tk.RIGHT) + tbox.pack(padx=(0, 5), side=tk.RIGHT) + + ctl = ttk.Scale( + ctl_frame, + variable=tk_var, + command=lambda val, var=tk_var, dt=float, rn=3, mm=min_max: # type:ignore + set_slider_rounding(val, var, dt, rn, mm)) + ctl["from_"] = min_max[0] + ctl["to"] = min_max[1] + ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) + for item in (tbox, ctl): + Tooltip(item, + text=hlp, + wrap_length=200) + logger.debug("Added Smoothing Slider") + + def _add_option_iterations(self) -> None: + """ Add a slider to adjust the amount if iterations to display """ + logger.debug("Adding Iterations Slider") + tk_var = self.vars["display_iterations"] + min_max = (0, 100000) + hlp = _("Set the number of iterations to display. 0 displays the full session.") + + ctl_frame = ttk.Frame(self.optsframe) + ctl_frame.pack(padx=2, side=tk.RIGHT) + + lbl = ttk.Label(ctl_frame, text="Iterations:", anchor=tk.W) + lbl.pack(pady=5, side=tk.LEFT, anchor=tk.N, expand=True) + + tbox = ttk.Entry(ctl_frame, width=6, textvariable=tk_var, justify=tk.RIGHT) + tbox.pack(padx=(0, 5), side=tk.RIGHT) + + ctl = ttk.Scale( + ctl_frame, + variable=tk_var, + command=lambda val, var=tk_var, dt=int, rn=1000, mm=min_max: # type:ignore + set_slider_rounding(val, var, dt, rn, mm)) + ctl["from_"] = min_max[0] + ctl["to"] = min_max[1] + ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) + for item in (tbox, ctl): + Tooltip(item, + text=hlp, + wrap_length=200) + logger.debug("Added Iterations Slider") + + def display_item_set(self) -> None: """ Load the graph(s) if available """ - session = get_config().session - if session.initialized and session.logging_disabled: - logger.trace("Logs disabled. Hiding graph") + if Session.is_training and Session.logging_disabled: + logger.trace("Logs disabled. Hiding graph") # type:ignore[attr-defined] self.set_info("Graph is disabled as 'no-logs' has been selected") self.display_item = None - elif session.initialized: - logger.trace("Loading graph") - self.display_item = session + self._clear_trace_variables() + elif Session.is_training and self.display_item is None: + logger.trace("Loading graph") # type:ignore[attr-defined] + self.display_item = Session + self._add_trace_variables() + elif Session.is_training and self.display_item is not None: + logger.trace("Graph already displayed. Nothing to do.") # type:ignore[attr-defined] else: + logger.trace("Clearing graph") # type:ignore[attr-defined] self.display_item = None + self._clear_trace_variables() - def display_item_process(self): + def display_item_process(self) -> None: """ Add a single graph to the graph window """ - logger.trace("Adding graph") + if not Session.is_training: + logger.debug("Waiting for Session Data to become available to graph") + self.after(1000, self.display_item_process) + return + existing = list(self.subnotebook_get_titles_ids().keys()) - for loss_key in self.display_item.loss_keys: + loss_keys = self.display_item.get_loss_keys(Session.session_ids[-1]) + if not loss_keys: + # Reload if we attempt to get loss keys before data is written + logger.debug("Waiting for Session Data to become available to graph") + self.after(1000, self.display_item_process) + return + + loss_keys = [key for key in loss_keys if key != "total"] + display_tabs = sorted(set(key[:-1].rstrip("_") for key in loss_keys)) + + for loss_key in display_tabs: tabname = loss_key.replace("_", " ").title() if tabname in existing: continue + logger.debug("Adding graph '%s'", tabname) - data = Calculations(session=get_config().session, + display_keys = [key for key in loss_keys if key.startswith(loss_key)] + data = Calculations(session_id=Session.session_ids[-1], display="loss", - loss_keys=[loss_key], - selections=["raw", "trend"]) + loss_keys=display_keys, + selections=["raw", "smoothed"], + smooth_amount=self.vars["smoothgraph"].get()) self.add_child(tabname, data) - def add_child(self, name, data): - """ Add the graph for the selected keys """ + def _smooth_amount_callback(self, *args) -> None: + """ Update each graph's smooth amount on variable change """ + try: + smooth_amount = self.vars["smoothgraph"].get() + except tk.TclError: + # Don't update when there is no value in the variable + return + logger.debug("Updating graph smooth_amount: (new_value: %s, args: %s)", + smooth_amount, args) + for graph in self.subnotebook.children.values(): + graph.calcs.set_smooth_amount(smooth_amount) + + def _iteration_limit_callback(self, *args) -> None: + """ Limit the amount of data displayed in the live graph on a iteration slider + variable change. """ + try: + limit = self.vars["display_iterations"].get() + except tk.TclError: + # Don't update when there is no value in the variable + return + logger.debug("Updating graph iteration limit: (new_value: %s, args: %s)", + limit, args) + for graph in self.subnotebook.children.values(): + graph.calcs.set_iterations_limit(limit) + + def _display_data_callback(self, line: str, variable: tk.BooleanVar) -> None: + """ Update the displayed graph lines based on option check button selection. + + Parameters + ---------- + line: str + The line to hide or display + variable: :class:`tkinter.BooleanVar` + The tkinter variable containing the ``True`` or ``False`` data for this display item + """ + var = variable.get() + logger.debug("Updating display %s to %s", line, var) + for graph in self.subnotebook.children.values(): + graph.calcs.update_selections(line, var) + + def add_child(self, name: str, data: Calculations) -> None: + """ Add the graph for the selected keys. + + Parameters + ---------- + name: str + The name of the graph to add to the notebook + data: :class:`~lib.gui.analysis.stats.Calculations` + The object holding the data to be graphed + """ logger.debug("Adding child: %s", name) graph = TrainingGraph(self.subnotebook, data, "Loss") graph.build() graph = self.subnotebook_add_page(name, widget=graph) - Tooltip(graph, text=self.helptext, wraplength=200) + Tooltip(graph, text=self.helptext, wrap_length=200) - def save_items(self): + def save_items(self) -> None: """ Open save dialogue and save graphs """ - graphlocation = FileHandler("dir", None).retfile + graphlocation = FileHandler("dir", None).return_file if not graphlocation: return for graph in self.subnotebook.children.values(): graph.save_fig(graphlocation) + + def _add_trace_variables(self) -> None: + """ Add tracing for when the option sliders are updated, for updating the graph. """ + for name, action in zip(T.get_args(T.Literal["smoothgraph", "display_iterations"]), + (self._smooth_amount_callback, self._iteration_limit_callback)): + var = self.vars[name] + if name not in self._trace_vars: + self._trace_vars[name] = (var, var.trace("w", action)) + + def _clear_trace_variables(self) -> None: + """ Clear all of the trace variables from :attr:`_trace_vars` and reset the dictionary. """ + if self._trace_vars: + for name, (var, trace) in self._trace_vars.items(): + logger.debug("Clearing trace from variable: %s", name) + var.trace_vdelete("w", trace) + self._trace_vars = {} + + def close(self) -> None: + """ Clear the plots from RAM """ + self._clear_trace_variables() + if self.subnotebook is None: + logger.debug("No graphs to clear. Returning") + return + + for name, graph in self.subnotebook.children.items(): + logger.debug("Clearing: %s", name) + graph.clear() + super().close() + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/display_graph.py b/lib/gui/display_graph.py index 3abfe8a380..9ae83f74a7 100755 --- a/lib/gui/display_graph.py +++ b/lib/gui/display_graph.py @@ -1,323 +1,600 @@ #!/usr/bin python3 -""" Graph functions for Display Frame of the Faceswap GUI """ +""" Graph functions for Display Frame area of the Faceswap GUI """ +from __future__ import annotations import datetime import logging import os import tkinter as tk +import typing as T from tkinter import ttk from math import ceil, floor +import numpy as np import matplotlib -# pylint: disable=wrong-import-position -matplotlib.use("TkAgg") - -from matplotlib import pyplot as plt, style # noqa +from matplotlib import style +from matplotlib.figure import Figure from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg, - NavigationToolbar2Tk) # noqa - -from .tooltip import Tooltip # noqa -from .utils import get_config, get_images # noqa + NavigationToolbar2Tk) +from matplotlib.backend_bases import NavigationToolbar2 -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from .custom_widgets import Tooltip +from .utils import get_config, get_images, LongRunningTask -class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors - """ Same as default, but only including buttons we need - with custom icons and layout - From: https://stackoverflow.com/questions/12695678 """ - toolitems = [t for t in NavigationToolbar2Tk.toolitems if - t[0] in ("Home", "Pan", "Zoom", "Save")] +if T.TYPE_CHECKING: + from matplotlib.lines import Line2D - @staticmethod - def _Button(frame, text, file, command, extension=".gif"): # pylint: disable=arguments-differ - """ Map Buttons to their own frame. - Use custom button icons, Use ttk buttons pack to the right """ - iconmapping = {"home": "reset", - "filesave": "save", - "zoom_to_rect": "zoom"} - icon = iconmapping[file] if iconmapping.get(file, None) else file - img = get_images().icons[icon] - btn = ttk.Button(frame, text=text, image=img, command=command) - btn.pack(side=tk.RIGHT, padx=2) - return btn +logger: logging.Logger = logging.getLogger(__name__) - def _init_toolbar(self): - """ Same as original but ttk widgets and standard tooltips used. Separator added and - message label packed to the left """ - xmin, xmax = self.canvas.figure.bbox.intervalx - height, width = 50, xmax-xmin - ttk.Frame.__init__(self, master=self.window, width=int(width), height=int(height)) - sep = ttk.Frame(self, height=2, relief=tk.RIDGE) - sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP) +class GraphBase(ttk.Frame): # pylint:disable=too-many-ancestors + """ Base class for matplotlib line graphs. - self.update() # Make axes menu - - btnframe = ttk.Frame(self) - btnframe.pack(fill=tk.X, padx=5, pady=5, side=tk.RIGHT) - - for text, tooltip_text, image_file, callback in self.toolitems: - if text is None: - # Add a spacer; return value is unused. - self._Spacer() - else: - button = self._Button(btnframe, text=text, file=image_file, - command=getattr(self, callback)) - if tooltip_text is not None: - Tooltip(button, text=tooltip_text, wraplength=200) + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The parent frame that holds the graph + data: :class:`lib.gui.analysis.stats.Calculations` + The statistics class that holds the data to be displayed + ylabel: str + The data label for the y-axis + """ + def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: + super().__init__(parent) + matplotlib.use("TkAgg") # Can't be at module level as breaks Github CI + style.use("ggplot") - self.message = tk.StringVar(master=self) - self._message_label = ttk.Label(master=self, textvariable=self.message) - self._message_label.pack(side=tk.LEFT, padx=5) - self.pack(side=tk.BOTTOM, fill=tk.X) + self._calcs = data + self._ylabel = ylabel + self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper", + "summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"] + self._lines: list[Line2D] = [] + self._toolbar: "NavigationToolbar" | None = None + self._fig = Figure(figsize=(4, 4), dpi=75) + self._ax1 = self._fig.add_subplot(1, 1, 1) + self._plotcanvas = FigureCanvasTkAgg(self._fig, self) -class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors - """ Base class for matplotlib line graphs """ - def __init__(self, parent, data, ylabel): - logger.debug("Initializing %s", self.__class__.__name__) - super().__init__(parent) - style.use("ggplot") + self._initiate_graph() + self._update_plot(initiate=True) - self.calcs = data - self.ylabel = ylabel - self.colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", - "Greys", "copper", "summer", "bone"] - self.lines = list() - self.toolbar = None - self.fig = plt.figure(figsize=(4, 4), dpi=75) - self.ax1 = self.fig.add_subplot(1, 1, 1) - self.plotcanvas = FigureCanvasTkAgg(self.fig, self) - - self.initiate_graph() - self.update_plot(initiate=True) - logger.debug("Initialized %s", self.__class__.__name__) + @property + def calcs(self): + """ :class:`lib.gui.analysis.stats.Calculations`. The calculated statistics associated with + this graph. """ + return self._calcs - def initiate_graph(self): + def _initiate_graph(self) -> None: """ Place the graph canvas """ logger.debug("Setting plotcanvas") - self.plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True) - plt.subplots_adjust(left=0.100, - bottom=0.100, - right=0.95, - top=0.95, - wspace=0.2, - hspace=0.2) + self._plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True) + self._fig.subplots_adjust(left=0.100, + bottom=0.100, + right=0.95, + top=0.95, + wspace=0.2, + hspace=0.2) logger.debug("Set plotcanvas") - def update_plot(self, initiate=True): - """ Update the plot with incoming data """ - logger.trace("Updating plot") + def _update_plot(self, initiate: bool = True) -> None: + """ Update the plot with incoming data + + Parameters + ---------- + initiate: bool, Optional + Whether the graph should be initialized for the first time (``True``) or data is being + updated for an existing graph (``False``). Default: ``True`` + """ + logger.trace("Updating plot") # type:ignore[attr-defined] if initiate: logger.debug("Initializing plot") - self.lines = list() - self.ax1.clear() - self.axes_labels_set() + self._lines = [] + self._ax1.clear() + self._axes_labels_set() logger.debug("Initialized plot") - fulldata = [item for item in self.calcs.stats.values()] - self.axes_limits_set(fulldata) + fulldata = list(self._calcs.stats.values()) + self._axes_limits_set(fulldata) + + if self._calcs.start_iteration > 0: + end_iteration = self._calcs.start_iteration + self._calcs.iterations + xrng = list(range(self._calcs.start_iteration, end_iteration)) + else: + xrng = list(range(self._calcs.iterations)) + + keys = list(self._calcs.stats.keys()) - xrng = [x for x in range(self.calcs.iterations)] - keys = list(self.calcs.stats.keys()) - for idx, item in enumerate(self.lines_sort(keys)): + for idx, item in enumerate(self._lines_sort(keys)): if initiate: - self.lines.extend(self.ax1.plot(xrng, self.calcs.stats[item[0]], - label=item[1], linewidth=item[2], color=item[3])) + self._lines.extend(self._ax1.plot(xrng, self._calcs.stats[item[0]], + label=item[1], linewidth=item[2], color=item[3])) else: - self.lines[idx].set_data(xrng, self.calcs.stats[item[0]]) + self._lines[idx].set_data(xrng, self._calcs.stats[item[0]]) if initiate: - self.legend_place() - logger.trace("Updated plot") + self._legend_place() + logger.trace("Updated plot") # type:ignore[attr-defined] - def axes_labels_set(self): - """ Set the axes label and range """ - logger.debug("Setting axes labels. y-label: '%s'", self.ylabel) - self.ax1.set_xlabel("Iterations") - self.ax1.set_ylabel(self.ylabel) + def _axes_labels_set(self) -> None: + """ Set the X and Y axes labels. """ + logger.debug("Setting axes labels. y-label: '%s'", self._ylabel) + self._ax1.set_xlabel("Iterations") + self._ax1.set_ylabel(self._ylabel) - def axes_limits_set_default(self): - """ Set default axes limits """ + def _axes_limits_set_default(self) -> None: + """ Set the default axes limits for the X and Y axes. """ logger.debug("Setting default axes ranges") - self.ax1.set_ylim(0.00, 100.0) - self.ax1.set_xlim(0, 1) + self._ax1.set_ylim(0.00, 100.0) + self._ax1.set_xlim(0, 1) + + def _axes_limits_set(self, data: list[float]) -> None: + """ Set the axes limits. + + Parameters + ---------- + data: list + The data points for the Y Axis + """ + xmin = self._calcs.start_iteration + if self._calcs.start_iteration > 0: + xmax = self._calcs.iterations + self._calcs.start_iteration + else: + xmax = self._calcs.iterations + xmax = max(1, xmax - 1) - def axes_limits_set(self, data): - """ Set the axes limits """ - xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1 if data: - ymin, ymax = self.axes_data_get_min_max(data) - self.ax1.set_ylim(ymin, ymax) - self.ax1.set_xlim(0, xmax) + ymin, ymax = self._axes_data_get_min_max(data) + self._ax1.set_ylim(ymin, ymax) + self._ax1.set_xlim(xmin, xmax) + logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", # type:ignore[attr-defined] + ymin, ymax, xmax) else: - self.axes_limits_set_default() - logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax) + self._axes_limits_set_default() @staticmethod - def axes_data_get_min_max(data): - """ Return the minimum and maximum values from list of lists """ - ymin, ymax = list(), list() - for item in data: - dataset = list(filter(lambda x: x is not None, item)) - if not dataset: - continue - ymin.append(min(dataset) * 1000) - ymax.append(max(dataset) * 1000) - ymin = floor(min(ymin)) / 1000 - ymax = ceil(max(ymax)) / 1000 - logger.trace("ymin: %s, ymax: %s", ymin, ymax) + def _axes_data_get_min_max(data: list[float]) -> tuple[float, float]: + """ Obtain the minimum and maximum values for the y-axis from the given data points. + + Parameters + ---------- + data: list + The data points for the Y Axis + + Returns + ------- + tuple + The minimum and maximum values for the y axis + """ + ymins, ymaxs = [], [] + + for item in data: # TODO Handle as array not loop + ymins.append(np.nanmin(item) * 1000) + ymaxs.append(np.nanmax(item) * 1000) + ymin = floor(min(ymins)) / 1000 + ymax = ceil(max(ymaxs)) / 1000 + logger.trace("ymin: %s, ymax: %s", ymin, ymax) # type:ignore[attr-defined] return ymin, ymax - def axes_set_yscale(self, scale): - """ Set the Y-Scale to log or linear """ + def _axes_set_yscale(self, scale: str) -> None: + """ Set the Y-Scale to log or linear + + Parameters + ---------- + scale: str + Should be one of ``"log"`` or ``"linear"`` + """ logger.debug("yscale: '%s'", scale) - self.ax1.set_yscale(scale) - - def lines_sort(self, keys): - """ Sort the data keys into consistent order - and set line color map and line width """ - logger.trace("Sorting lines") - raw_lines = list() - sorted_lines = list() + self._ax1.set_yscale(scale) + + def _lines_sort(self, + keys: list[str]) -> list[list[str | int | tuple[float, float, float, float]]]: + """ Sort the data keys into consistent order and set line color map and line width. + + Parameters + ---------- + keys: list + The list of data point keys + + Returns + ------- + list + list[list[str | int | tuple[float, float, float, float]]] + """ + logger.trace("Sorting lines") # type:ignore[attr-defined] + raw_lines: list[list[str]] = [] + sorted_lines: list[list[str]] = [] for key in sorted(keys): title = key.replace("_", " ").title() - if key.startswith(("avg", "trend")): - sorted_lines.append([key, title]) - else: + if key.startswith("raw"): raw_lines.append([key, title]) + else: + sorted_lines.append([key, title]) - groupsize = self.lines_groupsize(raw_lines, sorted_lines) + groupsize = self._lines_groupsize(raw_lines, sorted_lines) sorted_lines = raw_lines + sorted_lines - lines = self.lines_style(sorted_lines, groupsize) + lines = self._lines_style(sorted_lines, groupsize) return lines @staticmethod - def lines_groupsize(raw_lines, sorted_lines): + def _lines_groupsize(raw_lines: list[list[str]], sorted_lines: list[list[str]]) -> int: """ Get the number of items in each group. - If raw data isn't selected, then check the length of - remaining groups until something is found """ + + If raw data isn't selected, then check the length of remaining groups until something is + found. + + Parameters + ---------- + raw_lines: list + The list of keys for the raw data points + sorted_lines: + The list of sorted line keys to display on the graph + + Returns + ------- + int + The size of each group that exist within the data set. + """ groupsize = 1 if raw_lines: groupsize = len(raw_lines) - else: - for check in ("avg", "trend"): - if any(item[0].startswith(check) for item in sorted_lines): - groupsize = len([item for item in sorted_lines if item[0].startswith(check)]) - break - logger.trace(groupsize) + elif sorted_lines: + keys = [key[0][:key[0].find("_")] for key in sorted_lines] + distinct_keys = set(keys) + groupsize = len(keys) // len(distinct_keys) + logger.trace(groupsize) # type:ignore[attr-defined] return groupsize - def lines_style(self, lines, groupsize): - """ Set the color map and line width for each group """ - logger.trace("Setting lines style") + def _lines_style(self, + lines: list[list[str]], + groupsize: int) -> list[list[str | int | tuple[float, float, float, float]]]: + """ Obtain the color map and line width for each group. + + Parameters + ---------- + lines: list + The list of sorted line keys to display on the graph + groupsize: int + The size of each group to display in the graph + + Returns + ------- + list[list[str | int | tuple[float, float, float, float]]] + A list of loss keys with their corresponding line formatting and color information + """ + logger.trace("Setting lines style") # type:ignore[attr-defined] groups = int(len(lines) / groupsize) - colours = self.lines_create_colors(groupsize, groups) - for idx, item in enumerate(lines): - linewidth = ceil((idx + 1) / groupsize) + colours = self._lines_create_colors(groupsize, groups) + widths = list(range(1, groups + 1)) + retval = T.cast(list[list[str | int | tuple[float, float, float, float]]], lines) + for idx, item in enumerate(retval): + linewidth = widths[idx // groupsize] item.extend((linewidth, colours[idx])) - return lines - - def lines_create_colors(self, groupsize, groups): - """ Create the colors """ - colours = list() + return retval + + def _lines_create_colors(self, + groupsize: int, + groups: int) -> list[tuple[float, float, float, float]]: + """ Create the color maps. + + Parameters + ---------- + groupsize: int + The size of each group to display in the graph + groups: int + The total number of groups to graph + + Returns + ------- + list[tuple[float, float, float, float] + The colour map for each group + """ + colours = [] for i in range(1, groups + 1): - for colour in self.colourmaps[0:groupsize]: + for colour in self._colourmaps[0:groupsize]: cmap = matplotlib.cm.get_cmap(colour) cpoint = 1 - (i / 5) colours.append(cmap(cpoint)) - logger.trace(colours) + logger.trace(colours) # type:ignore[attr-defined] return colours - def legend_place(self): - """ Place and format legend """ + def _legend_place(self) -> None: + """ Place and format the graph legend """ logger.debug("Placing legend") - self.ax1.legend(loc="upper right", ncol=2) - - def toolbar_place(self, parent): - """ Add Graph Navigation toolbar """ - logger.debug("Placing toolbar") - self.toolbar = NavigationToolbar(self.plotcanvas, parent) - self.toolbar.pack(side=tk.BOTTOM) - self.toolbar.update() + self._ax1.legend(loc="upper right", ncol=2) + def _toolbar_place(self, parent: ttk.Frame) -> None: + """ Add Graph Navigation toolbar. -class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors - """ Live graph to be displayed during training. """ - - def __init__(self, parent, data, ylabel): - GraphBase.__init__(self, parent, data, ylabel) - self.add_callback() + Parameters + ---------- + parent: ttk.Frame + The parent graph frame to place the toolbar onto + """ + logger.debug("Placing toolbar") + self._toolbar = NavigationToolbar(self._plotcanvas, parent) + self._toolbar.pack(side=tk.BOTTOM) + self._toolbar.update() + + def clear(self) -> None: + """ Clear the graph plots from RAM """ + logger.debug("Clearing graph from RAM: %s", self) + self._fig.clf() + del self._fig + + +class TrainingGraph(GraphBase): # pylint:disable=too-many-ancestors + """ Live graph to be displayed during training. + + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The parent frame that holds the graph + data: :class:`lib.gui.analysis.stats.Calculations` + The statistics class that holds the data to be displayed + ylabel: str + The data label for the y-axis + """ + + def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent, data, ylabel) + self._thread: LongRunningTask | None = None # Thread for LongRunningTask + self._displayed_keys: list[str] = [] + self._add_callback() + logger.debug("Initialized %s", self.__class__.__name__) - def add_callback(self): - """ Add the variable trace to update graph on recent button or save iteration """ - get_config().tk_vars["refreshgraph"].trace("w", self.refresh) + def _add_callback(self) -> None: + """ Add the variable trace to update graph on refresh button press or save iteration. """ + get_config().tk_vars.refresh_graph.trace("w", self.refresh) # type:ignore - def build(self): - """ Update the plot area with loss values """ + def build(self) -> None: + """ Build the Training graph. """ logger.debug("Building training graph") - self.plotcanvas.draw() + self._plotcanvas.draw() logger.debug("Built training graph") - def refresh(self, *args): # pylint: disable=unused-argument - """ Read loss data and apply to graph """ - logger.debug("Updating plot") - self.calcs.refresh() - self.update_plot(initiate=False) - self.plotcanvas.draw() - get_config().tk_vars["refreshgraph"].set(False) + def refresh(self, *args) -> None: # pylint:disable=unused-argument + """ Read the latest loss data and apply to current graph """ + refresh_var = T.cast(tk.BooleanVar, get_config().tk_vars.refresh_graph) + if not refresh_var.get() and self._thread is None: + return + + if self._thread is None: + logger.debug("Updating plot data") + self._thread = LongRunningTask(target=self._calcs.refresh) + self._thread.start() + self.after(1000, self.refresh) + elif not self._thread.complete.is_set(): + logger.debug("Graph Data not yet available") + self.after(1000, self.refresh) + else: + logger.debug("Updating plot with data from background thread") + self._calcs = self._thread.get_result() # Terminate the LongRunningTask object + self._thread = None + + dsp_keys = list(sorted(self._calcs.stats)) + if dsp_keys != self._displayed_keys: + logger.debug("Reinitializing graph for keys change. Old keys: %s New keys: %s", + self._displayed_keys, dsp_keys) + initiate = True + self._displayed_keys = dsp_keys + else: + initiate = False + + self._update_plot(initiate=initiate) + self._plotcanvas.draw() + refresh_var.set(False) + + def save_fig(self, location: str) -> None: + """ Save the current graph to file - def save_fig(self, location): - """ Save the figure to file """ + Parameters + ---------- + location: str + The full path to the folder where the current graph should be saved + """ logger.debug("Saving graph: '%s'", location) - keys = sorted([key.replace("raw_", "") for key in self.calcs.stats.keys() + keys = sorted([key.replace("raw_", "") for key in self._calcs.stats.keys() if key.startswith("raw_")]) filename = " - ".join(keys) now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - filename = os.path.join(location, "{}_{}.{}".format(filename, now, "png")) - self.fig.set_size_inches(16, 9) - self.fig.savefig(filename, bbox_inches="tight", dpi=120) - print("Saved graph to {}".format(filename)) + filename = os.path.join(location, f"{filename}_{now}.png") + self._fig.set_size_inches(16, 9) + self._fig.savefig(filename, bbox_inches="tight", dpi=120) + print(f"Saved graph to {filename}") logger.debug("Saved graph: '%s'", filename) - self.resize_fig() + self._resize_fig() - def resize_fig(self): - """ Resize the figure back to the canvas """ - class Event(): # pylint: disable=too-few-public-methods + def _resize_fig(self) -> None: + """ Resize the figure to the current canvas size. """ + class Event(): # pylint:disable=too-few-public-methods """ Event class that needs to be passed to plotcanvas.resize """ - pass - Event.width = self.winfo_width() - Event.height = self.winfo_height() - self.plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter - - -class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors - """ Session Graph for session pop-up """ - def __init__(self, parent, data, ylabel, scale): - GraphBase.__init__(self, parent, data, ylabel) - self.scale = scale + pass # pylint:disable=unnecessary-pass + setattr(Event, "width", self.winfo_width()) + setattr(Event, "height", self.winfo_height()) + self._plotcanvas.resize(Event) # pylint:disable=no-value-for-parameter + + +class SessionGraph(GraphBase): # pylint:disable=too-many-ancestors + """ Session Graph for session pop-up. + + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The parent frame that holds the graph + data: :class:`lib.gui.analysis.stats.Calculations` + The statistics class that holds the data to be displayed + ylabel: str + The data label for the y-axis + scale: str + Should be one of ``"log"`` or ``"linear"`` + """ + def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent, data, ylabel) + self._scale = scale + logger.debug("Initialized %s", self.__class__.__name__) - def build(self): + def build(self) -> None: """ Build the session graph """ logger.debug("Building session graph") - self.toolbar_place(self) - self.plotcanvas.draw() + self._toolbar_place(self) + self._plotcanvas.draw() logger.debug("Built session graph") - def refresh(self, data, ylabel, scale): - """ Refresh graph data """ + def refresh(self, data, ylabel: str, scale: str) -> None: + """ Refresh the Session Graph's data. + + Parameters + ---------- + data: :class:`lib.gui.analysis.stats.Calculations` + The statistics class that holds the data to be displayed + ylabel: str + The data label for the y-axis + scale: str + Should be one of ``"log"`` or ``"linear"`` + """ logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale) - self.calcs = data - self.ylabel = ylabel + self._calcs = data + self._ylabel = ylabel self.set_yscale_type(scale) logger.debug("Refreshed session graph") - def set_yscale_type(self, scale): - """ switch the y-scale and redraw """ + def set_yscale_type(self, scale: str) -> None: + """ Set the scale type for the y-axis and redraw. + + Parameters + ---------- + scale: str + Should be one of ``"log"`` or ``"linear"`` + """ + scale = scale.lower() logger.debug("Updating scale type: '%s'", scale) - self.scale = scale - self.update_plot(initiate=True) - self.axes_set_yscale(self.scale) - self.plotcanvas.draw() + self._scale = scale + self._update_plot(initiate=True) + self._axes_set_yscale(self._scale) + self._plotcanvas.draw() logger.debug("Updated scale type") + + +class NavigationToolbar(NavigationToolbar2Tk): # pylint:disable=too-many-ancestors + """ Overrides the default Navigation Toolbar to provide only the buttons we require + and to layout the items in a consistent manner with the rest of the GUI for the Analysis + Session Graph pop up Window. + + Parameters + ---------- + canvas: :class:`matplotlib.backends.backend_tkagg.FigureCanvasTkAgg` + The canvas that holds the displayed graph and will hold the toolbar + window: :class:`~lib.gui.display_graph.SessionGraph` + The Session Graph canvas + pack_toolbar: bool, Optional + Whether to pack the Tool bar or not. Default: ``True`` + """ + toolitems = tuple(t for t in NavigationToolbar2Tk.toolitems if + t[0] in ("Home", "Pan", "Zoom", "Save")) + + def __init__(self, # pylint:disable=super-init-not-called + canvas: FigureCanvasTkAgg, + window: ttk.Frame, + *, + pack_toolbar: bool = True) -> None: + logger.debug(parse_class_init(locals())) + # Avoid using self.window (prefer self.canvas.get_tk_widget().master), + # so that Tool implementations can reuse the methods. + + ttk.Frame.__init__(T.cast(ttk.Frame, self), # pylint:disable=non-parent-init-called + master=window, + width=int(canvas.figure.bbox.width), + height=50) + + sep = ttk.Frame(self, height=2, relief=tk.RIDGE) + sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP) + + btnframe = ttk.Frame(self) # Add a button frame to consistently line up GUI + btnframe.pack(fill=tk.X, padx=5, pady=5, side=tk.RIGHT) + + self._buttons = {} + for text, tooltip_text, image_file, callback in self.toolitems: + assert isinstance(text, str) + assert isinstance(image_file, str) + assert isinstance(callback, str) + self._buttons[text] = button = self._Button( + btnframe, + text, + image_file, + toggle=callback in ["zoom", "pan"], + command=getattr(self, callback), + ) + if tooltip_text is not None: + Tooltip(button, text=tooltip_text, wrap_length=200) + + self.message = tk.StringVar(master=self) + self._message_label = ttk.Label(master=self, textvariable=self.message) + self._message_label.pack(side=tk.LEFT, padx=5) # Additional left padding + + NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called + if pack_toolbar: + self.pack(side=tk.BOTTOM, fill=tk.X) + logger.debug("Initialized %s", self.__class__.__name__) + + @staticmethod + def _Button(frame: ttk.Frame, # type:ignore[override] # pylint:disable=arguments-differ,arguments-renamed # noqa:E501 + text: str, + image_file: str, + toggle: bool, + command) -> ttk.Button | ttk.Checkbutton: + """ Override the default button method to use our icons and ttk widgets for + consistent GUI layout. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that holds the buttons + text: str + The display text for the button + image_file: str + The name of the image file to use + toggle: bool + Whether to use a checkbutton (``True``) or a regular button (``False``) + command: method + The Navigation Toolbar callback method + + Returns + ------- + :class:`tkinter.ttk.Button` or :class:`tkinter.ttk.Checkbutton` + The widger to use. A button if the option is not toggleable, a checkbutton if the + option is toggleable. + """ + iconmapping = {"home": "reload", + "filesave": "save", + "zoom_to_rect": "zoom"} + icon = iconmapping[image_file] if iconmapping.get(image_file, None) else image_file + img = get_images().icons[icon] + + if not toggle: + btn: ttk.Button | ttk.Checkbutton = ttk.Button(frame, + text=text, + image=img, # type:ignore[arg-type] + command=command) + else: + var = tk.IntVar(master=frame) + btn = ttk.Checkbutton(frame, + text=text, + image=img, # type:ignore[arg-type] + command=command, variable=var) + + # Original implementation uses tk Checkbuttons which have a select and deselect + # method. These aren't available in ttk Checkbuttons, so we monkey patch the methods + # to update the underlying variable. + setattr(btn, "select", lambda i=1: var.set(i)) + setattr(btn, "deselect", lambda i=0: var.set(i)) + + btn.pack(side=tk.RIGHT, padx=2) + return btn + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/display_page.py b/lib/gui/display_page.py index a12b95115a..19444e57c3 100644 --- a/lib/gui/display_page.py +++ b/lib/gui/display_page.py @@ -1,28 +1,33 @@ #!/usr/bin python3 """ Display Page parent classes for display section of the Faceswap GUI """ +import gettext import logging import tkinter as tk from tkinter import ttk -from .tooltip import Tooltip +from lib.utils import get_module_objects + +from .custom_widgets import Tooltip from .utils import get_images -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext -class DisplayPage(ttk.Frame): +class DisplayPage(ttk.Frame): # pylint:disable=too-many-ancestors """ Parent frame holder for each tab. Defines uniform structure for each tab to inherit from """ - def __init__(self, parent, tabname, helptext): - logger.debug("Initializing %s: (tabname: '%s', helptext: %s", - self.__class__.__name__, tabname, helptext) - ttk.Frame.__init__(self, parent) - self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW) + def __init__(self, parent, tab_name, helptext): + super().__init__(parent) - self.runningtask = parent.runningtask + self._parent = parent + self.running_task = parent.running_task self.helptext = helptext - self.tabname = tabname + self.tabname = tab_name self.vars = {"info": tk.StringVar()} self.add_optional_vars(self.set_vars()) @@ -33,8 +38,14 @@ def __init__(self, parent, tabname, helptext): self.add_frame_separator() self.set_mainframe_single_tab_style() + + self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW) parent.add(self, text=self.tabname.title()) - logger.debug("Initialized %s", self.__class__.__name__,) + + @property + def _tab_is_active(self): + """ bool: ``True`` if the tab currently has focus otherwise ``False`` """ + return self._parent.tab(self._parent.select(), "text").lower() == self.tabname.lower() def add_optional_vars(self, varsdict): """ Add page specific variables """ @@ -43,10 +54,14 @@ def add_optional_vars(self, varsdict): logger.debug("Adding: (%s: %s)", key, val) self.vars[key] = val - @staticmethod - def set_vars(): + def set_vars(self): """ Override to return a dict of page specific variables """ - return dict() + return {} + + def on_tab_select(self): + """ Override for specific actions when the current tab is selected """ + logger.debug("Returning as 'on_tab_select' not implemented for %s", + self.__class__.__name__) def add_subnotebook(self): """ Add the main frame notebook """ @@ -67,9 +82,8 @@ def add_options_info(self): logger.debug("Adding options info") lblinfo = ttk.Label(self.optsframe, textvariable=self.vars["info"], - anchor=tk.W, - width=70) - lblinfo.pack(side=tk.LEFT, padx=5, pady=5, anchor=tk.W) + anchor=tk.W) + lblinfo.pack(side=tk.LEFT, expand=True, padx=5, pady=5, anchor=tk.W) def set_info(self, msg): """ Set the info message """ @@ -129,12 +143,11 @@ def subnotebook_get_widgets(self): subnotebook frame """ logger.debug("Getting subnotebook widgets") for child in self.subnotebook.winfo_children(): - for widget in child.winfo_children(): - yield widget + yield from child.winfo_children() def subnotebook_get_titles_ids(self): """ Return tabs ids and titles """ - tabs = dict() + tabs = {} for tab_id in range(0, self.subnotebook.index("end")): tabs[self.subnotebook.tab(tab_id, "text")] = tab_id logger.debug(tabs) @@ -147,12 +160,14 @@ def subnotebook_page_from_id(self, tab_id): return self.subnotebook.children[tab_name] -class DisplayOptionalPage(DisplayPage): +class DisplayOptionalPage(DisplayPage): # pylint:disable=too-many-ancestors """ Parent Context Sensitive Display Tab """ - def __init__(self, parent, tabname, helptext, waittime): - DisplayPage.__init__(self, parent, tabname, helptext) + def __init__(self, parent, tab_name, helptext, wait_time, command=None): + super().__init__(parent, tab_name, helptext) + self._waittime = wait_time + self.command = command self.display_item = None self.set_info_text() @@ -160,10 +175,9 @@ def __init__(self, parent, tabname, helptext, waittime): parent.select(self) self.update_idletasks() - self.update_page(waittime) + self._update_page() - @staticmethod - def set_vars(): + def set_vars(self): """ Analysis specific vars """ enabled = tk.BooleanVar() enabled.set(True) @@ -171,24 +185,28 @@ def set_vars(): ready = tk.BooleanVar() ready.set(False) - modified = tk.DoubleVar() - modified.set(None) - tk_vars = {"enabled": enabled, - "ready": ready, - "modified": modified} + "ready": ready} logger.debug(tk_vars) return tk_vars + def on_tab_select(self): + """ Callback for when the optional tab is selected. + + Run the tab's update code when the tab is selected. + """ + logger.debug("Callback received for '%s' tab", self.tabname) + self._update_page() + # INFO LABEL def set_info_text(self): """ Set waiting for display text """ if not self.vars["enabled"].get(): - msg = "{} disabled".format(self.tabname.title()) + msg = f"{self.tabname.title()} disabled" elif self.vars["enabled"].get() and not self.vars["ready"].get(): - msg = "Waiting for {}...".format(self.tabname) + msg = f"Waiting for {self.tabname}..." else: - msg = "Displaying {}".format(self.tabname) + msg = f"Displaying {self.tabname}" logger.debug(msg) self.set_info(msg) @@ -206,27 +224,27 @@ def add_option_save(self): command=self.save_items) btnsave.pack(padx=2, side=tk.RIGHT) Tooltip(btnsave, - text="Save {}(s) to file".format(self.tabname), - wraplength=200) + text=_(f"Save {self.tabname}(s) to file"), + wrap_length=200) def add_option_enable(self): - """ Add checkbutton to enable/disable page """ + """ Add check-button to enable/disable page """ logger.debug("Adding enable option") chkenable = ttk.Checkbutton(self.optsframe, variable=self.vars["enabled"], - text="Enable {}".format(self.tabname), + text=f"Enable {self.tabname}", command=self.on_chkenable_change) chkenable.pack(side=tk.RIGHT, padx=5, anchor=tk.W) Tooltip(chkenable, - text="Enable or disable {} display".format(self.tabname), - wraplength=200) + text=_(f"Enable or disable {self.tabname} display"), + wrap_length=200) def save_items(self): """ Save items. Override for display specific saving """ raise NotImplementedError() def on_chkenable_change(self): - """ Update the display immediately on a checkbutton change """ + """ Update the display immediately on a check-button change """ logger.debug("Enabled checkbox changed") if self.vars["enabled"].get(): self.subnotebook_show() @@ -234,15 +252,15 @@ def on_chkenable_change(self): self.subnotebook_hide() self.set_info_text() - def update_page(self, waittime): + def _update_page(self): """ Update the latest preview item """ - if not self.runningtask.get(): + if not self.running_task.get() or not self._tab_is_active: return if self.vars["enabled"].get(): - logger.trace("Updating page") + logger.trace("Updating page: %s", self.__class__.__name__) self.display_item_set() self.load_display() - self.after(waittime, lambda t=waittime: self.update_page(t)) + self.after(self._waittime, self._update_page) def display_item_set(self): """ Override for display specific loading """ @@ -250,9 +268,9 @@ def display_item_set(self): def load_display(self): """ Load the display """ - if not self.display_item: + if not self.display_item or not self._tab_is_active: return - logger.debug("Loading display") + logger.debug("Loading display for tab: %s", self.tabname) self.display_item_process() self.vars["ready"].set(True) self.set_info_text() @@ -260,3 +278,14 @@ def load_display(self): def display_item_process(self): """ Override for display specific loading """ raise NotImplementedError() + + def close(self): + """ Called when the parent notebook is shutting down + Children must be destroyed as forget only hides display + Override for page specific shutdown """ + for child in self.winfo_children(): + logger.debug("Destroying child: %s", child) + child.destroy() + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/gui_config.py b/lib/gui/gui_config.py new file mode 100644 index 0000000000..a752e9ab0c --- /dev/null +++ b/lib/gui/gui_config.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" Default configurations for the GUI """ + +import logging +import os + +from tkinter import font as tk_font +from matplotlib import font_manager + +from lib.config import FaceswapConfig +from lib.config import ConfigItem +from lib.utils import get_module_objects, PROJECT_ROOT + +logger = logging.getLogger(__name__) + + +class _Config(FaceswapConfig): + """ Config File for GUI """ + def set_defaults(self, helptext="") -> None: + """ Set the default values for config """ + logger.debug("Setting defaults") + super().set_defaults( + helptext="Faceswap GUI Options.\nConfigure the appearance and behaviour of the GUI") + # Font choices cannot be added until tkinter has been launched + logger.debug("Adding font list from tkinter") + self.sections["global"].options["font"].choices = get_clean_fonts() + + +def get_commands() -> list[str]: + """ Return commands formatted for GUI + + Returns + ------- + list[str] + A list of faceswap and tools commands that can be displayed in Faceswap's GUI + """ + command_path = os.path.join(PROJECT_ROOT, "scripts") + tools_path = os.path.join(PROJECT_ROOT, "tools") + commands = [os.path.splitext(item)[0] for item in os.listdir(command_path) + if os.path.splitext(item)[1] == ".py" + and os.path.splitext(item)[0] not in ("gui", "fsmedia") + and not os.path.splitext(item)[0].startswith("_")] + tools = [os.path.splitext(item)[0] for item in os.listdir(tools_path) + if os.path.splitext(item)[1] == ".py" + and os.path.splitext(item)[0] not in ("gui", "cli") + and not os.path.splitext(item)[0].startswith("_")] + return commands + tools + + +def get_clean_fonts() -> list[str]: + """ Return a sane list of fonts for the system that has both regular and bold variants. + + Pre-pend "default" to the beginning of the list. + + Returns + ------- + list[str]: + A list of valid fonts for the system + """ + fmanager = font_manager.FontManager() + fonts: dict[str, dict[str, bool]] = {} + for fnt in fmanager.ttflist: + if str(fnt.weight) in ("400", "normal", "regular"): + fonts.setdefault(fnt.name, {})["regular"] = True + if str(fnt.weight) in ("700", "bold"): + fonts.setdefault(fnt.name, {})["bold"] = True + valid_fonts = {key for key, val in fonts.items() if len(val) == 2} + retval = sorted(list(valid_fonts.intersection(tk_font.families()))) + if not retval: + # Return the font list with any @prefixed or non-Unicode characters stripped and default + # prefixed + logger.debug("No bold/regular fonts found. Running simple filter") + retval = sorted([fnt for fnt in tk_font.families() + if not fnt.startswith("@") and not any(ord(c) > 127 for c in fnt)]) + return ["default"] + retval + + +fullscreen = ConfigItem( + datatype=bool, + default=False, + group="startup", + info="Start Faceswap maximized.") + + +tab = ConfigItem( + datatype=str, + default="extract", + group="startup", + info="Start Faceswap in this tab.", + choices=get_commands()) + + +options_panel_width = ConfigItem( + datatype=int, + default=30, + group="layout", + info="How wide the lefthand option panel is as a percentage of GUI width at " + "startup.", + min_max=(10, 90), + rounding=1) + + +console_panel_height = ConfigItem( + datatype=int, + default=20, + group="layout", + info="How tall the bottom console panel is as a percentage of GUI height at " + "startup.", + min_max=(10, 90), + rounding=1) + + +icon_size = ConfigItem( + datatype=int, + default=14, + group="layout", + info="Pixel size for icons. NB: Size is scaled by DPI.", + min_max=(10, 20), + rounding=1) + + +font = ConfigItem( + datatype=str, + default="default", + group="font", + info="Global font", + choices=["default"]) # Cannot get tk fonts until tk is loaded, so real value populated later + + +font_size = ConfigItem( + datatype=int, + default=9, + group="font", + info="Global font size.", + min_max=(6, 12), + rounding=1) + + +autosave_last_session = ConfigItem( + datatype=str, + default="prompt", + group="startup", + info="Automatically save the current settings on close and reload on startup" + "\n\tnever - Don't autosave session" + "\n\tprompt - Prompt to reload last session on launch" + "\n\talways - Always load last session on launch", + choices=["never", "prompt", "always"], + gui_radio=True) + + +timeout = ConfigItem( + datatype=int, + default=120, + group="behaviour", + info="Training can take some time to save and shutdown. Set the timeout " + "in seconds before giving up and force quitting.", + min_max=(10, 600), + rounding=10) + + +auto_load_model_stats = ConfigItem( + datatype=bool, + default=True, + group="behaviour", + info="Auto load model statistics into the Analysis tab when selecting a model " + "in Train or Convert tabs.") + + +def load_config(config_file: str | None = None) -> None: + """ Load the GUI configuration .ini file + + Parameters + ---------- + config_file : str | None, optional + Path to a custom .ini configuration file to load. Default: ``None`` (use default + configuration file) + """ + _Config(configfile=config_file) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/menu.py b/lib/gui/menu.py index c79ec0fe0c..226b6be462 100644 --- a/lib/gui/menu.py +++ b/lib/gui/menu.py @@ -1,149 +1,628 @@ #!/usr/bin python3 """ The Menu Bars for faceswap GUI """ - +from __future__ import annotations +import gettext import logging import os -import sys import tkinter as tk +import typing as T +from tkinter import ttk +import webbrowser + +from lib.git import git +from lib.multithreading import MultiThread +from lib.serializer import get_serializer, Serializer +from lib.utils import FaceswapError, get_module_objects +import update_deps + +from .popup_configure import open_popup +from .custom_widgets import Tooltip +from .utils import get_config, get_images -from importlib import import_module +if T.TYPE_CHECKING: + from scripts.gui import FaceswapGui -from lib.Serializer import JSONSerializer +logger = logging.getLogger(__name__) -from .utils import get_config -from .popup_configure import popup_config +# LOCALES +_LANG = gettext.translation("gui.menu", localedir="locales", fallback=True) +_ = _LANG.gettext +_RESOURCES: list[tuple[str, str]] = [ + (_("faceswap.dev - Guides and Forum"), "https://www.faceswap.dev"), + (_("Patreon - Support this project"), "https://www.patreon.com/faceswap"), + (_("Discord - The FaceSwap Discord server"), "https://discord.gg/VasFUAy"), + (_("Github - Our Source Code"), "https://github.com/deepfakes/faceswap")] -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +class MainMenuBar(tk.Menu): # pylint:disable=too-many-ancestors + """ GUI Main Menu Bar -class MainMenuBar(tk.Menu): - """ GUI Main Menu Bar """ - def __init__(self, master=None): + Parameters + ---------- + master: :class:`tkinter.Tk` + The root tkinter object + """ + def __init__(self, master: FaceswapGui) -> None: logger.debug("Initializing %s", self.__class__.__name__) super().__init__(master) self.root = master - self.config = get_config() - self.file_menu = tk.Menu(self, tearoff=0) - self.recent_menu = tk.Menu(self.file_menu, tearoff=0, postcommand=self.refresh_recent_menu) - self.edit_menu = tk.Menu(self, tearoff=0) - self.tools_menu = tk.Menu(self, tearoff=0) + self.file_menu = FileMenu(self) + self.settings_menu = SettingsMenu(self) + self.help_menu = HelpMenu(self) + + self.add_cascade(label=_("File"), menu=self.file_menu, underline=0) + self.add_cascade(label=_("Settings"), menu=self.settings_menu, underline=0) + self.add_cascade(label=_("Help"), menu=self.help_menu, underline=0) + logger.debug("Initialized %s", self.__class__.__name__) + + +class SettingsMenu(tk.Menu): # pylint:disable=too-many-ancestors + """ Settings menu items and functions + + Parameters + ---------- + parent: :class:`tkinter.Menu` + The main menu bar to hold this menu item + """ + def __init__(self, parent: MainMenuBar) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(parent, tearoff=0) + self.root = parent.root + self._build() + logger.debug("Initialized %s", self.__class__.__name__) + + def _build(self) -> None: + """ Add the settings menu to the menu bar """ + # pylint:disable=cell-var-from-loop + logger.debug("Building settings menu") + self.add_command(label=_("Configure Settings..."), + underline=0, + command=open_popup) + logger.debug("Built settings menu") - self.build_file_menu() - self.build_edit_menu() - self.build_tools_menu() + +class FileMenu(tk.Menu): # pylint:disable=too-many-ancestors + """ File menu items and functions + + Parameters + ---------- + parent: :class:`tkinter.Menu` + The main menu bar to hold this menu item + """ + def __init__(self, parent: MainMenuBar) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(parent, tearoff=0) + self.root = parent.root + self._config = get_config() + self.recent_menu = tk.Menu(self, tearoff=0, postcommand=self._refresh_recent_menu) + self._build() logger.debug("Initialized %s", self.__class__.__name__) - def build_file_menu(self): + def _refresh_recent_menu(self) -> None: + """ Refresh recent menu on save/load of files """ + self.recent_menu.delete(0, "end") + self._build_recent_menu() + + def _build(self) -> None: """ Add the file menu to the menu bar """ logger.debug("Building File menu") - self.file_menu.add_command( - label="Load full config...", underline=0, command=self.config.load) - self.file_menu.add_command( - label="Save full config...", underline=0, command=self.config.save) - self.file_menu.add_separator() - self.file_menu.add_cascade(label="Open recent", underline=6, menu=self.recent_menu) - self.file_menu.add_separator() - self.file_menu.add_command( - label="Reset all to default", underline=0, command=self.config.cli_opts.reset) - self.file_menu.add_command( - label="Clear all", underline=0, command=self.config.cli_opts.clear) - self.file_menu.add_separator() - self.file_menu.add_command(label="Quit", underline=0, command=self.root.close_app) - self.add_cascade(label="File", menu=self.file_menu, underline=0) + self.add_command(label=_("New Project..."), + underline=0, + accelerator="Ctrl+N", + command=self._config.project.new) + self.root.bind_all("", self._config.project.new) + self.add_command(label=_("Open Project..."), + underline=0, + accelerator="Ctrl+O", + command=self._config.project.load) + self.root.bind_all("", self._config.project.load) + self.add_command(label=_("Save Project"), + underline=0, + accelerator="Ctrl+S", + command=lambda: self._config.project.save(save_as=False)) + self.root.bind_all("", lambda e: self._config.project.save(e, save_as=False)) + self.add_command(label=_("Save Project as..."), + underline=13, + accelerator="Ctrl+Alt+S", + command=lambda: self._config.project.save(save_as=True)) + self.root.bind_all("", lambda e: self._config.project.save(e, save_as=True)) + self.add_command(label=_("Reload Project from Disk"), + underline=0, + accelerator="F5", + command=self._config.project.reload) + self.root.bind_all("", self._config.project.reload) + self.add_command(label=_("Close Project"), + underline=0, + accelerator="Ctrl+W", + command=self._config.project.close) + self.root.bind_all("", self._config.project.close) + self.add_separator() + self.add_command(label=_("Open Task..."), + underline=5, + accelerator="Ctrl+Alt+T", + command=lambda: self._config.tasks.load(current_tab=False)) + self.root.bind_all("", + lambda e: self._config.tasks.load(e, current_tab=False)) + self.add_separator() + self.add_cascade(label=_("Open recent"), underline=6, menu=self.recent_menu) + self.add_separator() + self.add_command(label=_("Quit"), + underline=0, + accelerator="Alt+F4", + command=self.root.close_app) + self.root.bind_all("", self.root.close_app) logger.debug("Built File menu") - def build_recent_menu(self): + @classmethod + def _clear_recent_files(cls, serializer: Serializer, menu_file: str) -> None: + """ Creates or clears recent file list + + Parameters + ---------- + serializer: :class:`~lib.serializer.Serializer` + The serializer to use for storing files + menu_file: str + The file name holding the recent files + """ + logger.debug("clearing recent files list: '%s'", menu_file) + serializer.save(menu_file, []) + + def _build_recent_menu(self) -> None: """ Load recent files into menu bar """ logger.debug("Building Recent Files menu") - serializer = JSONSerializer - menu_file = os.path.join(self.config.pathcache, ".recent.json") - if not os.path.isfile(menu_file): - self.clear_recent_files(serializer, menu_file) - with open(menu_file, "rb") as inp: - recent_files = serializer.unmarshal(inp.read().decode("utf-8")) - logger.debug("Loaded recent files: %s", recent_files) + serializer = get_serializer("json") + menu_file = os.path.join(self._config.pathcache, ".recent.json") + recent_files = [] + if not os.path.isfile(menu_file) or os.path.getsize(menu_file) == 0: + self._clear_recent_files(serializer, menu_file) + try: + recent_files = serializer.load(menu_file) + except FaceswapError as err: + if "Error unserializing data for type" in str(err): + # Some reports of corruption breaking menus + logger.warning("There was an error opening the recent files list so it has been " + "reset.") + self._clear_recent_files(serializer, menu_file) + + logger.debug("Loaded recent files: %s", recent_files) + removed_files = [] for recent_item in recent_files: filename, command = recent_item - logger.debug("processing: ('%s', %s)", filename, command) if not os.path.isfile(filename): - logger.debug("File does not exist") + logger.debug("File does not exist. Flagging for removal: '%s'", filename) + removed_files.append(recent_item) continue - lbl_command = command if command else "All" + # Legacy project files didn't have a command stored + command = command if command else "project" + logger.debug("processing: ('%s', %s)", filename, command) + if command.lower() == "project": + load_func = self._config.project.load + lbl = command + kwargs = {"filename": filename} + else: + load_func = self._config.tasks.load # type:ignore + lbl = _("{} Task").format(command) + kwargs = {"filename": filename, "current_tab": False} self.recent_menu.add_command( - label="{} ({})".format(filename, lbl_command.title()), - command=lambda fnm=filename, cmd=command: self.config.load(cmd, fnm)) + label=f"{filename} ({lbl.title()})", + command=lambda kw=kwargs, fn=load_func: fn(**kw)) # type:ignore + if removed_files: + for recent_item in removed_files: + logger.debug("Removing from recent files: `%s`", recent_item[0]) + recent_files.remove(recent_item) + serializer.save(menu_file, recent_files) self.recent_menu.add_separator() self.recent_menu.add_command( - label="Clear recent files", + label=_("Clear recent files"), underline=0, - command=lambda srl=serializer, mnu=menu_file: self.clear_recent_files(srl, mnu)) + command=lambda srl=serializer, mnu=menu_file: self._clear_recent_files( # type:ignore + srl, mnu)) logger.debug("Built Recent Files menu") - @staticmethod - def clear_recent_files(serializer, menu_file): - """ Creates or clears recent file list """ - logger.debug("clearing recent files list: '%s'", menu_file) - recent_files = serializer.marshal(list()) - with open(menu_file, "wb") as out: - out.write(recent_files.encode("utf-8")) - def refresh_recent_menu(self): - """ Refresh recent menu on save/load of files """ - self.recent_menu.delete(0, "end") - self.build_recent_menu() - - def build_edit_menu(self): - """ Add the edit menu to the menu bar """ - logger.debug("Building Edit menu") - configs = self.scan_for_configs() - for name in sorted(list(configs.keys())): - label = "Configure {} Plugins...".format(name.title()) - config = configs[name] - self.edit_menu.add_command( - label=label, - underline=10, - command=lambda conf=(name, config), root=self.root: popup_config(conf, root)) - self.add_cascade(label="Edit", menu=self.edit_menu, underline=0) - logger.debug("Built Edit menu") - - def scan_for_configs(self): - """ Scan for config.ini file locations """ - root_path = os.path.abspath(os.path.dirname(sys.argv[0])) - plugins_path = os.path.join(root_path, "plugins") - logger.debug("Scanning path: '%s'", plugins_path) - configs = dict() - for dirpath, _, filenames in os.walk(plugins_path): - if "_config.py" in filenames: - plugin_type = os.path.split(dirpath)[-1] - config = self.load_config(plugin_type) - configs[plugin_type] = config - logger.debug("Configs loaded: %s", sorted(list(configs.keys()))) - return configs - - @staticmethod - def load_config(plugin_type): - """ Load the config to generate config file if it doesn't exist and get filename """ - # Load config to generate default if doesn't exist - mod = ".".join(("plugins", plugin_type, "_config")) - module = import_module(mod) - config = module.Config(None) - logger.debug("Found '%s' config at '%s'", plugin_type, config.configfile) - return config - - def build_tools_menu(self): - """ Add the file menu to the menu bar """ - logger.debug("Building Tools menu") - self.tools_menu.add_command( - label="Output System Information", underline=0, command=self.output_sysinfo) - self.add_cascade(label="Tools", menu=self.tools_menu, underline=0) - logger.debug("Built Tools menu") - - @staticmethod - def output_sysinfo(): +class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors + """ Help menu items and functions + + Parameters + ---------- + parent: :class:`tkinter.Menu` + The main menu bar to hold this menu item + """ + def __init__(self, parent: MainMenuBar) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(parent, tearoff=0) + self.root = parent.root + self.recources_menu = tk.Menu(self, tearoff=0) + self._branches_menu = tk.Menu(self, tearoff=0) + self._build() + logger.debug("Initialized %s", self.__class__.__name__) + + def _in_thread(self, action: str): + """ Perform selected action inside a thread + + Parameters + ---------- + action: str + The action to be performed. The action corresponds to the function name to be called + """ + logger.debug("Performing help action: %s", action) + thread = MultiThread(getattr(self, action), thread_count=1) + thread.start() + logger.debug("Performed help action: %s", action) + + def _output_sysinfo(self): """ Output system information to console """ - get_config().tk_vars["consoleclear"].set(True) - from lib.sysinfo import SysInfo - print(SysInfo().full_info()) + logger.debug("Obtaining system information") + self.root.config(cursor="watch") + self._clear_console() + try: + from lib.system.sysinfo import sysinfo # pylint:disable=import-outside-toplevel + info = sysinfo + except Exception as err: # pylint:disable=broad-except + info = f"Error obtaining system info: {str(err)}" + self._clear_console() + logger.debug("Obtained system information: %s", info) + print(info) + self.root.config(cursor="") + + @classmethod + def _process_status_output(cls, status: list[str]) -> bool: + """ Process the output of a git status call and output information + + Parameters + ---------- + status : list[str] + The lines returned from a git status call + + Returns + ------- + bool + ``True`` if the repo can be updated otherwise ``False`` + """ + for line in status: + if line.lower().startswith("your branch is ahead"): + logger.warning("Your branch is ahead of the remote repo. Not updating") + return False + if line.lower().startswith("your branch is up to date"): + logger.info("Faceswap is up to date.") + return False + if "have diverged" in line.lower(): + logger.warning("Your branch has diverged from the remote repo. Not updating") + return False + if line.lower().startswith("your branch is behind"): + return True + + logger.warning("Unable to retrieve status of branch") + return False + + def _check_for_updates(self, check: bool = False) -> bool: + """ Check whether an update is required + + Parameters + ---------- + check: bool + ``True`` if we are just checking for updates ``False`` if a check and update is to be + performed. Default: ``False`` + + Returns + ------- + bool + ``True`` if an update is required + """ + # Do the check + logger.info("Checking for updates...") + msg = ("Git is not installed or you are not running a cloned repo. " + "Unable to check for updates") + + sync = git.update_remote() + if not sync: + logger.warning(msg) + return False + + status = git.status + if not status: + logger.warning(msg) + return False + + retval = self._process_status_output(status) + if retval and check: + logger.info("There are updates available") + return retval + + def _check(self) -> None: + """ Check for updates and clone repository """ + logger.debug("Checking for updates...") + self.root.config(cursor="watch") + self._check_for_updates(check=True) + self.root.config(cursor="") + + def _do_update(self) -> bool: + """ Update Faceswap + + Returns + ------- + bool + ``True`` if update was successful + """ + logger.info("A new version is available. Updating...") + success = git.pull() + if not success: + logger.info("An error occurred during update") + return success + + def _update(self) -> None: + """ Check for updates and clone repository """ + logger.debug("Updating Faceswap...") + self.root.config(cursor="watch") + success = False + if self._check_for_updates(): + success = self._do_update() + update_deps.main(is_gui=True) + if success: + logger.info("Please restart Faceswap to complete the update.") + self.root.config(cursor="") + + def _build(self) -> None: + """ Build the help menu """ + logger.debug("Building Help menu") + + self.add_command(label=_("Check for updates..."), + underline=0, + command=lambda action="_check": self._in_thread(action)) # type:ignore + self.add_command(label=_("Update Faceswap..."), + underline=0, + command=lambda action="_update": self._in_thread(action)) # type:ignore + if self._build_branches_menu(): + self.add_cascade(label=_("Switch Branch"), underline=7, menu=self._branches_menu) + self.add_separator() + self._build_recources_menu() + self.add_cascade(label=_("Resources"), underline=0, menu=self.recources_menu) + self.add_separator() + self.add_command( + label=_("Output System Information"), + underline=0, + command=lambda action="_output_sysinfo": self._in_thread(action)) # type:ignore + logger.debug("Built help menu") + + def _build_branches_menu(self) -> bool: + """ Build branch selection menu. + + Queries git for available branches and builds a menu based on output. + + Returns + ------- + bool + ``True`` if menu was successfully built otherwise ``False`` + """ + branches = git.branches + if not branches: + return False + + branches = self._filter_branches(branches) + if not branches: + return False + + for branch in branches: + self._branches_menu.add_command( + label=branch, + command=lambda b=branch: self._switch_branch(b)) # type:ignore + return True + + @classmethod + def _filter_branches(cls, branches: list[str]) -> list[str]: + """ Filter the branches, remove any non-local branches + + Parameters + ---------- + branches: list[str] + list of available git branches + + Returns + ------- + list[str] + Unique list of available branches sorted in alphabetical order + """ + current = None + unique = set() + for line in branches: + branch = line.strip() + if branch.startswith("remotes"): + continue + if branch.startswith("*"): + branch = branch.replace("*", "").strip() + current = branch + continue + unique.add(branch) + logger.debug("Found branches: %s", unique) + if current in unique: + logger.debug("Removing current branch from output: %s", current) + unique.remove(current) + + retval = sorted(list(unique), key=str.casefold) + logger.debug("Final branches: %s", retval) + return retval + + @classmethod + def _switch_branch(cls, branch: str) -> None: + """ Change the currently checked out branch, and return a notification. + + Parameters + ---------- + str + The branch to switch to + """ + logger.info("Switching branch to '%s'...", branch) + if not git.checkout(branch): + logger.error("Unable to switch branch to '%s'", branch) + return + logger.info("Succesfully switched to '%s'. You may want to check for updates to make sure " + "that you have the latest code.", branch) + logger.info("Please restart Faceswap to complete the switch.") + + def _build_recources_menu(self) -> None: + """ Build resources menu """ + # pylint:disable=cell-var-from-loop + logger.debug("Building Resources Files menu") + for resource in _RESOURCES: + self.recources_menu.add_command( + label=resource[0], + command=lambda link=resource[1]: webbrowser.open_new(link)) # type:ignore + logger.debug("Built resources menu") + + @classmethod + def _clear_console(cls) -> None: + """ Clear the console window """ + get_config().tk_vars.console_clear.set(True) + + +class TaskBar(ttk.Frame): # pylint:disable=too-many-ancestors + """ Task bar buttons + + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The frame that holds the task bar + """ + def __init__(self, parent: ttk.Frame) -> None: + super().__init__(parent) + self._config = get_config() + self.pack(side=tk.TOP, anchor=tk.W, fill=tk.X, expand=False) + self._btn_frame = ttk.Frame(self) + self._btn_frame.pack(side=tk.TOP, pady=2, anchor=tk.W, fill=tk.X, expand=False) + + self._project_btns() + self._group_separator() + self._task_btns() + self._group_separator() + self._settings_btns() + self._section_separator() + + @classmethod + def _loader_and_kwargs(cls, btntype: str) -> tuple[str, dict[str, bool]]: + """ Get the loader name and key word arguments for the given button type + + Parameters + ---------- + btntype: str + The button type to obtain the information for + + Returns + ------- + loader: str + The name of the loader to use for the given button type + kwargs: dict[str, bool] + The keyword arguments to use for the returned loader + """ + if btntype == "save": + loader = btntype + kwargs = {"save_as": False} + elif btntype == "save_as": + loader = "save" + kwargs = {"save_as": True} + else: + loader = btntype + kwargs = {} + logger.debug("btntype: %s, loader: %s, kwargs: %s", btntype, loader, kwargs) + return loader, kwargs + + @classmethod + def _set_help(cls, btntype: str) -> str: + """ Set the helptext for option buttons + + Parameters + ---------- + btntype: str + The button type to set the help text for + """ + logger.debug("Setting help") + hlp = "" + task = _("currently selected Task") if btntype[-1] == "2" else _("Project") + if btntype.startswith("reload"): + hlp = _("Reload {} from disk").format(task) + if btntype == "new": + hlp = _("Create a new {}...").format(task) + if btntype.startswith("clear"): + hlp = _("Reset {} to default").format(task) + elif btntype.startswith("save") and "_" not in btntype: + hlp = _("Save {}").format(task) + elif btntype.startswith("save_as"): + hlp = _("Save {} as...").format(task) + elif btntype.startswith("load"): + msg = task + if msg.endswith("Task"): + msg += _(" from a task or project file") + hlp = _("Load {}...").format(msg) + return hlp + + def _project_btns(self) -> None: + """ Place the project buttons """ + frame = ttk.Frame(self._btn_frame) + frame.pack(side=tk.LEFT, anchor=tk.W, expand=False, padx=2) + + for btntype in ("new", "load", "save", "save_as", "reload"): + logger.debug("Adding button: '%s'", btntype) + + loader, kwargs = self._loader_and_kwargs(btntype) + cmd = getattr(self._config.project, loader) + btn = ttk.Button(frame, + image=get_images().icons[btntype], # type:ignore[arg-type] + command=lambda fn=cmd, kw=kwargs: fn(**kw)) # type:ignore[misc] + btn.pack(side=tk.LEFT, anchor=tk.W) + hlp = self._set_help(btntype) + Tooltip(btn, text=hlp, wrap_length=200) + + def _task_btns(self) -> None: + """ Place the task buttons """ + frame = ttk.Frame(self._btn_frame) + frame.pack(side=tk.LEFT, anchor=tk.W, expand=False, padx=2) + + for loadtype in ("load", "save", "save_as", "clear", "reload"): + btntype = f"{loadtype}2" + logger.debug("Adding button: '%s'", btntype) + + loader, kwargs = self._loader_and_kwargs(loadtype) + if loadtype == "load": + kwargs["current_tab"] = True + cmd = getattr(self._config.tasks, loader) + btn = ttk.Button( + frame, + image=get_images().icons[btntype], # type:ignore[arg-type] + command=lambda fn=cmd, kw=kwargs: fn(**kw)) # type:ignore[misc] + btn.pack(side=tk.LEFT, anchor=tk.W) + hlp = self._set_help(btntype) + Tooltip(btn, text=hlp, wrap_length=200) + + def _settings_btns(self) -> None: + """ Place the settings buttons """ + # pylint:disable=cell-var-from-loop + frame = ttk.Frame(self._btn_frame) + frame.pack(side=tk.LEFT, anchor=tk.W, expand=False, padx=2) + for name in ("extract", "train", "convert"): + btntype = f"settings_{name}" + btntype = btntype if btntype in get_images().icons else "settings" + logger.debug("Adding button: '%s'", btntype) + btn = ttk.Button( + frame, + image=get_images().icons[btntype], # type:ignore[arg-type] + command=lambda n=name: open_popup(name=n)) # type:ignore[misc] + btn.pack(side=tk.LEFT, anchor=tk.W) + hlp = _("Configure {} settings...").format(name.title()) + Tooltip(btn, text=hlp, wrap_length=200) + + def _group_separator(self) -> None: + """ Place a group separator """ + separator = ttk.Separator(self._btn_frame, orient="vertical") + separator.pack(padx=(2, 1), fill=tk.Y, side=tk.LEFT) + + def _section_separator(self) -> None: + """ Place a section separator """ + frame = ttk.Frame(self) + frame.pack(side=tk.BOTTOM, fill=tk.X) + separator = ttk.Separator(frame, orient="horizontal") + separator.pack(fill=tk.X, side=tk.LEFT, expand=True) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/options.py b/lib/gui/options.py index 32eb40e309..c941910a6e 100644 --- a/lib/gui/options.py +++ b/lib/gui/options.py @@ -1,240 +1,659 @@ #!/usr/bin python3 """ Cli Options for the GUI """ +from __future__ import annotations + import inspect from argparse import SUPPRESS +from dataclasses import dataclass +from importlib import import_module import logging -from tkinter import ttk +import os +import re +import sys +import typing as T + +from lib.cli import actions +from lib.utils import get_module_objects -from lib import cli -import tools.cli as ToolsCli from .utils import get_images +from .control_helper import ControlPanelOption + +if T.TYPE_CHECKING: + from tkinter import Variable + from types import ModuleType + from lib.cli.args import FaceSwapArgs + +logger = logging.getLogger(__name__) + + +@dataclass +class CliOption: + """ A parsed command line option -logger = logging.getLogger(__name__) # pylint: disable=invalid-name + Parameters + ---------- + cpanel_option: :class:`~lib.gui.control_helper.ControlPanelOption`: + Object to hold information of a command line item for displaying in a GUI + :class:`~lib.gui.control_helper.ControlPanel` + opts: tuple[str, ...]: + The short switch and long name (if exists) of the command line option + nargs: Literal["+"] | None: + ``None`` for not used. "+" for at least 1 argument required with values to be contained + in a list + """ + cpanel_option: ControlPanelOption + """:class:`~lib.gui.control_helper.ControlPanelOption`: Object to hold information of a command + line item for displaying in a GUI :class:`~lib.gui.control_helper.ControlPanel`""" + opts: tuple[str, ...] + """tuple[str, ...]: The short switch and long name (if exists) of cli option """ + nargs: T.Literal["+"] | None + """Literal["+"] | None: ``None`` for not used. "+" for at least 1 argument required with + values to be contained in a list """ class CliOptions(): """ Class and methods for the command line options """ - def __init__(self): + def __init__(self) -> None: logger.debug("Initializing %s", self.__class__.__name__) - self.categories = ("faceswap", "tools") - self.commands = dict() - self.opts = dict() - self.build_options() + self._base_path = os.path.realpath(os.path.dirname(sys.argv[0])) + self._commands: dict[T.Literal["faceswap", "tools"], list[str]] = {"faceswap": [], + "tools": []} + self._opts: dict[str, dict[str, CliOption | str]] = {} + self._build_options() logger.debug("Initialized %s", self.__class__.__name__) - def build_options(self): - """ Get the commands that belong to each category """ - for category in self.categories: - logger.debug("Building '%s'", category) - src = ToolsCli if category == "tools" else cli - mod_classes = self.get_cli_classes(src) - self.commands[category] = self.sort_commands(category, mod_classes) - self.opts.update(self.extract_options(src, mod_classes)) - logger.debug("Built '%s'", category) + @property + def categories(self) -> tuple[T.Literal["faceswap", "tools"], ...]: + """tuple[str, str] The categories for faceswap's GUI """ + return tuple(self._commands) + + @property + def commands(self) -> dict[T.Literal["faceswap", "tools"], list[str]]: + """dict[str, ]""" + return self._commands + + @property + def opts(self) -> dict[str, dict[str, CliOption | str]]: + """dict[str, dict[str, CliOption | str]] The command line options collected from faceswap's + cli files """ + return self._opts + + def _get_modules_tools(self) -> list[ModuleType]: + """ Parse the tools cli python files for the modules that contain the command line + arguments + + Returns + ------- + list[`types.ModuleType`] + The modules for each faceswap tool that exists in the project + """ + tools_dir = os.path.join(self._base_path, "tools") + logger.debug("Scanning '%s' for cli files", tools_dir) + retval: list[ModuleType] = [] + for tool_name in sorted(os.listdir(tools_dir)): + cli_file = os.path.join(tools_dir, tool_name, "cli.py") + if not os.path.exists(cli_file): + logger.debug("File does not exist. Skipping: '%s'", cli_file) + continue + + mod = ".".join(("tools", tool_name, "cli")) + retval.append(import_module(mod)) + logger.debug("Collected: %s", retval[-1]) + return retval + + def _get_modules_faceswap(self) -> list[ModuleType]: + """ Parse the faceswap cli python files for the modules that contain the command line + arguments + + Returns + ------- + list[`types.ModuleType`] + The modules for each faceswap command line argument file that exists in the project + """ + base_dir = ["lib", "cli"] + cli_dir = os.path.join(self._base_path, *base_dir) + logger.debug("Scanning '%s' for cli files", cli_dir) + retval: list[ModuleType] = [] + + for fname in os.listdir(cli_dir): + if not fname.startswith("args"): + logger.debug("Skipping file '%s'", fname) + continue + mod = ".".join((*base_dir, os.path.splitext(fname)[0])) + retval.append(import_module(mod)) + logger.debug("Collected: '%s", retval[-1]) + return retval + + def _get_modules(self, category: T.Literal["faceswap", "tools"]) -> list[ModuleType]: + """ Parse the cli files for faceswap and tools and return the imported module + + Parameters + ---------- + category: Literal["faceswap", "tools"] + The faceswap category to obtain the cli modules + + Returns + ------- + list[`types.ModuleType`] + The modules for each faceswap command/tool that exists in the project for the given + category + """ + logger.debug("Getting '%s' cli modules", category) + if category == "tools": + return self._get_modules_tools() + return self._get_modules_faceswap() + + @classmethod + def _get_classes(cls, module: ModuleType) -> list[T.Type[FaceSwapArgs]]: + """ Obtain the classes from the given module that contain the command line + arguments + + Parameters + ---------- + module: :class:`types.ModuleType` + The imported module to parse for command line argument classes + + Returns + ------- + list[:class:`~lib.cli.args.FaceswapArgs`] + The command line argument class objects that exist in the module + """ + retval = [] + for name, obj in inspect.getmembers(module): + if not inspect.isclass(obj) or not name.lower().endswith("args"): + logger.debug("Skipping non-cli class object '%s'", name) + continue + if name.lower() in (("faceswapargs", "extractconvertargs", "guiargs")): + logger.debug("Skipping uneeded object '%s'", name) + continue + logger.debug("Collecting %s", obj) + retval.append(obj) + logger.debug("Collected from '%s': %s", module.__name__, [c.__name__ for c in retval]) + return retval + + def _get_all_classes(self, modules: list[ModuleType]) -> list[T.Type[FaceSwapArgs]]: + """Obtain the the command line options classes for the given modules + + Parameters + ---------- + modules : list[:class:`types.ModuleType`] + The imported modules to extract the command line argument classes from + + Returns + ------- + list[:class:`~lib.cli.args.FaceSwapArgs`] + The valid command line class objects for the given modules + """ + retval = [] + for module in modules: + mod_classes = self._get_classes(module) + if not mod_classes: + logger.debug("module '%s' contains no cli classes. Skipping", module) + continue + retval.extend(mod_classes) + logger.debug("Obtained %s cli classes from %s modules", len(retval), len(modules)) + return retval + + @classmethod + def _class_name_to_command(cls, class_name: str) -> str: + """ Format a FaceSwapArgs class name to a standardized command name + + Parameters + ---------- + class_name: str + The name of the class to convert to a command name + + Returns + ------- + str + The formatted command name + """ + return class_name.lower()[:-4] + + def _store_commands(self, + category: T.Literal["faceswap", "tools"], + classes: list[T.Type[FaceSwapArgs]]) -> None: + """ Format classes into command names and sort. Store in :attr:`commands`. + Sorting is in specific workflow order for faceswap and alphabetical for all others + + Parameters + ---------- + category: Literal["faceswap", "tools"] + The category to store the command names for + classes: list[:class:`~lib.cli.args.FaceSwapArgs`] + The valid command line class objects for the category + """ + class_names = [c.__name__ for c in classes] + commands = sorted(self._class_name_to_command(n) for n in class_names) - @staticmethod - def get_cli_classes(cli_source): - """ Parse the cli scripts for the arg classes """ - mod_classes = list() - for name, obj in inspect.getmembers(cli_source): - if inspect.isclass(obj) and name.lower().endswith("args") \ - and name.lower() not in (("faceswapargs", - "extractconvertargs", - "guiargs")): - mod_classes.append(name) - logger.debug(mod_classes) - return mod_classes - - def sort_commands(self, category, classes): - """ Format classes into command names and sort: - Specific workflow order for faceswap. - Alphabetical for all others """ - commands = sorted(self.format_command_name(command) - for command in classes) if category == "faceswap": ordered = ["extract", "train", "convert"] commands = ordered + [command for command in commands if command not in ordered] - logger.debug(commands) - return commands - - @staticmethod - def format_command_name(classname): - """ Format args class name to command """ - return classname.lower()[:-4] - - def extract_options(self, cli_source, mod_classes): - """ Extract the existing ArgParse Options - into master options Dictionary """ - subopts = dict() - for classname in mod_classes: - logger.debug("Processing: (classname: '%s')", classname) - command = self.format_command_name(classname) - options = self.get_cli_arguments(cli_source, classname, command) - options = self.process_options(options) - logger.debug("Processed: (classname: '%s', command: '%s', options: %s)", - classname, command, options) - subopts[command] = options - return subopts - - @staticmethod - def get_cli_arguments(cli_source, classname, command): - """ Extract the options from the main and tools cli files """ - meth = getattr(cli_source, classname)(None, command) - return meth.argument_list + meth.optional_arguments + meth.global_arguments - - def process_options(self, command_options): - """ Process the options for a single command """ - final_options = list() + self._commands[category].extend(commands) + logger.debug("Set '%s' commands: %s", category, self._commands[category]) + + @classmethod + def _get_cli_arguments(cls, + arg_class: T.Type[FaceSwapArgs], + command: str) -> tuple[str, list[dict[str, T.Any]]]: + """ Extract the command line options from the given cli class + + Parameters + ---------- + arg_class: :class:`~lib.cli.args.FaceSwapArgs` + The class to extract the options from + command: str + The command name to extract the options for + + Returns + ------- + info: str + The helptext information for given command + options: list[dict. str, Any] + The command line options for the given command + """ + args = arg_class(None, command) + arg_list = args.argument_list + args.optional_arguments + args.global_arguments + logger.debug("Obtain options for '%s'. Info: '%s', options: %s", + command, args.info, len(arg_list)) + return args.info, arg_list + + @classmethod + def _set_control_title(cls, opts: tuple[str, ...]) -> str: + """ Take the option switch and format it nicely + + Parameters + ---------- + opts: tuple[str, ...] + The option switch for a command line option + + Returns + ------- + str + The option switch formatted for display + """ + ctltitle = opts[1] if len(opts) == 2 else opts[0] + retval = ctltitle.replace("-", " ").replace("_", " ").strip().title() + logger.debug("Formatted '%s' to '%s'", ctltitle, retval) + return retval + + @classmethod + def _get_data_type(cls, opt: dict[str, T.Any]) -> type: + """ Return a data type for passing into control_helper.py to get the correct control + + Parameters + ---------- + option: dict[str, Any] + The option to extract the data type from + + Returns + ------- + :class:`type` + The Python type for the option + """ + type_ = opt.get("type") + if type_ is not None and isinstance(opt["type"], type): + retval = type_ + elif opt.get("action", "") in ("store_true", "store_false"): + retval = bool + else: + retval = str + logger.debug("Setting type to %s for %s", retval, type_) + return retval + + @classmethod + def _get_rounding(cls, opt: dict[str, T.Any]) -> int | None: + """ Return rounding for the given option + + Parameters + ---------- + option: dict[str, Any] + The option to extract the rounding from + + Returns + ------- + int | None + int if the data type supports rounding otherwise ``None`` + """ + dtype = opt.get("type") + if dtype == float: + retval = opt.get("rounding", 2) + elif dtype == int: + retval = opt.get("rounding", 1) + else: + retval = None + logger.debug("Setting rounding to %s for type %s", retval, dtype) + return retval + + @classmethod + def _expand_action_option(cls, + option: dict[str, T.Any], + options: list[dict[str, T.Any]]) -> None: + """ Expand the action option to the full command name + + Parameters + ---------- + option: dict[str, Any] + The option to expand the action for + options: list[dict[str, Any]] + The full list of options for the command + """ + opts = {opt["opts"][0]: opt["opts"][-1] + for opt in options} + old_val = option["action_option"] + new_val = opts[old_val] + logger.debug("Updating action option from '%s' to '%s'", old_val, new_val) + option["action_option"] = new_val + + def _get_sysbrowser(self, + option: dict[str, T.Any], + options: list[dict[str, T.Any]], + command: str) -> dict[T.Literal["filetypes", + "browser", + "command", + "destination", + "action_option"], str | list[str]] | None: + """ Return the system file browser and file types if required + + Parameters + ---------- + option: dict[str, Any] + The option to obtain the system browser for + options: list[dict[str, Any]] + The full list of options for the command + command: str + The command that the options belong to + + Returns + ------- + dict[Literal["filetypes", "browser", "command", + "destination", "action_option"], list[str]] | None + The browser information, if valid, or ``None`` if browser not required + """ + action = option.get("action", None) + if action not in (actions.DirFullPaths, + actions.FileFullPaths, + actions.FilesFullPaths, + actions.DirOrFileFullPaths, + actions.DirOrFilesFullPaths, + actions.SaveFileFullPaths, + actions.ContextFullPaths): + return None + + retval: dict[T.Literal["filetypes", + "browser", + "command", + "destination", + "action_option"], str | list[str]] = {} + action_option = None + if option.get("action_option", None) is not None: + self._expand_action_option(option, options) + action_option = option["action_option"] + retval["filetypes"] = option.get("filetypes", "default") + if action == actions.FileFullPaths: + retval["browser"] = ["load"] + elif action == actions.FilesFullPaths: + retval["browser"] = ["multi_load"] + elif action == actions.SaveFileFullPaths: + retval["browser"] = ["save"] + elif action == actions.DirOrFileFullPaths: + retval["browser"] = ["folder", "load"] + elif action == actions.DirOrFilesFullPaths: + retval["browser"] = ["folder", "multi_load"] + elif action == actions.ContextFullPaths and action_option: + retval["browser"] = ["context"] + retval["command"] = command + retval["action_option"] = action_option + retval["destination"] = option.get("dest", option["opts"][1].replace("--", "")) + else: + retval["browser"] = ["folder"] + logger.debug(retval) + return retval + + def _process_options(self, command_options: list[dict[str, T.Any]], command: str + ) -> dict[str, CliOption]: + """ Process the options for a single command + + Parameters + ---------- + command_options: list[dict. str, Any] + The command line options for the given command + command: str + The command name to process + + Returns + ------- + dict[str, :class:`CliOption`] + The collected command line options for handling by the GUI + """ + retval: dict[str, CliOption] = {} for opt in command_options: - logger.trace("Processing: %s", opt) + logger.debug("Processing: cli option: %s", opt["opts"]) if opt.get("help", "") == SUPPRESS: - logger.trace("Skipping suppressed option: %s", opt) + logger.debug("Skipping suppressed option: %s", opt) continue - ctl, sysbrowser, filetypes, action_option = self.set_control(opt) - opt["control_title"] = self.set_control_title(opt.get("opts", "")) - opt["control"] = ctl - opt["filesystem_browser"] = sysbrowser - opt["filetypes"] = filetypes - opt["action_option"] = action_option - final_options.append(opt) - logger.trace("Processed: %s", opt) - return final_options - - @staticmethod - def set_control_title(opts): - """ Take the option switch and format it nicely """ - ctltitle = opts[1] if len(opts) == 2 else opts[0] - ctltitle = ctltitle.replace("-", " ").replace("_", " ").strip().title() - return ctltitle + title = self._set_control_title(opt["opts"]) + cpanel_option = ControlPanelOption( + title, + self._get_data_type(opt), + group=opt.get("group", None), + default=opt.get("default", None), + choices=opt.get("choices", None), + is_radio=opt.get("action", "") == actions.Radio, + is_multi_option=opt.get("action", "") == actions.MultiOption, + rounding=self._get_rounding(opt), + min_max=opt.get("min_max", None), + sysbrowser=self._get_sysbrowser(opt, command_options, command), + helptext=opt["help"], + track_modified=True, + command=command) + retval[title] = CliOption(cpanel_option=cpanel_option, + opts=opt["opts"], + nargs=opt.get("nargs")) + logger.debug("Processed: %s", retval) + return retval - def set_control(self, option): - """ Set the control and filesystem browser to use for each option """ - sysbrowser = None - action = option.get("action", None) - action_option = option.get("action_option", None) - filetypes = option.get("filetypes", None) - ctl = ttk.Entry - if action in (cli.FullPaths, - cli.DirFullPaths, - cli.FileFullPaths, - cli.DirOrFileFullPaths, - cli.SaveFileFullPaths, - cli.ContextFullPaths): - sysbrowser, filetypes = self.set_sysbrowser(action, - filetypes, - action_option) - elif option.get("min_max", None): - ctl = ttk.Scale - elif option.get("choices", "") != "": - ctl = ttk.Combobox - elif option.get("action", "") == "store_true": - ctl = ttk.Checkbutton - return ctl, sysbrowser, filetypes, action_option - - @staticmethod - def set_sysbrowser(action, filetypes, action_option): - """ Set the correct file system browser and filetypes - for the passed in action """ - sysbrowser = ["folder"] - filetypes = "default" if not filetypes else filetypes - if action == cli.FileFullPaths: - sysbrowser = ["load"] - elif action == cli.SaveFileFullPaths: - sysbrowser = ["save"] - elif action == cli.DirOrFileFullPaths: - sysbrowser = ["folder", "load"] - elif action == cli.ContextFullPaths and action_option: - sysbrowser = ["context"] - logger.debug("sysbrowser: %s, filetypes: '%s'", sysbrowser, filetypes) - return sysbrowser, filetypes - - def set_context_option(self, command): - """ Set the tk_var for the source action option - that dictates the context sensitive file browser. """ - actions = {item["opts"][0]: item["value"] - for item in self.gen_command_options(command)} - for opt in self.gen_command_options(command): - if opt["filesystem_browser"] == ["context"]: - opt["action_option"] = actions[opt["action_option"]] - - def gen_command_options(self, command): - """ Yield each option for specified command """ - for option in self.opts[command]: - yield option - - def options_to_process(self, command=None): - """ Return a consistent object for processing - regardless of whether processing all commands - or just one command for reset and clear """ + def _extract_options(self, arguments: list[T.Type[FaceSwapArgs]]): + """ Extract the collected command line FaceSwapArg options into master options + :attr:`opts` dictionary + + Parameters + ---------- + arguments: list[:class:`~lib.cli.args.FaceSwapArgs`] + The command line class objects to process + """ + retval = {} + for arg_class in arguments: + logger.debug("Processing: '%s'", arg_class.__name__) + command = self._class_name_to_command(arg_class.__name__) + info, options = self._get_cli_arguments(arg_class, command) + opts = T.cast(dict[str, CliOption | str], self._process_options(options, command)) + opts["helptext"] = info + retval[command] = opts + self._opts.update(retval) + + def _build_options(self) -> None: + """ Parse the command line argument modules and populate :attr:`commands` and :attr:`opts` + for each category """ + for category in self.categories: + modules = self._get_modules(category) + classes = self._get_all_classes(modules) + self._store_commands(category, classes) + self._extract_options(classes) + logger.debug("Built '%s'", category) + + def _gen_command_options(self, command: str + ) -> T.Generator[tuple[str, CliOption], None, None]: + """ Yield each option for specified command + + Parameters + ---------- + command: str + The faceswap command to generate the options for + + Yields + ------ + str + The option name for display + :class:`CliOption`: + The option object + """ + for key, val in self._opts.get(command, {}).items(): + if not isinstance(val, CliOption): + continue + yield key, val + + def _options_to_process(self, command: str | None = None) -> list[CliOption]: + """ Return a consistent object for processing regardless of whether processing all commands + or just one command for reset and clear. Removes helptext from return value + + Parameters + ---------- + command: str | None, optional + The command to return the options for. ``None`` for all commands. Default ``None`` + + Returns + ------- + list[:class:`CliOption`] + The options to be processed + """ if command is None: - options = [opt for opts in self.opts.values() for opt in opts] - else: - options = [opt for opt in self.gen_command_options(command)] - return options + return [opt for opts in self._opts.values() + for opt in opts if isinstance(opt, CliOption)] + return [opt for opt in self._opts[command] if isinstance(opt, CliOption)] + + def reset(self, command: str | None = None) -> None: + """ Reset the options for all or passed command back to default value - def reset(self, command=None): - """ Reset the options for all or passed command - back to default value """ + Parameters + ---------- + command: str | None, optional + The command to reset the options for. ``None`` to reset for all commands. + Default: ``None`` + """ logger.debug("Resetting options to default. (command: '%s'", command) - for option in self.options_to_process(command): - default = option.get("default", "") - default = "" if default is None else default - if (option.get("nargs", None) - and isinstance(default, (list, tuple))): + for option in self._options_to_process(command): + cp_opt = option.cpanel_option + default = "" if cp_opt.default is None else cp_opt.default + if option.nargs is not None and isinstance(default, (list, tuple)): default = ' '.join(str(val) for val in default) - option["value"].set(default) + cp_opt.set(default) - def clear(self, command=None): - """ Clear the options values for all or passed - commands """ + def clear(self, command: str | None = None) -> None: + """ Clear the options values for all or passed commands + + Parameters + ---------- + command: str | None, optional + The command to clear the options for. ``None`` to clear options for all commands. + Default: ``None`` + """ logger.debug("Clearing options. (command: '%s'", command) - for option in self.options_to_process(command): - if isinstance(option["value"].get(), bool): - option["value"].set(False) - elif isinstance(option["value"].get(), int): - option["value"].set(0) + for option in self._options_to_process(command): + cp_opt = option.cpanel_option + if isinstance(cp_opt.get(), bool): + cp_opt.set(False) + elif isinstance(cp_opt.get(), (int, float)): + cp_opt.set(0) else: - option["value"].set("") + cp_opt.set("") + + def get_option_values(self, command: str | None = None + ) -> dict[str, dict[str, bool | int | float | str]]: + """ Return all or single command control titles with the associated tk_var value + + Parameters + ---------- + command: str | None, optional + The command to get the option values for. ``None`` to get all option values. + Default: ``None`` - def get_option_values(self, command=None): - """ Return all or single command control titles - with the associated tk_var value """ - ctl_dict = dict() - for cmd, opts in self.opts.items(): + Returns + ------- + dict[str, dict[str, bool | int | float | str]] + option values in the format {command: {option_name: option_value}} + """ + ctl_dict: dict[str, dict[str, bool | int | float | str]] = {} + for cmd, opts in self._opts.items(): if command and command != cmd: continue - cmd_dict = dict() - for opt in opts: - cmd_dict[opt["control_title"]] = opt["value"].get() + cmd_dict: dict[str, bool | int | float | str] = {} + for key, val in opts.items(): + if not isinstance(val, CliOption): + continue + cmd_dict[key] = val.cpanel_option.get() ctl_dict[cmd] = cmd_dict - logger.debug("command: '%s', ctl_dict: '%s'", command, ctl_dict) + logger.debug("command: '%s', ctl_dict: %s", command, ctl_dict) return ctl_dict - def get_one_option_variable(self, command, title): - """ Return a single tk_var for the specified - command and control_title """ - for option in self.gen_command_options(command): - if option["control_title"] == title: - return option["value"] + def get_one_option_variable(self, command: str, title: str) -> Variable | None: + """ Return a single :class:`tkinter.Variable` tk_var for the specified command and + control_title + + Parameters + ---------- + command: str + The command to return the variable from + title: str + The option title to return the variable for + + Returns + ------- + :class:`tkinter.Variable` | None + The requested tkinter variable, or ``None`` if it could not be found + """ + for opt_title, option in self._gen_command_options(command): + if opt_title == title: + return option.cpanel_option.tk_var return None - def gen_cli_arguments(self, command): - """ Return the generated cli arguments for - the selected command """ - for option in self.gen_command_options(command): - optval = str(option.get("value", "").get()) - opt = option["opts"][0] - if command in ("extract", "convert") and opt == "-o": - get_images().pathoutput = optval - if optval in ("False", ""): + def gen_cli_arguments(self, command: str) -> T.Generator[tuple[str, ...], None, None]: + """ Yield the generated cli arguments for the selected command + + Parameters + ---------- + command: str + The command to generate the command line arguments for + + Yields + ------ + tuple[str, ...] + The generated command line arguments + """ + output_dir = None + switches = "" + args = [] + for _, option in self._gen_command_options(command): + str_val = str(option.cpanel_option.get()) + switch = option.opts[0] + batch_mode = command == "extract" and switch == "-b" # Check for batch mode + if command in ("extract", "convert") and switch == "-o": # Output location for preview + output_dir = str_val + + if str_val in ("False", ""): # skip no value opts continue - elif optval == "True": - yield (opt, ) - else: - if option.get("nargs", None): - optval = optval.split(" ") - opt = [opt] + optval + + if str_val == "True": # store_true just output the switch + switches += switch[1:] + continue + + if option.nargs is not None: + if "\"" in str_val: + val = [arg[1:-1] for arg in re.findall(r"\".+?\"", str_val)] else: - opt = (opt, optval) - yield opt + val = str_val.split(" ") + arg = (switch, *val) + else: + arg = (switch, str_val) + args.append(arg) + + switch_args = [] if not switches else [(f"-{switches}", )] + yield from switch_args + args + + if command in ("extract", "convert") and output_dir is not None: + get_images().preview_extract.set_faceswap_output_path(output_dir, + batch_mode=batch_mode) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/popup_configure.py b/lib/gui/popup_configure.py index 9e259236db..405082c665 100644 --- a/lib/gui/popup_configure.py +++ b/lib/gui/popup_configure.py @@ -1,350 +1,842 @@ #!/usr/bin python3 -""" Configure Plugins popup of the Faceswap GUI """ - -from configparser import ConfigParser +"""The pop-up window of the Faceswap GUI for the setting of configuration options.""" +from __future__ import annotations +import gettext import logging +import os import tkinter as tk - from tkinter import ttk - -from .tooltip import Tooltip -from .utils import get_config, get_images, ContextMenu, set_slider_rounding - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name -POPUP = dict() - - -def popup_config(config, root): - """ Close any open popup and open requested popup """ - if POPUP: - p_key = list(POPUP.keys())[0] - logger.debug("Closing open popup: '%s'", p_key) - POPUP[p_key].destroy() - del POPUP[p_key] - window = ConfigurePlugins(config, root) - POPUP[config[0]] = window - - -class ConfigurePlugins(tk.Toplevel): - """ Pop up for detailed graph/stats for selected session """ - def __init__(self, config, root): - logger.debug("Initializing %s", self.__class__.__name__) +import typing as T + +from lib.config import get_configs +from lib.logger import parse_class_init +from lib.serializer import get_serializer +from lib.utils import get_module_objects + +from .control_helper import ControlPanel, ControlPanelOption +from .custom_widgets import Tooltip +from .utils import FileHandler, get_config, get_images, PATHCACHE + +if T.TYPE_CHECKING: + from lib.config import FaceswapConfig + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class _State(): + """ + Holds the current state of the popup window, ensuring that only 1 instance can ever exist + """ + def __init__(self) -> None: + logger.debug(parse_class_init(locals())) + self._popup: _ConfigurePlugins | None = None + + def open_popup(self, name: str | None = None) -> None: + """Launch the popup, ensuring only one instance is ever open + + Parameters + ---------- + name : str | None, Optional + The name of the configuration file. Used for selecting the correct section if required. + Set to ``None`` if no initial section should be selected. Default: ``None`` + """ + logger.debug("name: %s", name) + if self._popup is not None: + logger.debug("Restoring existing popup") + self._popup.update() + self._popup.deiconify() + self._popup.lift() + return + self._popup = _ConfigurePlugins(name) + + def close_popup(self) -> None: + """Destroy the open popup and remove it from tracking.""" + if self._popup is None: + logger.debug("No popup to close. Returning") + return + logger.debug("Destroying popup") + self._popup.destroy() + del self._popup + self._popup = None + + +_STATE = _State() +open_popup = _STATE.open_popup + + +class _ConfigurePlugins(tk.Toplevel): + """Pop-up window for the setting of Faceswap Configuration Options. + + Parameters + ---------- + name : str | None + The name of the section that is being navigated to. Used for opening on the correct + page in the Tree View. ``None`` to open on the first page + """ + def __init__(self, name: str | None) -> None: + logger.debug(parse_class_init(locals())) super().__init__() - name, self.config = config - self.title("{} Plugins".format(name.title())) - self.tk.call('wm', 'iconphoto', self._w, get_images().icons["favicon"]) - - self.set_geometry(root) + self._root = get_config().root + self._set_geometry() + self._tk_vars = {"header": tk.StringVar()} + + theme = {**get_config().user_theme["group_panel"], + **get_config().user_theme["group_settings"]} + header_frame = self._build_header() + content_frame = ttk.Frame(self) + + self._tree = _Tree(content_frame, name, theme).tree + self._tree.bind("", self._select_item) + + self._opts_frame = DisplayArea(self, content_frame, self._tree, theme) + self._opts_frame.pack(fill=tk.BOTH, expand=True, side=tk.RIGHT) + footer_frame = self._build_footer() + + header_frame.pack(fill=tk.X, padx=5, pady=5, side=tk.TOP) + content_frame.pack(fill=tk.BOTH, padx=5, pady=(0, 5), expand=True, side=tk.TOP) + footer_frame.pack(fill=tk.X, padx=5, pady=(0, 5), side=tk.BOTTOM) + + select = name if name else self._tree.get_children()[0] + self._tree.selection_set(select) + self._tree.focus(select) + self._select_item(0) # type:ignore[arg-type] + + self.title("Configure Settings") + self.tk.call('wm', + 'iconphoto', + self._w, # type:ignore[attr-defined] + get_images().icons["favicon"]) + self.protocol("WM_DELETE_WINDOW", _STATE.close_popup) - self.page_frame = ttk.Frame(self) - self.page_frame.pack(fill=tk.BOTH, expand=True) - - self.plugin_info = dict() - self.config_dict_gui = self.get_config() - self.build() - self.update() logger.debug("Initialized %s", self.__class__.__name__) - def set_geometry(self, root): - """ Set pop-up geometry """ + def _set_geometry(self) -> None: + """Set the geometry of the pop-up window""" scaling_factor = get_config().scaling_factor - pos_x = root.winfo_x() + 80 - pos_y = root.winfo_y() + 80 - width = int(720 * scaling_factor) - height = int(400 * scaling_factor) + pos_x = self._root.winfo_x() + 80 + pos_y = self._root.winfo_y() + 80 + width = int(600 * scaling_factor) + height = int(536 * scaling_factor) logger.debug("Pop up Geometry: %sx%s, %s+%s", width, height, pos_x, pos_y) - self.geometry("{}x{}+{}+{}".format(width, height, pos_x, pos_y)) + self.geometry(f"{width}x{height}+{pos_x}+{pos_y}") + + def _build_header(self) -> ttk.Frame: + """Build the main header text and separator. + + Returns + ------- + :class:`tkinter.ttk.Frame` + The header of the popup configuration window + """ + header_frame = ttk.Frame(self) + lbl_frame = ttk.Frame(header_frame) + + self._tk_vars["header"].set("Settings") + lbl_header = ttk.Label(lbl_frame, + textvariable=self._tk_vars["header"], + anchor=tk.W, + style="SPanel.Header1.TLabel") + lbl_header.pack(fill=tk.X, expand=True, side=tk.LEFT) + + sep = ttk.Frame(header_frame, height=2, relief=tk.RIDGE) + + lbl_frame.pack(fill=tk.X, expand=True, side=tk.TOP) + sep.pack(fill=tk.X, pady=(1, 0), side=tk.BOTTOM) + return header_frame + + def _build_footer(self) -> ttk.Frame: + """Build the main footer buttons and separator. + + Returns + ------- + :class:`ttk.Frame` + The footer of the popup configuration window + """ + logger.debug("Adding action buttons") + frame = ttk.Frame(self) + left_frame = ttk.Frame(frame) + right_frame = ttk.Frame(frame) + + btn_saveall = ttk.Button(left_frame, + text="Save All", + width=10, + command=self._opts_frame.save) + btn_rstall = ttk.Button(left_frame, + text="Reset All", + width=10, + command=self._opts_frame.reset) + + btn_cls = ttk.Button(right_frame, text="Cancel", width=10, command=_STATE.close_popup) + btn_save = ttk.Button(right_frame, + text="Save", + width=10, + command=lambda: self._opts_frame.save(page_only=True)) + btn_rst = ttk.Button(right_frame, + text="Reset", + width=10, + command=lambda: self._opts_frame.reset(page_only=True)) + + Tooltip(btn_cls, text=_("Close without saving"), wrap_length=720) + Tooltip(btn_save, text=_("Save this page's config"), wrap_length=720) + Tooltip(btn_rst, text=_("Reset this page's config to default values"), wrap_length=720) + Tooltip(btn_saveall, + text=_("Save all settings for the currently selected config"), + wrap_length=720) + Tooltip(btn_rstall, + text=_("Reset all settings for the currently selected config to default values"), + wrap_length=720) - def get_config(self): - """ Format config into useful format for GUI and pull default value if a value has not - been supplied """ - logger.debug("Formatting Config for GUI") - conf = dict() - for section in self.config.config.sections(): - self.config.section = section - category = section.split(".")[0] - options = self.config.defaults[section] - conf.setdefault(category, dict())[section] = options - for key in options.keys(): - if key == "helptext": - self.plugin_info[section] = options[key] - continue - options[key]["value"] = self.config.config_dict.get(key, options[key]["default"]) - logger.debug("Formatted Config for GUI: %s", conf) - return conf - - def build(self): - """ Build the config popup """ - logger.debug("Building plugin config popup") - container = ttk.Notebook(self.page_frame) - container.pack(fill=tk.BOTH, expand=True) - categories = sorted(list(key for key in self.config_dict_gui.keys())) - if "global" in categories: # Move global to first item - categories.insert(0, categories.pop(categories.index("global"))) - for category in categories: - page = self.build_page(container, category) - container.add(page, text=category.title()) - - self.add_frame_separator() - self.add_actions() - logger.debug("Built plugin config popup") - - def build_page(self, container, category): - """ Build a plugin config page """ - logger.debug("Building plugin config page: '%s'", category) - plugins = sorted(list(key for key in self.config_dict_gui[category].keys())) - if any(plugin != category for plugin in plugins): - page = ttk.Notebook(container) - page.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - for plugin in plugins: - frame = ConfigFrame(page, - self.config_dict_gui[category][plugin], - self.plugin_info[plugin]) - title = plugin[plugin.rfind(".") + 1:] - title = title.replace("_", " ").title() - page.add(frame, text=title) - else: - page = ConfigFrame(container, - self.config_dict_gui[category][plugins[0]], - self.plugin_info[plugins[0]]) - - logger.debug("Built plugin config page: '%s'", category) - - return page - - def add_frame_separator(self): - """ Add a separator between top and bottom frames """ - logger.debug("Add frame seperator") - sep = ttk.Frame(self.page_frame, height=2, relief=tk.RIDGE) - sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) - logger.debug("Added frame seperator") - - def add_actions(self): - """ Add Action buttons """ - logger.debug("Add action buttons") - frame = ttk.Frame(self.page_frame) - frame.pack(fill=tk.BOTH, padx=5, pady=5, side=tk.BOTTOM) - btn_cls = ttk.Button(frame, text="Cancel", width=10, command=self.destroy) btn_cls.pack(padx=2, side=tk.RIGHT) - btn_ok = ttk.Button(frame, text="OK", width=10, command=self.save_config) - btn_ok.pack(padx=2, side=tk.RIGHT) + btn_save.pack(padx=2, side=tk.RIGHT) + btn_rst.pack(padx=2, side=tk.RIGHT) + btn_saveall.pack(padx=2, side=tk.RIGHT) + btn_rstall.pack(padx=2, side=tk.RIGHT) + + left_frame.pack(side=tk.LEFT) + right_frame.pack(side=tk.RIGHT) logger.debug("Added action buttons") + return frame - def save_config(self): - """ Save the config file """ - logger.debug("Saving config") - options = {sect: opts - for value in self.config_dict_gui.values() - for sect, opts in value.items()} - - new_config = ConfigParser(allow_no_value=True) - for section, items in self.config.defaults.items(): - logger.debug("Adding section: '%s')", section) - self.config.insert_config_section(section, items["helptext"], config=new_config) - for item, def_opt in items.items(): - if item == "helptext": - continue - new_opt = options[section][item] - logger.debug("Adding option: (item: '%s', default: '%s' new: '%s'", - item, def_opt, new_opt) - helptext = def_opt["helptext"] - helptext += self.config.set_helptext_choices(def_opt) - helptext += "\n[Default: {}]".format(def_opt["default"]) - helptext = self.config.format_help(helptext, is_section=False) - new_config.set(section, helptext) - new_config.set(section, item, str(new_opt["selected"].get())) - self.config.config = new_config - self.config.save_config() - print("Saved config: '{}'".format(self.config.configfile)) - self.destroy() - logger.debug("Saved config") + def _select_item(self, event: tk.Event) -> None: # pylint:disable=unused-argument + """Update the session summary info with the selected item or launch graph. + + If the mouse is clicked on the graph icon, then the session summary pop-up graph is + launched. Otherwise the selected ID is stored. + + Parameters + ---------- + event : :class:`tkinter.Event` + The tkinter mouse button release event. Unused. + """ + selection = self._tree.focus() + section = selection.split("|")[0] + subsections = selection.split("|")[1:] if "|" in selection else [] + self._tk_vars["header"].set(f"{section.title()} Settings") + self._opts_frame.select_options(section, subsections) + + +class _Tree(ttk.Frame): # pylint:disable=too-many-ancestors + """Frame that holds the Tree View Navigator and scroll bar for the configuration pop-up. + + Parameters + ---------- + parent : :class:`tkinter.ttk.Frame` + The parent frame to the Tree View area + name : str | None + The name of the section that is being navigated to. Used for opening on the correct + page in the Tree View. ``None`` if no specific area is being navigated to + theme : dict[str, Any] + The color mapping for the settings pop-up theme + """ + def __init__(self, parent: ttk.Frame, name: str | None, theme: dict[str, T.Any]): + logger.debug(parse_class_init(locals())) + super().__init__(parent) + self._fix_styles(theme) + + frame = ttk.Frame(self, relief=tk.SOLID, borderwidth=1) + self._tree = self._build_tree(frame, name) + scrollbar = ttk.Scrollbar(frame, orient="vertical", command=self._tree.yview) + + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self._tree.pack(fill=tk.Y, expand=True) + self._tree.configure(yscrollcommand=scrollbar.set) + frame.pack(expand=True, fill=tk.Y) + self.pack(side=tk.LEFT, fill=tk.Y) + + @property + def tree(self) -> ttk.Treeview: + """:class:`tkinter.ttk.Treeview` The Tree View held within the frame""" + return self._tree + + @classmethod + def _fix_styles(cls, theme: dict[str, T.Any]) -> None: + """Tkinter has a bug when setting the background style on certain OSes. This fixes the + issue so we can set different colored backgrounds. + + We also set some default styles for our tree view. + + Parameters + ---------- + theme: dict[str, Any] + The color mapping for the settings pop-up theme + """ + style = ttk.Style() + + # Fix a bug in Tree-view that doesn't show alternate foreground on selection + fix_map = lambda o: [elm for elm in style.map("Treeview", query_opt=o) # noqa[E731] # pylint:disable=C3001 + if elm[:2] != ("!disabled", "!selected")] + + # Remove the Borders + style.configure("ConfigNav.Treeview", bd=0, background="#F0F0F0") + style.layout("ConfigNav.Treeview", [('ConfigNav.Treeview.treearea', {'sticky': 'nswe'})]) + + # Set colors + style.map("ConfigNav.Treeview", + foreground=fix_map("foreground"), # type:ignore[arg-type] + background=fix_map("background")) # type:ignore[arg-type] + style.map('ConfigNav.Treeview', background=[('selected', theme["tree_select"])]) + + @classmethod + def _process_sections(cls, + tree: ttk.Treeview, + sections: list[list[str]], + category: str, + is_open: bool) -> None: + """Process the sections of a category's configuration. + + Creates a category's sections, then the sub options for that category + + Parameters + ---------- + tree: :class:`tkinter.ttk.Treeview` + The tree view to insert sections into + sections: list[list[str]] + The sections to insert into the Tree View + category: str + The category node that these sections sit in + is_open: bool + ``True`` if the node should be created in "open" mode. ``False`` if it should be + closed. + """ + seen = set() + for section in sections: + if section[-1] == "global": # Global categories get escalated to parent + continue + sect = section[0] + section_id = f"{category}|{sect}" + if sect not in seen: + seen.add(sect) + text = sect.replace("_", " ").title() + tree.insert(category, "end", section_id, text=text, open=is_open, tags="section") + if len(section) == 2: + opt = section[-1] + opt_id = f"{section_id}|{opt}" + opt_text = opt.replace("_", " ").title() + tree.insert(section_id, "end", opt_id, text=opt_text, open=is_open, tags="option") + + def _build_tree(self, parent: ttk.Frame, name: str | None) -> ttk.Treeview: + """Build the configuration pop-up window. + + Parameters + ---------- + parent : :class:`tkinter.ttk.Frame` + The parent frame that holds the treeview + name : str | None + The name of the section that is being navigated to. Used for opening on the correct + page in the Tree View. ``None`` if no specific area is being navigated to + + Returns + ------- + :class:`tkinter.ttk.Treeview` + The populated tree view + """ + logger.debug("Building Tree View Navigator") + tree = ttk.Treeview(parent, show="tree", style="ConfigNav.Treeview") + data = {category: [sect.split(".") for sect in sorted(conf.sections)] + for category, conf in get_configs().items()} + ordered = sorted(list(data.keys())) + categories = ["extract", "train", "convert"] + categories += [x for x in ordered if x not in categories] + + for cat in categories: + img = get_images().icons.get(f"settings_{cat}", "") + text = cat.replace("_", " ").title() + text = " " + text if img else text + is_open = tk.TRUE if name is None or name == cat else tk.FALSE + tree.insert("", "end", cat, text=text, image=img, open=is_open, tags="category") + self._process_sections(tree, data[cat], cat, name == cat) + + tree.tag_configure('category', background='#DFDFDF') + tree.tag_configure('section', background='#E8E8E8') + tree.tag_configure('option', background='#F0F0F0') + logger.debug("Tree View Navigator") + return tree + + +class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors + """The option configuration area of the pop up options. + + Parameters + ---------- + top_level : :class:``tk.Toplevel`` + The tkinter Top Level widget + parent : :class:`tkinter.ttk.Frame` + The parent frame that holds the Display Area of the pop up configuration window + tree : :class:`tkinter.ttk.Treeview` + The Tree View navigator for the pop up configuration window + theme : dict[str, Any] + The color mapping for the settings pop-up theme + """ + def __init__(self, + top_level: tk.Toplevel, + parent: ttk.Frame, + tree: ttk.Treeview, + theme: dict[str, T.Any]) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent) + self._theme = theme + self._tree = tree + self._vars: dict[str, tk.StringVar] = {} + self._cache: dict[str, ttk.Frame] = {} + self._config_cpanel_dict = self._get_config() + self._displayed_frame: ttk.Frame | None = None + self._displayed_key: str | None = None + + self._presets = _Presets(self, top_level) + self._build_header() + @property + def displayed_key(self) -> str | None: + """str : The current display page's lookup key for configuration options.""" + return self._displayed_key -class ConfigFrame(ttk.Frame): # pylint: disable=too-many-ancestors - """ Config Frame - Holds the Options for config """ + @property + def config_dict(self) -> dict[str, dict[str, str | dict[str, ControlPanelOption]]]: + """ + dict[str, dict[str, str | dict[str, ControlPanelOption]]] : The configuration + dictionary for all display pages. + """ + return self._config_cpanel_dict + + def _get_config(self) -> dict[str, dict[str, str | dict[str, ControlPanelOption]]]: + """ + Format the configuration options stored in :attr:`lib.config.FACESWAP_CONFIGS` into a + dict of :class:`~lib.gui.control_helper.ControlPanelOption's for placement into option + frames. + + Returns + ------- + dict[str, dict[str, str | dict[str, class:`~lib.gui.control_helper.ControlPanelOption`]]] + A dictionary of section names to :class:`~lib.gui.control_helper.ControlPanelOption` + objects + """ + logger.debug("Formatting Config for GUI") + retval: dict[str, dict[str, str | dict[str, ControlPanelOption]]] = {} + for plugin, conf in get_configs().items(): + for section_name, section in conf.sections.items(): + category = section_name.split(".")[0] + sect = section_name.split(".")[-1] + # Elevate global to root + key = plugin if sect == "global" else f"{plugin}|{category}|{sect}" + retval[key] = {"helptext": section.helptext, "options": {}} + cp_options: dict[str, ControlPanelOption] = {} + for option_name, option in section.options.items(): + cp_options[option_name] = ControlPanelOption.from_config_object(option_name, + option) + + retval[key] = {"helptext": section.helptext, "options": cp_options} + logger.debug("Formatted Config for GUI: %s", retval) + return retval + + def _build_presets_buttons(self, frame: ttk.Frame) -> None: + """Build the section that holds the preset load and save buttons. + + Parameters + ---------- + frame : :class:`ttk.Frame` + The frame that holds the preset buttons + """ + presets_frame = ttk.Frame(frame) + for lbl in ("load", "save"): + btn = ttk.Button(presets_frame, + image=get_images().icons[lbl], + command=getattr(self._presets, lbl)) + Tooltip(btn, text=_(f"{lbl.title()} preset for this plugin"), wrap_length=720) + btn.pack(padx=2, side=tk.LEFT) + presets_frame.pack(side=tk.RIGHT) + + def _build_header(self) -> None: + """Build the dynamic header text.""" + header_frame = ttk.Frame(self) + lbl_frame = ttk.Frame(header_frame) + + var = tk.StringVar() + lbl = ttk.Label(lbl_frame, textvariable=var, anchor=tk.W, style="SPanel.Header2.TLabel") + lbl.pack(fill=tk.X, expand=True, side=tk.TOP) + + self._build_presets_buttons(header_frame) + lbl_frame.pack(fill=tk.X, side=tk.LEFT, expand=True) + header_frame.pack(fill=tk.X, padx=5, pady=5, side=tk.TOP) + self._vars["header"] = var + + def _create_links_page(self, key: str) -> ttk.Frame: + """For headings which don't have settings, build a links page to the subsections. + + Parameters + ---------- + key : str + The lookup key to set the links page for + + Returns + ------- + :class:`tkinter.ttk.Frame` + The created links page + """ + frame = ttk.Frame(self) + links = {item.replace(key, "")[1:].split("|")[0] + for item in self._config_cpanel_dict + if item.startswith(key)} + + if not links: + return frame + + header_lbl = ttk.Label(frame, text=_("Select a plugin to configure:")) + header_lbl.pack(side=tk.TOP, fill=tk.X, padx=5, pady=(5, 10)) + for link in sorted(links): + lbl = ttk.Label(frame, + text=link.replace("_", " ").title(), + anchor=tk.W, + foreground=self._theme["link_color"], + cursor="hand2") + lbl.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 5)) + bind = f"{key}|{link}" + lbl.bind("", lambda e, x=bind: self._link_callback(x)) # type:ignore[misc] - def __init__(self, parent, options, plugin_info): - logger.debug("Initializing %s", self.__class__.__name__) - ttk.Frame.__init__(self, parent) - self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + return frame - self.options = options - self.plugin_info = plugin_info + def _cache_page(self, key: str) -> None: + """Create the control panel options for the requested configuration and cache. + + Parameters + ---------- + key : str + The lookup key to the settings cache + """ + info = self._config_cpanel_dict.get(key, None) + if info is None: + logger.debug("key '%s' does not exist in options. Creating links page.", key) + self._cache[key] = self._create_links_page(key) + else: + opts = T.cast(dict[str, dict[str, ControlPanelOption]], info["options"]) + self._cache[key] = ControlPanel(self, + list(opts.values()), + header_text=info["helptext"], + columns=1, + max_columns=1, + option_columns=4, + style="SPanel", + blank_nones=False) + + def _set_display(self, section: str, subsections: list[str]) -> None: + """Set the correct display page for the given section and subsections. + + Parameters + ---------- + section : str + The main section to be navigated to (or root node) + subsections : list + The full list of subsections ending on the required node + """ + key = "|".join([section] + subsections) + if self._displayed_frame is not None: + self._displayed_frame.pack_forget() + + if key not in self._cache: + self._cache_page(key) + + self._displayed_frame = self._cache[key] + self._displayed_key = key + self._displayed_frame.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True) + + def select_options(self, section: str, subsections: list[str]) -> None: + """Display the page for the given section and subsections. + + Parameters + ---------- + section : str + The main section to be navigated to (or root node) + subsections : list[str] + The full list of subsections ending on the required node + """ + labels = ["global"] if not subsections else subsections + self._vars["header"].set(" - ".join(sect.replace("_", " ").title() for sect in labels)) + self._set_display(section, subsections) + + def _link_callback(self, identifier: str): + """Set the tree view to the selected item and display the requested page on a link click. + + Parameters + ---------- + identifier : str + The identifier from the tree view for the page to display + """ + parent = "|".join(identifier.split("|")[:-1]) + self._tree.item(parent, open=True) + self._tree.selection_set(identifier) + self._tree.focus(identifier) + split = identifier.split("|") + section = split[0] + subsections = split[1:] if len(split) > 1 else [] + self.select_options(section, subsections) + + def reset(self, page_only: bool = False) -> None: + """Reset all configuration options to their default values. + + Parameters + ---------- + page_only : bool, optional + ``True`` resets just the currently selected page's options to default, ``False`` resets + all plugins within the currently selected config to default. Default: ``False`` + """ + logger.debug("Resetting config, page_only: %s", page_only) + selection = self._tree.focus() + if page_only: + if selection not in self._config_cpanel_dict: + logger.info("No configuration options to reset for current page: %s", selection) + return + items = list(T.cast(dict[str, ControlPanelOption], + self._config_cpanel_dict[selection]["options"]).values()) + else: + items = [opt + for key, val in self._config_cpanel_dict.items() + for opt in T.cast(dict[str, ControlPanelOption], val["options"]).values() + if key.startswith(selection.split("|")[0])] + for item in items: + logger.debug("Resetting item '%s' from '%s' to default '%s'", + item.title, item.get(), item.default) + item.set(item.default) + logger.debug("Reset config") + + def _update_config(self, + page_only: bool, + config: FaceswapConfig, + category: str, + current_section: str) -> bool: + """Update the FaceswapConfig item from the currently selected options + + Parameters + ---------- + page_only : bool + ``True`` saves just the currently selected page's options, ``False`` saves all the + plugins options within the currently selected config. + config : :class:`~lib.config.FaceswapConfig` + The original config that is to be addressed + category : str + The configuration category to update + current_section : str + The section of the configuration to update + + Returns + ------- + bool + ``True`` if the config has been updated. ``False`` if it is unchanged + """ + retval = False + for section_name, section in config.sections.items(): + if page_only and section_name != current_section: + logger.debug("Skipping section '%s' for page_only save", section_name) + continue + key = category + key += f"|{section_name.replace('.', '|')}" if section_name != "global" else "" + gui_opts = T.cast(dict[str, ControlPanelOption], + self._config_cpanel_dict[key]["options"]) + for option_name, option in section.options.items(): + new_opt = gui_opts[option_name].get() + if new_opt == option.value or (isinstance(option.value, list) and + set(str(new_opt).split()) == set(option.value)): + logger.debug("Skipping unchanged option '%s'", option_name) + continue + fmt_opt = str(new_opt).split() if isinstance(option.value, list) else new_opt + logger.debug("Updating '%s' from %s to %s", + option_name, repr(option.value), repr(fmt_opt)) + option.set(new_opt) + retval = True + return retval + + def save(self, page_only: bool = False) -> None: + """Save the configuration file to disk. + + Parameters + ---------- + page_only : bool, optional + ``True`` saves just the currently selected page's options, ``False`` saves all the + plugins options within the currently selected config. Default: ``False`` + """ + logger.debug("Saving config") + selection = self._tree.focus() + category = selection.split("|")[0] + config = get_configs()[category] - self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) - self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + if "|" in selection: + lookup = ".".join(selection.split("|")[1:]) + else: # Expand global out from root node + lookup = "global" - self.optsframe = ttk.Frame(self.canvas) - self.optscanvas = self.canvas.create_window((0, 0), window=self.optsframe, anchor=tk.NW) + if page_only and lookup not in config.sections: + logger.info("No settings to save for the current page") + return - self.build_frame() - logger.debug("Initialized %s", self.__class__.__name__) + if not self._update_config(page_only, config, category, lookup): + logger.info("No config changes to save") + return - def build_frame(self): - """ Build the options frame for this command """ - logger.debug("Add Config Frame") - self.add_scrollbar() - self.canvas.bind("", self.resize_frame) + config.save_config() + logger.debug("Saved config") + if category != "gui": + return - self.add_info() - for key, val in self.options.items(): - if key == "helptext": - continue - OptionControl(key, val, self.optsframe) - logger.debug("Added Config Frame") + if not get_config().tk_vars.running_task.get(): + get_config().root.rebuild() # type:ignore[attr-defined] + else: + logger.info("Can't redraw GUI whilst a task is running. GUI Settings will be " + "applied at the next restart.") + + +class _Presets(): + """Handles the file dialog and loading and saving of plugin preset files. + + Parameters + ---------- + parent : :class:`DisplayArea` + The parent display area frame + top_level : :class:`tkinter.Toplevel` + The top level pop up window + """ + def __init__(self, parent: DisplayArea, top_level: tk.Toplevel): + logger.debug(parse_class_init(locals())) + self._parent = parent + self._popup = top_level + self._base_path = os.path.join(PATHCACHE, "presets") + self._serializer = get_serializer("json") + logger.debug("Initialized: %s", self.__class__.__name__) - def add_scrollbar(self): - """ Add a scrollbar to the options frame """ - logger.debug("Add Config Scrollbar") - scrollbar = ttk.Scrollbar(self, command=self.canvas.yview) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - self.canvas.config(yscrollcommand=scrollbar.set) - self.optsframe.bind("", self.update_scrollbar) - logger.debug("Added Config Scrollbar") - - def update_scrollbar(self, event): # pylint: disable=unused-argument - """ Update the options frame scrollbar """ - self.canvas.configure(scrollregion=self.canvas.bbox("all")) - - def resize_frame(self, event): - """ Resize the options frame to fit the canvas """ - logger.debug("Resize Config Frame") - canvas_width = event.width - self.canvas.itemconfig(self.optscanvas, width=canvas_width) - logger.debug("Resized Config Frame") - - def add_info(self): - """ Plugin information """ - info_frame = ttk.Frame(self.optsframe) - info_frame.pack(fill=tk.X, expand=True) - lbl = ttk.Label(info_frame, text="About:", width=20, anchor=tk.W) - lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) - info = ttk.Label(info_frame, text=self.plugin_info) - info.pack(padx=5, pady=5, fill=tk.X, expand=True) - - -class OptionControl(): - """ Build the correct control for the option parsed and place it on the - frame """ - - def __init__(self, title, values, option_frame): - logger.debug("Initializing %s", self.__class__.__name__) - self.title = title - self.values = values - self.option_frame = option_frame - - self.control = self.set_control() - self.control_frame = self.set_control_frame() - self.tk_var = self.set_tk_var() - - self.build_full_control() - logger.debug("Initialized %s", self.__class__.__name__) + @property + def _displayed_key(self) -> str: + """str : The currently displayed plugin key""" + retval = self._parent.displayed_key + assert retval is not None + return retval @property - def helptext(self): - """ Format the help text for tooltips """ - logger.debug("Format control help: '%s'", self.title) - helptext = self.values.get("helptext", "") - helptext = helptext.replace("\n\t", "\n - ").replace("%%", "%") - helptext = self.title + " - " + helptext - logger.debug("Formatted control help: (title: '%s', help: '%s'", self.title, helptext) - return helptext - - def set_control(self): - """ Set the correct control type for this option """ - dtype = self.values["type"] - choices = self.values["choices"] - if choices: - control = ttk.Combobox - elif dtype == bool: - control = ttk.Checkbutton - elif dtype in (int, float): - control = ttk.Scale - else: - control = ttk.Entry - logger.debug("Setting control '%s' to %s", self.title, control) - return control - - def set_control_frame(self): - """ Frame to hold control and it's label """ - logger.debug("Build config control frame") - frame = ttk.Frame(self.option_frame) - frame.pack(fill=tk.X, expand=True) - logger.debug("Built confog control frame") - return frame + def _preset_path(self) -> str: + """str : The path to the default preset folder for the currently displayed plugin.""" + return os.path.join(self._base_path, self._displayed_key.split("|")[0]) - def set_tk_var(self): - """ Correct variable type for control """ - logger.debug("Setting config variable type: '%s'", self.title) - var = tk.BooleanVar if self.control == ttk.Checkbutton else tk.StringVar - var = var(self.control_frame) - logger.debug("Set config variable type: ('%s': %s", self.title, type(var)) - return var - - def build_full_control(self): - """ Build the correct control type for the option passed through """ - logger.debug("Build confog option control") - self.build_control_label() - self.build_one_control() - self.values["selected"] = self.tk_var - logger.debug("Built option control") - - def build_control_label(self): - """ Label for control """ - logger.debug("Build config control label: '%s'", self.title) - title = self.title.replace("_", " ").title() - lbl = ttk.Label(self.control_frame, text=title, width=20, anchor=tk.W) - lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) - logger.debug("Built config control label: '%s'", self.title) - - def build_one_control(self): - """ Build and place the option controls """ - logger.debug("Build control: (title: '%s', values: %s)", self.title, self.values) - self.tk_var.set(self.values["value"]) - - if self.control == ttk.Scale: - self.slider_control() - else: - self.control_to_optionsframe() - logger.debug("Built control: '%s'", self.title) - - def slider_control(self): - """ A slider control with corresponding Entry box """ - logger.debug("Add slider control to Config Options Frame: %s", self.control) - d_type = self.values["type"] - rnd = self.values["rounding"] - min_max = self.values["min_max"] - - tbox = ttk.Entry(self.control_frame, width=8, textvariable=self.tk_var, justify=tk.RIGHT) - tbox.pack(padx=(0, 5), side=tk.RIGHT) - ctl = self.control( - self.control_frame, - variable=self.tk_var, - command=lambda val, var=self.tk_var, dt=d_type, rn=rnd, mm=min_max: - set_slider_rounding(val, var, dt, rn, mm)) - ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) - rc_menu = ContextMenu(ctl) - rc_menu.cm_bind() - ctl["from_"] = min_max[0] - ctl["to"] = min_max[1] - - Tooltip(ctl, text=self.helptext, wraplength=720) - Tooltip(tbox, text=self.helptext, wraplength=720) - logger.debug("Added slider control to Options Frame: %s", self.control) - - def control_to_optionsframe(self): - """ Standard non-check buttons sit in the main options frame """ - logger.debug("Add control to Options Frame: %s", self.control) - choices = self.values["choices"] - if self.control == ttk.Checkbutton: - ctl = self.control(self.control_frame, variable=self.tk_var, text=None) + @property + def _full_key(self) -> str: + """str : The full extrapolated lookup key for the currently displayed page.""" + full_key = self._displayed_key + return full_key if "|" in full_key else f"{full_key}|global" + + def load(self) -> None: + """Load a preset on a load preset button press. + + Loads parameters from a saved json file and updates the displayed page. + """ + filename = self._get_filename("load") + if not filename: + return + + opts = self._serializer.load(filename) + if opts.get("__filetype") != "faceswap_preset": + logger.warning("'%s' is not a valid plugin preset file", filename) + return + if opts.get("__section") != self._full_key: + logger.warning("You are attempting to load a preset for '%s' into '%s'. Aborted.", + opts.get("__section", "no section"), self._full_key) + return + + logger.debug("Loaded preset: %s", opts) + + exist = T.cast(dict[str, ControlPanelOption], + self._parent.config_dict[self._displayed_key]["options"]) + for key, val in opts.items(): + if key.startswith("__") or key not in exist: + logger.debug("Skipping non-existent item: '%s'", key) + continue + logger.debug("Setting '%s' to '%s'", key, val) + exist[key].set(val) + logger.info("Preset loaded from: '%s'", os.path.basename(filename)) + + def save(self) -> None: + """Save the preset when on a save preset button is press. + + Compiles currently displayed configuration options into a json file and saves into selected + location. + """ + filename = self._get_filename("save") + if not filename: + return + + opts = T.cast(dict[str, ControlPanelOption], + self._parent.config_dict[self._displayed_key]["options"]) + preset = {opt: val.get() for opt, val in opts.items()} + preset["__filetype"] = "faceswap_preset" + preset["__section"] = self._full_key + self._serializer.save(filename, preset) + logger.info("Preset '%s' saved to: '%s'", self._full_key, filename) + + def _get_filename(self, action: T.Literal["load", "save"]) -> str | None: + """Obtain the filename for load and save preset actions. + + Parameters + ---------- + action : ["load", "save"] + The preset action that is being performed + + Returns + ------- + str | None + The requested preset filename. ``None`` if no filename found + """ + if not self._parent.config_dict.get(self._displayed_key): + logger.info("No settings to %s for the current page.", action) + return None + + if action == "save": + filename = FileHandler("save_filename", + "json", + title="Save Preset...", + initial_folder=self._preset_path, + parent=self._parent, + initial_file=self._get_initial_filename()).return_file else: - ctl = self.control(self.control_frame, textvariable=self.tk_var) - ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) - rc_menu = ContextMenu(ctl) - rc_menu.cm_bind() - if choices: - logger.debug("Adding combo choices: %s", choices) - ctl["values"] = [choice for choice in choices] - Tooltip(ctl, text=self.helptext, wraplength=720) - logger.debug("Added control to Options Frame: %s", self.control) + filename = FileHandler("filename", + "json", + title="Load Preset...", + initial_folder=self._preset_path, + parent=self._parent).return_file + + if not filename: + logger.debug("%s cancelled", action.title()) + + self._raise_toplevel() + return filename + + def _get_initial_filename(self) -> str: + """Obtain the initial filename for saving a preset. + + The name is based on the plugin's display key. A scan of the default presets folder is done + to ensure no filename clash. If a filename does clash, then an integer is added to the end. + + Returns + ------- + str + The initial preset filename + """ + _, key = self._full_key.split("|", 1) + base_filename = f"{key.replace('|', '_')}_preset" + + i = 0 + filename = f"{base_filename}.json" + while True: + if not os.path.exists(os.path.join(self._preset_path, filename)): + break + logger.debug("File pre-exists: %s", filename) + filename = f"{base_filename}_{i}.json" + i += 1 + logger.debug("Initial filename: %s", filename) + return filename + + def _raise_toplevel(self) -> None: + """Bring Toplevel to the top in case file dialog has hidden it.""" + self._popup.update() + self._popup.deiconify() + self._popup.lift() + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/popup_session.py b/lib/gui/popup_session.py new file mode 100644 index 0000000000..f33743cabe --- /dev/null +++ b/lib/gui/popup_session.py @@ -0,0 +1,589 @@ +#!/usr/bin python3 +""" Pop-up Graph launched from the Analysis tab of the Faceswap GUI """ + +import csv +import gettext +import logging +import tkinter as tk + +from dataclasses import dataclass, field +from tkinter import ttk + +from lib.utils import get_module_objects + +from .control_helper import ControlBuilder, ControlPanelOption +from .custom_widgets import Tooltip +from .display_graph import SessionGraph +from .analysis import Calculations, Session +from .utils import FileHandler, get_images, LongRunningTask + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("gui.tooltips", localedir="locales", fallback=True) +_ = _LANG.gettext + + +@dataclass +class SessionTKVars: # pylint:disable=too-many-instance-attributes + """ Dataclass for holding the tk variables required for the session popup + + Parameters + ---------- + buildgraph: :class:`tkinter.BooleanVar` + Trigger variable to indicate the graph should be rebuilt + status: :class:`tkinter.StringVar` + The variable holding the current status of the popup window + display: :class:`tkinter.StringVar` + Variable indicating the type of information to be displayed + scale: :class:`tkinter.StringVar` + Variable indicating whether to display as log or linear data + raw: :class:`tkinter.BooleanVar` + Variable to indicate raw data should be displayed + trend: :class:`tkinter.BooleanVar` + Variable to indicate that trend data should be displayed + avg: :class:`tkinter.BooleanVar` + Variable to indicate that rolling average data should be displayed + smoothed: :class:`tkinter.BooleanVar` + Variable to indicate that smoothed data should be displayed + outliers: :class:`tkinter.BooleanVar` + Variable to indicate that outliers should be displayed + loss_keys: dict + Dictionary of names to :class:`tkinter.BooleanVar` indicating whether specific loss items + should be displayed + avgiterations: :class:`tkinter.IntVar` + The number of iterations to use for rolling average + smoothamount: :class:`tkinter.DoubleVar` + The amount of smoothing to apply for smoothed data + """ + buildgraph: tk.BooleanVar + status: tk.StringVar + display: tk.StringVar + scale: tk.StringVar + raw: tk.BooleanVar + trend: tk.BooleanVar + avg: tk.BooleanVar + smoothed: tk.BooleanVar + outliers: tk.BooleanVar + avgiterations: tk.IntVar + smoothamount: tk.DoubleVar + loss_keys: dict[str, tk.BooleanVar] = field(default_factory=dict) + + +class SessionPopUp(tk.Toplevel): + """ Pop up for detailed graph/stats for selected session. + + session_id: int or `"Total"` + The session id number for the selected session from the Analysis tab. Should be the string + `"Total"` if all sessions are being graphed + data_points: int + The number of iterations in the selected session + """ + def __init__(self, session_id: int, data_points: int) -> None: + logger.debug("Initializing: %s: (session_id: %s, data_points: %s)", + self.__class__.__name__, session_id, data_points) + super().__init__() + self._thread: LongRunningTask | None = None # Thread for loading data in background + self._default_view = "avg" if data_points > 1000 else "smoothed" + self._session_id = None if session_id == "Total" else int(session_id) + + self._graph_frame = ttk.Frame(self) + self._graph: SessionGraph | None = None + self._display_data: Calculations | None = None + + self._vars = self._set_vars() + + self._graph_initialised = False + + optsframe = self._layout_frames() + self._build_options(optsframe) + + self._lbl_loading = ttk.Label(self._graph_frame, text="Loading Data...", anchor=tk.CENTER) + self._lbl_loading.pack(fill=tk.BOTH, expand=True) + self.update_idletasks() + + self._compile_display_data() + + logger.debug("Initialized: %s", self.__class__.__name__) + + def _set_vars(self) -> SessionTKVars: + """ Set status tkinter String variable and tkinter Boolean variable to callback when the + graph is ready to build. + + Returns + ------- + :class:`SessionTKVars` + The tkinter Variables for the pop up graph + """ + logger.debug("Setting tk graph build variable and internal variables") + retval = SessionTKVars(buildgraph=tk.BooleanVar(), + status=tk.StringVar(), + display=tk.StringVar(), + scale=tk.StringVar(), + raw=tk.BooleanVar(), + trend=tk.BooleanVar(), + avg=tk.BooleanVar(), + smoothed=tk.BooleanVar(), + outliers=tk.BooleanVar(), + avgiterations=tk.IntVar(), + smoothamount=tk.DoubleVar()) + retval.buildgraph.set(False) + retval.buildgraph.trace("w", self._graph_build) + return retval + + def _layout_frames(self) -> ttk.Frame: + """ Top level container frames """ + logger.debug("Layout frames") + + leftframe = ttk.Frame(self) + sep = ttk.Frame(self, width=2, relief=tk.RIDGE) + + self._graph_frame.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True) + sep.pack(fill=tk.Y, side=tk.LEFT) + leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5) + + logger.debug("Laid out frames") + + return leftframe + + def _build_options(self, frame: ttk.Frame) -> None: + """ Build Options into the options frame. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + """ + logger.debug("Building Options") + self._opts_combobox(frame) + self._opts_checkbuttons(frame) + self._opts_loss_keys(frame) + self._opts_slider(frame) + self._opts_buttons(frame) + sep = ttk.Frame(frame, height=2, relief=tk.RIDGE) + sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) + logger.debug("Built Options") + + def _opts_combobox(self, frame: ttk.Frame) -> None: + """ Add the options combo boxes. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + """ + logger.debug("Building Combo boxes") + choices = {"Display": ("Loss", "Rate"), "Scale": ("Linear", "Log")} + + for item in ["Display", "Scale"]: + var: tk.StringVar = getattr(self._vars, item.lower()) + + cmbframe = ttk.Frame(frame) + lblcmb = ttk.Label(cmbframe, text=f"{item}:", width=7, anchor=tk.W) + cmb = ttk.Combobox(cmbframe, textvariable=var, width=10) + cmb["values"] = choices[item] + cmb.current(0) + + cmd = self._option_button_reload if item == "Display" else self._graph_scale + var.trace("w", cmd) + hlp = self._set_help(item) + Tooltip(cmbframe, text=hlp, wrap_length=200) + + cmb.pack(fill=tk.X, side=tk.RIGHT) + lblcmb.pack(padx=(0, 2), side=tk.LEFT) + cmbframe.pack(fill=tk.X, pady=5, padx=5, side=tk.TOP) + logger.debug("Built Combo boxes") + + def _opts_checkbuttons(self, frame: ttk.Frame) -> None: + """ Add the options check buttons. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + """ + logger.debug("Building Check Buttons") + self._add_section(frame, "Display") + for item in ("raw", "trend", "avg", "smoothed", "outliers"): + if item == "avg": + text = "Show Rolling Average" + elif item == "outliers": + text = "Flatten Outliers" + else: + text = f"Show {item.title()}" + + var: tk.BooleanVar = getattr(self._vars, item) + if item == self._default_view: + var.set(True) + + ctl = ttk.Checkbutton(frame, variable=var, text=text) + hlp = self._set_help(item) + Tooltip(ctl, text=hlp, wrap_length=200) + ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W) + + logger.debug("Built Check Buttons") + + def _opts_loss_keys(self, frame: ttk.Frame) -> None: + """ Add loss key selections. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + """ + logger.debug("Building Loss Key Check Buttons") + loss_keys = Session.get_loss_keys(self._session_id) + lk_vars = {} + section_added = False + for loss_key in sorted(loss_keys): + if loss_key.startswith("total"): + continue + + text = loss_key.replace("_", " ").title() + helptext = _("Display {}").format(text) + + var = tk.BooleanVar() + var.set(True) + lk_vars[loss_key] = var + + if len(loss_keys) == 1: + # Don't display if there's only one item + break + + if not section_added: + self._add_section(frame, "Keys") + section_added = True + + ctl = ttk.Checkbutton(frame, variable=var, text=text) + Tooltip(ctl, text=helptext, wrap_length=200) + ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W) + + self._vars.loss_keys = lk_vars + logger.debug("Built Loss Key Check Buttons") + + def _opts_slider(self, frame: ttk.Frame) -> None: + """ Add the options entry boxes. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + """ + + self._add_section(frame, "Parameters") + logger.debug("Building Slider Controls") + text = "" + dtype: type[int] | type[float] = int + default: int | float = 0 + rounding = 0 + min_max: tuple[int, int | float] = (0, 0) + for item in ("avgiterations", "smoothamount"): + if item == "avgiterations": + dtype = int + text = "Iterations to Average:" + default = 500 + rounding = 25 + min_max = (25, 2500) + elif item == "smoothamount": + dtype = float + text = "Smoothing Amount:" + default = 0.90 + rounding = 2 + min_max = (0, 0.99) + slider = ControlPanelOption(text, + dtype, + default=default, + rounding=rounding, + min_max=min_max, + helptext=self._set_help(item)) + setattr(self._vars, item, slider.tk_var) + ControlBuilder(frame, slider, 1, 19, None, "Analysis.", True) + logger.debug("Built Sliders") + + def _opts_buttons(self, frame: ttk.Frame) -> None: + """ Add the option buttons. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + """ + logger.debug("Building Buttons") + btnframe = ttk.Frame(frame) + lblstatus = ttk.Label(btnframe, + width=40, + textvariable=self._vars.status, + anchor=tk.W) + + for btntype in ("reload", "save"): + cmd = getattr(self, f"_option_button_{btntype}") + btn = ttk.Button(btnframe, + image=get_images().icons[btntype], # type:ignore[arg-type] + command=cmd) + hlp = self._set_help(btntype) + Tooltip(btn, text=hlp, wrap_length=200) + btn.pack(padx=2, side=tk.RIGHT) + + lblstatus.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=True) + btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM) + logger.debug("Built Buttons") + + @classmethod + def _add_section(cls, frame: ttk.Frame, title: str) -> None: + """ Add a separator and section title between options + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The frame that the options reside in + title: str + The section title to display + """ + sep = ttk.Frame(frame, height=2, relief=tk.SOLID) + lbl = ttk.Label(frame, text=title) + + lbl.pack(side=tk.TOP, padx=5, pady=0, anchor=tk.CENTER) + sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP) + + def _option_button_save(self) -> None: + """ Action for save button press. """ + logger.debug("Saving File") + savefile = FileHandler("save", "csv").return_file + if not savefile: + logger.debug("Save Cancelled") + return + logger.debug("Saving to: %s", savefile) + assert self._display_data is not None + save_data = self._display_data.stats + fieldnames = sorted(key for key in save_data.keys()) + + with savefile as outfile: + csvout = csv.writer(outfile, delimiter=",") + csvout.writerow(fieldnames) + csvout.writerows(zip(*[save_data[key] for key in fieldnames])) + + def _option_button_reload(self, *args) -> None: # pylint:disable=unused-argument + """ Action for reset button press and checkbox changes. + + Parameters + ---------- + args: tuple + Required for TK Callback but unused + """ + logger.debug("Refreshing Graph") + if not self._graph_initialised: + return + valid = self._compile_display_data() + if not valid: + logger.debug("Invalid data") + return + assert self._graph is not None + self._graph.refresh(self._display_data, + self._vars.display.get(), + self._vars.scale.get()) + logger.debug("Refreshed Graph") + + def _graph_scale(self, *args) -> None: # pylint:disable=unused-argument + """ Action for changing graph scale. + + Parameters + ---------- + args: tuple + Required for TK Callback but unused + """ + assert self._graph is not None + if not self._graph_initialised: + return + self._graph.set_yscale_type(self._vars.scale.get()) + + @classmethod + def _set_help(cls, action: str) -> str: + """ Set the help text for option buttons. + + Parameters + ---------- + action: str + The action to get the help text for + + Returns + ------- + str + The help text for the given action + """ + lookup = { + "reload": _("Refresh graph"), + "save": _("Save display data to csv"), + "avgiterations": _("Number of data points to sample for rolling average"), + "smoothamount": _("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum " + "smoothing"), + "outliers": _("Flatten data points that fall more than 1 standard deviation from the " + "mean to the mean value."), + "avg": _("Display rolling average of the data"), + "smoothed": _("Smooth the data"), + "raw": _("Display raw data"), + "trend": _("Display polynormal data trend"), + "display": _("Set the data to display"), + "scale": _("Change y-axis scale")} + return lookup.get(action.lower(), "") + + def _compile_display_data(self) -> bool: + """ Compile the data to be displayed. + + Returns + ------- + bool + ``True`` if there is valid data to display, ``False`` if not + """ + if self._thread is None: + logger.debug("Compiling Display Data in background thread") + loss_keys = [key for key, val in self._vars.loss_keys.items() + if val.get()] + logger.debug("Selected loss_keys: %s", loss_keys) + + selections = self._selections_to_list() + + if not self._check_valid_selection(loss_keys, selections): + logger.warning("No data to display. Not refreshing") + return False + self._vars.status.set("Loading Data...") + + if self._graph is not None: + self._graph.pack_forget() + self._lbl_loading.pack(fill=tk.BOTH, expand=True) + self.update_idletasks() + + kwargs = {"session_id": self._session_id, + "display": self._vars.display.get(), + "loss_keys": loss_keys, + "selections": selections, + "avg_samples": self._vars.avgiterations.get(), + "smooth_amount": self._vars.smoothamount.get(), + "flatten_outliers": self._vars.outliers.get()} + self._thread = LongRunningTask(target=self._get_display_data, + kwargs=kwargs, + widget=self) + self._thread.start() + self.after(1000, self._compile_display_data) + return True + if not self._thread.complete.is_set(): + logger.debug("Popup Data not yet available") + self.after(1000, self._compile_display_data) + return True + + logger.debug("Getting Popup from background Thread") + self._display_data = self._thread.get_result() + self._thread = None + if not self._check_valid_data(): + logger.warning("No valid data to display. Not refreshing") + self._vars.status.set("") + return False + logger.debug("Compiled Display Data") + self._vars.buildgraph.set(True) + return True + + @classmethod + def _get_display_data(cls, **kwargs) -> Calculations: + """ Get the display data in a LongRunningTask. + + Parameters + ---------- + kwargs: dict + The keyword arguments to pass to `lib.gui.analysis.Calculations` + + Returns + ------- + :class:`lib.gui.analysis.Calculations` + The summarized results for the given session + """ + return Calculations(**kwargs) + + def _check_valid_selection(self, loss_keys: list[str], selections: list[str]) -> bool: + """ Check that there will be data to display. + + Parameters + ---------- + loss_keys: list + The selected loss to display + selections: list + The selected checkbox options + + Returns + ------- + bool + ``True` if there is data to be displayed, otherwise ``False`` + """ + display = self._vars.display.get().lower() + logger.debug("Validating selection. (loss_keys: %s, selections: %s, display: %s)", + loss_keys, selections, display) + if not selections or (display == "loss" and not loss_keys): + return False + return True + + def _check_valid_data(self) -> bool: + """ Check that the selections holds valid data to display + NB: len-as-condition is used as data could be a list or a numpy array + + Returns + ------- + bool + ``True` if there is data to be displayed, otherwise ``False`` + """ + assert self._display_data is not None + logger.debug("Validating data. %s", + {key: len(val) for key, val in self._display_data.stats.items()}) + if any(len(val) == 0 # pylint:disable=len-as-condition + for val in self._display_data.stats.values()): + return False + return True + + def _selections_to_list(self) -> list[str]: + """ Compile checkbox selections to a list. + + Returns + ------- + list + The selected options from the check-boxes + """ + logger.debug("Compiling selections to list") + selections = [] + for item in ("raw", "trend", "avg", "smoothed"): + var: tk.BooleanVar = getattr(self._vars, item) + if var.get(): + selections.append(item) + logger.debug("Compiling selections to list: %s", selections) + return selections + + def _graph_build(self, *args) -> None: # pylint:disable=unused-argument + """ Build the graph in the top right paned window + + Parameters + ---------- + args: tuple + Required for TK Callback but unused + """ + if not self._vars.buildgraph.get(): + return + self._vars.status.set("Loading Data...") + logger.debug("Building Graph") + self._lbl_loading.pack_forget() + self.update_idletasks() + if self._graph is None: + graph = SessionGraph(self._graph_frame, + self._display_data, + self._vars.display.get(), + self._vars.scale.get()) + graph.pack(expand=True, fill=tk.BOTH) + graph.build() + self._graph = graph + self._graph_initialised = True + else: + self._graph.refresh(self._display_data, + self._vars.display.get(), + self._vars.scale.get()) + self._graph.pack(fill=tk.BOTH, expand=True) + self._vars.status.set("") + self._vars.buildgraph.set(False) + logger.debug("Built Graph") + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/project.py b/lib/gui/project.py new file mode 100644 index 0000000000..24a4a392fe --- /dev/null +++ b/lib/gui/project.py @@ -0,0 +1,1033 @@ +#!/usr/bin/env python3 +""" Handling of Faceswap GUI Projects, Tasks and Last Session """ + +import logging +import os +import tkinter as tk +from tkinter import messagebox + +from lib.serializer import get_serializer +from lib.gui import gui_config as cfg +from lib.utils import get_module_objects + + +logger = logging.getLogger(__name__) + + +class _GuiSession(): # pylint:disable=too-few-public-methods + """ Parent class for GUI Session Handlers. + + Parameters + ---------- + config: :class:`lib.gui.utils.Config` + The master GUI config + file_handler: :class:`lib.gui.utils.FileHandler` + A file handler object + + """ + def __init__(self, config, file_handler=None): + # NB file_handler has to be passed in to avoid circular imports + logger.debug("Initializing: %s: (config: %s, file_handler: %s)", + self.__class__.__name__, config, file_handler) + self._serializer = get_serializer("json") + self._config = config + + self._options = None + self._file_handler = file_handler + self._filename = None + self._saved_tasks = None + self._modified = False + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def _active_tab(self): + """ str: The name of the currently selected :class:`lib.gui.command.CommandNotebook` + tab. """ + notebook = self._config.command_notebook + toolsbook = self._config.tools_notebook + command = notebook.tab(notebook.select(), "text").lower() + if command == "tools": + command = toolsbook.tab(toolsbook.select(), "text").lower() + logger.debug("Active tab: %s", command) + return command + + @property + def _modified_vars(self): + """ dict: The tkinter Boolean vars indicating the modified state for each tab. """ + return self._config.modified_vars + + @property + def _file_exists(self): + """ bool: ``True`` if :attr:`_filename` exists otherwise ``False``. """ + return self._filename is not None and os.path.isfile(self._filename) + + @property + def _cli_options(self): + """ dict: the raw cli options from :attr:`_options` with project fields removed. """ + return {key: val for key, val in self._options.items() if isinstance(val, dict)} + + @property + def _default_options(self): + """ dict: The default options for all tabs """ + return self._config.default_options + + @property + def _dirname(self): + """ str: The folder name that :attr:`_filename` resides in. Returns ``None`` if + filename is ``None``. """ + return os.path.dirname(self._filename) if self._filename is not None else None + + @property + def _basename(self): + """ str: The base name of :attr:`_filename`. Returns ``None`` if filename is ``None``. """ + return os.path.basename(self._filename) if self._filename is not None else None + + @property + def _stored_tab_name(self): + """str: The tab_name stored in :attr:`_options` or ``None`` if it does not exist """ + if self._options is None: + return None + return self._options.get("tab_name", None) + + @property + def _selected_to_choices(self): + """ dict: The selected value and valid choices for multi-option, radio or combo options. + """ + valid_choices = {cmd: {opt: {"choices": val.cpanel_option.choices, + "is_multi": val.cpanel_option.is_multi_option} + for opt, val in data.items() + if hasattr(val, "cpanel_option") # Filter out helptext + and val.cpanel_option.choices is not None + } + for cmd, data in self._config.cli_opts.opts.items()} + logger.trace("valid_choices: %s", valid_choices) + retval = {command: {option: {"value": value, + "is_multi": valid_choices[command][option]["is_multi"], + "choices": valid_choices[command][option]["choices"]} + for option, value in options.items() + if value and command in valid_choices + and option in valid_choices[command]} + for command, options in self._options.items() + if isinstance(options, dict)} + logger.trace("returning: %s", retval) + return retval + + def _current_gui_state(self, command=None): + """ The current state of the GUI. + + Parameters + ---------- + command: str, optional + If provided, returns the state of just the given tab command. If ``None`` returns options + for all tabs. Default ``None`` + + Returns + ------- + dict: The options currently set in the GUI + """ + return self._config.cli_opts.get_option_values(command) + + def _set_filename(self, filename=None, sess_type="project"): + """ Set the :attr:`_filename` attribute. + + :attr:`_filename` is set either from a given filename or the result from + a :attr:`_file_handler`. + + Parameters + ---------- + filename: str, optional + An optional filename. If given then this filename will be used otherwise it will be + collected by a :attr:`_file_handler` + + sess_type: {all, project, task}, optional + The session type that the filename is being set for. Dictates the type of file handler + that is opened. + + Returns + ------- + bool: `True` if filename has been successfully set otherwise ``False`` + """ + logger.debug("filename: '%s', sess_type: '%s'", filename, sess_type) + handler = f"config_{sess_type}" + + if filename is None: + logger.debug("Popping file handler") + cfgfile = self._file_handler("open", handler).return_file + if not cfgfile: + logger.debug("No filename given") + return False + filename = cfgfile.name + cfgfile.close() + + if not os.path.isfile(filename): + msg = f"File does not exist: '{filename}'" + logger.error(msg) + return False + ext = os.path.splitext(filename)[1] + if (sess_type == "project" and ext != ".fsw") or (sess_type == "task" and ext != ".fst"): + logger.debug("Invalid file extension for session type: (sess_type: '%s', " + "extension: '%s')", sess_type, ext) + return False + logger.debug("Setting filename: '%s'", filename) + self._filename = filename + return True + + # GUI STATE SETTING + def _set_options(self, command=None): + """ Set the GUI options based on the currently stored properties of :attr:`_options` + and sets the active tab. + + Parameters + ---------- + command: str, optional + The tab to set the options for. If None then sets options for all tabs. + Default: ``None`` + """ + opts = self._get_options_for_command(command) if command else self._cli_options + logger.debug("command: %s, opts: %s", command, opts) + if opts is None: + logger.debug("No options found. Returning") + return + for cmd, opt in opts.items(): + self._set_gui_state_for_command(cmd, opt) + tab_name = self._options.get("tab_name", None) if command is None else command + tab_name = tab_name if tab_name is not None else "extract" + logger.debug("tab_name: %s", tab_name) + self._config.set_active_tab_by_name(tab_name) + + def _get_options_for_command(self, command): + """ Return a single command's options from :attr:`_options` formatted consistently with + an all options dict. + + Parameters + ---------- + command: str + The command to return the options for + + Returns + ------- + dict: The options for a single command in the format {command: options}. If the command + is not found then returns ``None`` + """ + logger.debug(command) + opts = self._options.get(command, None) + retval = {command: opts} + if not opts: + self._config.tk_vars.console_clear.set(True) + logger.info("No %s section found in file", command) + retval = None + logger.debug(retval) + return retval + + def _set_gui_state_for_command(self, command, options): + """ Set the GUI state for the given command. + + Parameters + ---------- + command: str + The tab to set the options for + options: dict + The option values to set the GUI to + """ + logger.debug("command: %s: options: %s", command, options) + if not options: + logger.debug("No options provided, not updating GUI") + return + for srcopt, srcval in options.items(): + optvar = self._config.cli_opts.get_one_option_variable(command, srcopt) + if not optvar: + continue + logger.trace("setting option: (srcopt: %s, optvar: %s, srcval: %s)", + srcopt, optvar, srcval) + optvar.set(srcval) + + def _reset_modified_var(self, command=None): + """ Reset :attr:`_modified_vars` variables back to unmodified (`False`) for all + commands or for the given command. + + Parameters + ---------- + command: str, optional + The command to reset the modified tkinter variable for. If ``None`` then all tkinter + modified variables are reset to `False`. Default: ``None`` + """ + for key, tk_var in self._modified_vars.items(): + if (command is None or command == key) and tk_var.get(): + logger.debug("Reset modified state for: (command: %s key: %s)", command, key) + tk_var.set(False) + + # RECENT FILE HANDLING + def _add_to_recent(self, command=None): + """ Add the file for this session to the recent files list. + + Parameters + ---------- + command: str, optional + The command that this session relates to. If `None` then the whole project is added. + Default: ``None`` + """ + logger.debug(command) + if self._filename is None: + logger.debug("No filename for selected file. Not adding to recent.") + return + recent_filename = os.path.join(self._config.pathcache, ".recent.json") + logger.debug("Adding to recent files '%s': (%s, %s)", + recent_filename, self._filename, command) + if not os.path.exists(recent_filename) or os.path.getsize(recent_filename) == 0: + logger.debug("Starting with empty recent_files list") + recent_files = [] + else: + logger.debug("loading recent_files list: %s", recent_filename) + recent_files = self._serializer.load(recent_filename) + logger.debug("Initial recent files: %s", recent_files) + recent_files = self._del_from_recent(self._filename, recent_files) + ftype = "project" if command is None else command + recent_files.insert(0, (self._filename, ftype)) + recent_files = recent_files[:20] + logger.debug("Final recent files: %s", recent_files) + self._serializer.save(recent_filename, recent_files) + + def _del_from_recent(self, filename, recent_files=None, save=False): + """ Remove an item from the recent files list. + + Parameters + ---------- + filename: str + The filename to be removed from the recent files list + recent_files: list, optional + If the recent files list has already been loaded, it can be passed in to avoid + loading again. If ``None`` then load the recent files list from disk. Default: ``None`` + save: bool, optional + Whether the recent files list should be saved after removing the file. ``True`` saves + the file, ``False`` does not. Default: ``False`` + """ + recent_filename = os.path.join(self._config.pathcache, ".recent.json") + if recent_files is None: + logger.debug("Loading file list from disk: %s", recent_filename) + if not os.path.exists(recent_filename) or os.path.getsize(recent_filename) == 0: + logger.debug("No recent file list") + return None + recent_files = self._serializer.load(recent_filename) + filenames = [recent[0] for recent in recent_files] + if filename in filenames: + idx = filenames.index(filename) + logger.debug("Removing from recent file list: %s", filename) + del recent_files[idx] + if save: + logger.debug("Saving recent files list: %s", recent_filename) + self._serializer.save(recent_filename, recent_files) + else: + logger.debug("Filename '%s' does not appear in recent file list", filename) + return recent_files + + def _get_lone_task(self): + """ Get the sole command name from :attr:`_options`. + + Returns + ------- + str: The only existing command name in the current :attr:`_options` dict or ``None`` if + there are multiple commands stored. + """ + command = None + if len(self._cli_options) == 1: + command = list(self._cli_options.keys())[0] + logger.debug(command) + return command + + # DISK IO + def _load(self): + """ Load GUI options from :attr:`_filename` location and set to :attr:`_options`. + + Returns + ------- + bool: ``True`` if successfully loaded otherwise ``False`` + """ + if self._file_exists: + logger.debug("Loading config") + self._options = self._serializer.load(self._filename) + self._check_valid_choices() + retval = True + else: + logger.debug("File doesn't exist. Aborting") + retval = False + return retval + + def _check_valid_choices(self): + """ Check whether the loaded file has any selected combo/radio/multi-option values that are + no longer valid and remove them so that they are not passed into faceswap. """ + for command, options in self._selected_to_choices.items(): + for option, data in options.items(): + if ((data["is_multi"] and all(v in data["choices"] for v in data["value"].split())) + or not data["is_multi"] and data["value"] in data["choices"]): + continue + if data["is_multi"]: + val = " ".join([v for v in data["value"].split() if v in data["choices"]]) + else: + val = "" + val = self._default_options[command][option] if not val else val + logger.debug("Updating invalid value to default: (command: '%s', option: '%s', " + "original value: '%s', new value: '%s')", command, option, + self._options[command][option], val) + self._options[command][option] = val + + def _save_as_to_filename(self, session_type): + """ Set :attr:`_filename` from a save as dialog. + + Parameters + ---------- + session_type: ['all', 'task', 'project'] + The type of session to pop the save as dialog for. Limits the allowed filetypes + + Returns + ------- + bool: + True if :attr:`filename` successfully set otherwise ``False`` + """ + logger.debug("Popping save as file handler. session_type: '%s'", session_type) + title = f"Save {f'{session_type.title()} ' if session_type != 'all' else ''}As..." + cfgfile = self._file_handler("save", + f"config_{session_type}", + title=title, + initial_folder=self._dirname).return_file + if not cfgfile: + logger.debug("No filename provided. session_type: '%s'", session_type) + return False + self._filename = cfgfile.name + logger.debug("Set filename: (session_type: '%s', filename: '%s'", + session_type, self._filename) + cfgfile.close() + return True + + def _save(self, command=None): + """ Collect the options in the current GUI state and save. + + Obtains the current options set in the GUI with the selected tab and applies them to + :attr:`_options`. Saves :attr:`_options` to :attr:`_filename`. Resets :attr:_modified_vars + for either the given command or all commands, + + Parameters + ---------- + command: str, optional + The tab to collect the current state for. If ``None`` then collects the current + state for all tabs. Default: ``None`` + """ + self._options = self._current_gui_state(command) + self._options["tab_name"] = self._active_tab + logger.debug("Saving options: (filename: %s, options: %s", self._filename, self._options) + self._serializer.save(self._filename, self._options) + self._reset_modified_var(command) + self._add_to_recent(command) + + +class Tasks(_GuiSession): + """ Faceswap ``.fst`` Task File handling. + + Faceswap tasks handle the management of each individual task tab in the GUI. Unlike + :class:`Projects`, Tasks contains all the active tasks currently running, rather than an + individual task. + + Parameters + ---------- + config: :class:`lib.gui.utils.Config` + The master GUI config + file_handler: :class:`lib.gui.utils.FileHandler` + A file handler object + """ + def __init__(self, config, file_handler): + super().__init__(config, file_handler) + self._tasks = {} + + @property + def _is_project(self): + """ str: ``True`` if all tasks are from an overarching session project else ``False``.""" + retval = False if not self._tasks else all(v.get("is_project", False) + for v in self._tasks.values()) + return retval + + @property + def _project_filename(self): + """ str: The overarching session project filename.""" + fname = None + if not self._is_project: + return fname + + for val in self._tasks.values(): + fname = val["filename"] + break + return fname + + def load(self, *args, # pylint:disable=unused-argument + filename=None, current_tab=True): + """ Load a task into this :class:`Tasks` class. + + Tasks can be loaded from project ``.fsw`` files or task ``.fst`` files, depending on where + this function is being called from. + + Parameters + ---------- + *args: tuple + Unused, but needs to be present for arguments passed by tkinter event handling + filename: str, optional + If a filename is passed in, This will be used, otherwise a file handler will be + launched to select the relevant file. + current_tab: bool, optional + ``True`` if the task to be loaded must be for the currently selected tab. ``False`` + if loading a task into any tab. If current_tab is `True` then tasks can be loaded from + ``.fsw`` and ``.fst`` files, otherwise they can only be loaded from ``.fst`` files. + Default: ``True`` + """ + logger.debug("Loading task config: (filename: '%s', current_tab: '%s')", + filename, current_tab) + + # Option to load specific task from project files: + sess_type = "all" if current_tab else "task" + + is_legacy = (not self._is_project and + filename is not None and sess_type == "task" and + os.path.splitext(filename)[1] == ".fsw") + if is_legacy: + logger.debug("Legacy task found: '%s'", filename) + filename = self._update_legacy_task(filename) + + filename_set = self._set_filename(filename, sess_type=sess_type) + if not filename_set: + return + loaded = self._load() + if not loaded: + return + + command = self._active_tab if current_tab else self._stored_tab_name + command = self._get_lone_task() if command is None else command + if command is None: + logger.error("Unable to determine task from the given file: '%s'", filename) + return + if command not in self._options: + logger.error("No '%s' task in '%s'", command, self._filename) + return + + self._set_options(command) + self._add_to_recent(command) + + if self._is_project: + self._filename = self._project_filename + elif self._filename.endswith(".fsw"): + self._filename = None + + self._add_task(command) + if is_legacy: + self.save() + + logger.debug("Loaded task config: (command: '%s', filename: '%s')", command, filename) + + def _update_legacy_task(self, filename): + """ Update legacy ``.fsw`` tasks to ``.fst`` tasks. + + Tasks loaded from the recent files menu may be passed in with a ``.fsw`` extension. + This renames the file and removes it from the recent file list. + + Parameters + ---------- + filename: str + The filename of the `.fsw` file that needs converting + + Returns + ------- + str: + The new filename of the updated tasks file + """ + # TODO remove this code after a period of time. Implemented November 2019 + logger.debug("original filename: '%s'", filename) + fname, ext = os.path.splitext(filename) + if ext != ".fsw": + logger.debug("Not a .fsw file: '%s'", filename) + return filename + + new_filename = f"{fname}.fst" + logger.debug("Renaming '%s' to '%s'", filename, new_filename) + os.rename(filename, new_filename) + self._del_from_recent(filename, save=True) + logger.debug("new filename: '%s'", new_filename) + return new_filename + + def save(self, save_as=False): + """ Save the current GUI state for the active tab to a ``.fst`` faceswap task file. + + Parameters + ---------- + save_as: bool, optional + Whether to save to the stored filename, or pop open a file handler to ask for a + location. If there is no stored filename, then a file handler will automatically be + popped. + """ + logger.debug("Saving config...") + self._set_active_task() + save_as = save_as or self._is_project or self._filename is None + + if save_as and not self._save_as_to_filename("task"): + return + + command = self._active_tab + self._save(command=command) + self._add_task(command) + if not save_as: + logger.info("Saved project to: '%s'", self._filename) + else: + logger.debug("Saved project to: '%s'", self._filename) + + def clear(self): + """ Reset all GUI options to their default values for the active tab. """ + self._config.cli_opts.reset(self._active_tab) + + def reload(self): + """ Reset currently selected tab GUI options to their last saved state. """ + self._set_active_task() + + if self._options is None: + logger.info("No active task to reload") + return + logger.debug("Reloading task") + self.load(filename=self._filename, current_tab=True) + if self._is_project: + self._reset_modified_var(self._active_tab) + + def _add_task(self, command): + """ Add the currently active task to the internal :attr:`_tasks` dict. + + If the currently stored task is from an overarching session project, then + only the options are updated. When resetting a tab to saved a project will always + be preferred to a task loaded into the project, so the original reference file name + stays with the project. + + Parameters + ---------- + command: str + The tab that pertains to the currently active task + + """ + self._tasks[command] = {"filename": self._filename, + "options": self._options, + "is_project": self._is_project} + + def clear_tasks(self): + """ Clears all of the stored tasks. + + This is required when loading a task stored in a legacy project file, and is only to be + called by :class:`Project` when a project has been loaded which is in fact a task. + """ + logger.debug("Clearing stored tasks") + self._tasks = {} + + def add_project_task(self, filename, command, options): + """ Add an individual task from a loaded :class:`Project` to the internal :attr:`_tasks` + dict. + + Project tasks take priority over any other tasks, so the individual tasks from a new + project must be placed in the _tasks dict. + + Parameters + ---------- + filename: str + The filename of the session project file + command: str + The tab that this task's options belong to + options: dict + The options for this task loaded from the project + """ + self._tasks[command] = {"filename": filename, "options": options, "is_project": True} + + def _set_active_task(self, command=None): + """ Set the active :attr:`_filename` and :attr:`_options` to currently selected tab's + options. + + Parameters + ---------- + command: str, optional + If a command is passed in then set the given tab to active, If this is none set the tab + which currently has focus to active. Default: ``None`` + """ + logger.debug(command) + command = self._active_tab if command is None else command + task = self._tasks.get(command, None) + if task is None: + self._filename, self._options = (None, None) + else: + self._filename, self._options = (task.get("filename", None), task.get("options", None)) + logger.debug("tab: %s, filename: %s, options: %s", + self._active_tab, self._filename, self._options) + + +class Project(_GuiSession): + """ Faceswap ``.fsw`` Project File handling. + + Faceswap projects handle the management of all task tabs in the GUI and updates + the main Faceswap title bar with the project name and modified state. + + Parameters + ---------- + config: :class:`lib.gui.utils.Config` + The master GUI config + file_handler: :class:`lib.gui.utils.FileHandler` + A file handler object + """ + + def __init__(self, config, file_handler): + super().__init__(config, file_handler) + self._update_root_title() + + @property + def filename(self): + """ str: The currently active project filename. """ + return self._filename + + @property + def cli_options(self): + """ dict: the raw cli options from :attr:`_options` with project fields removed. """ + return self._cli_options + + @property + def _project_modified(self): + """bool: ``True`` if the project has been modified otherwise ``False``. """ + return any(var.get() for var in self._modified_vars.values()) + + @property + def _tasks(self): + """ :class:`Tasks`: The current session's :class:``Tasks``. """ + return self._config.tasks + + def set_default_options(self): + """ Set the default options. The Default GUI options are stored on Faceswap startup. + + Exposed as the :attr:`_default_options` for a project cannot be set until after the main + Command Tabs have been loaded. + """ + logger.debug("Setting options to default") + self._options = self._default_options + + # MODIFIED STATE CALLBACK + def set_modified_callback(self): + """ Adds a callback to each of the :attr:`_modified_vars` tkinter variables + When one of these variables is changed, triggers :func:`_modified_callback` + with the command that was changed. + + This is exposed as the callback can only be added after the main Command Tabs have + been drawn, and their options' initial values have been set. + + """ + for key, tkvar in self._modified_vars.items(): + logger.debug("Adding callback for tab: %s", key) + tkvar.trace("w", self._modified_callback) + + def _modified_callback(self, *args): # pylint:disable=unused-argument + """ Update the project modified state on a GUI modification change and + update the Faceswap title bar. """ + if self._project_modified and self._current_gui_state() == self._cli_options: + logger.debug("Project is same as stored. Setting modified to False") + self._reset_modified_var() + + if self._modified != self._project_modified: + logger.debug("Updating project state from variable: (modified: %s)", + self._project_modified) + self._modified = self._project_modified + self._update_root_title() + + def load(self, *args, # pylint:disable=unused-argument + filename=None, last_session=False): + """ Load a project from a saved ``.fsw`` project file. + + Parameters + ---------- + *args: tuple + Unused, but needs to be present for arguments passed by tkinter event handling + filename: str, optional + If a filename is passed in, This will be used, otherwise a file handler will be + launched to select the relevant file. + last_session: bool, optional + ``True`` if the project is being loaded from the last opened session ``False`` if the + project is being loaded directly from disk. Default: ``False`` + """ + logger.debug("Loading project config: (filename: '%s', last_session: %s)", + filename, last_session) + filename_set = self._set_filename(filename, sess_type="project") + + if not filename_set: + logger.debug("No filename set") + return + loaded = self._load() + if not loaded: + logger.debug("Options not loaded") + return + + # Legacy .fsw files could store projects or tasks. Check if this is a legacy file + # and hand off file to Tasks if necessary + command = self._get_lone_task() + legacy = command is not None + if legacy: + self._handoff_legacy_task() + return + + if not last_session: + self._set_options() # Options will be set by last session. Don't set now + self._update_tasks() + self._add_to_recent() + self._reset_modified_var() + self._update_root_title() + logger.debug("Loaded project config: (command: '%s', filename: '%s')", command, filename) + + def _handoff_legacy_task(self): + """ Update legacy tasks saved with the old file extension ``.fsw`` to tasks ``.fst``. + + Hands off file handling to :class:`Tasks` and resets project to default. + """ + logger.debug("Updating legacy task '%s", self._filename) + filename = self._filename + self._filename = None + self.set_default_options() + self._tasks.clear_tasks() + self._tasks.load(filename=filename, current_tab=False) + logger.debug("Updated legacy task and reset project") + + def _update_tasks(self): + """ Add the tasks from the loaded project to the :class:`Tasks` class. """ + for key, val in self._cli_options.items(): + opts = {key: val} + opts["tab_name"] = key + self._tasks.add_project_task(self._filename, key, opts) + + def reload(self, *args): # pylint:disable=unused-argument + """ Reset all GUI's option tabs to their last saved state. + + Parameters + ---------- + *args: tuple + Unused, but needs to be present for arguments passed by tkinter event handling + """ + if self._options is None: + logger.info("No active project to reload") + return + logger.debug("Reloading project") + self._set_options() + self._update_tasks() + self._reset_modified_var() + self._update_root_title() + + def _update_root_title(self): + """ Update the root Window title with the project name. Add a asterisk + if the file is modified. """ + text = "" if self._basename is None else self._basename + text += "*" if self._modified else "" + self._config.set_root_title(text=text) + + def save(self, *args, save_as=False): # pylint:disable=unused-argument + """ Save the current GUI state to a ``.fsw`` project file. + + Parameters + ---------- + *args: tuple + Unused, but needs to be present for arguments passed by tkinter event handling + save_as: bool, optional + Whether to save to the stored filename, or pop open a file handler to ask for a + location. If there is no stored filename, then a file handler will automatically be + popped. + """ + logger.debug("Saving config as...") + + save_as = save_as or self._filename is None + if save_as and not self._save_as_to_filename("project"): + return + self._save() + self._update_tasks() + self._update_root_title() + if not save_as: + logger.info("Saved project to: '%s'", self._filename) + else: + logger.debug("Saved project to: '%s'", self._filename) + + def new(self, *args): # pylint:disable=unused-argument + """ Create a new project with default options. + + Pops a file handler to select location. + + Parameters + ---------- + *args: tuple + Unused, but needs to be present for arguments passed by tkinter event handling + """ + logger.debug("Creating new project") + if not self.confirm_close(): + logger.debug("Creating new project cancelled") + return + + cfgfile = self._file_handler("save", + "config_project", + title="New Project...", + initial_folder=self._basename).return_file + if not cfgfile: + logger.debug("No filename selected") + return + self._filename = cfgfile.name + cfgfile.close() + + self.set_default_options() + self._config.cli_opts.reset() + self._save() + self._update_root_title() + + def close(self, *args): # pylint:disable=unused-argument + """ Clear the current project and set all options to default. + + Parameters + ---------- + *args: tuple + Unused, but needs to be present for arguments passed by tkinter event handling + """ + logger.debug("Close requested") + if not self.confirm_close(): + logger.debug("Close cancelled") + return + self._config.cli_opts.reset() + self._filename = None + self.set_default_options() + self._reset_modified_var() + self._update_root_title() + self._config.set_active_tab_by_name(cfg.tab()) + + def confirm_close(self): + """ Pop a message box to get confirmation that an unsaved project should be closed + + Returns + ------- + bool: ``True`` if user confirms close, ``False`` if user cancels close + """ + if not self._modified: + logger.debug("Project is not modified") + return True + confirmtxt = "You have unsaved changes.\n\nAre you sure you want to close the project?" + if messagebox.askokcancel("Close", confirmtxt, default="cancel", icon="warning"): + logger.debug("Close Cancelled") + return True + logger.debug("Close confirmed") + return False + + +class LastSession(_GuiSession): + """ Faceswap Last Session handling. + + Faceswap :class:`LastSession` handles saving the state of the Faceswap GUI at close and + reloading the state at launch. + + Last Session behavior can be configured in :file:`config.gui.ini`. + + Parameters + ---------- + config: :class:`lib.gui.utils.Config` + The master GUI config + """ + + def __init__(self, config): + super().__init__(config) + self._filename = os.path.join(self._config.pathcache, ".last_session.json") + if not self._enabled: + return + + if cfg.autosave_last_session() == "prompt": + self.ask_load() + elif cfg.autosave_last_session() == "always": + self.load() + + @property + def _enabled(self): + """ bool: ``True`` if autosave is enabled otherwise ``False``. """ + return cfg.autosave_last_session() != "never" + + def from_dict(self, options): + """ Set the :attr:`_options` property based on the given options dictionary + and update the GUI to use these values. + + This function is required for reloading the GUI state when the GUI has been force + refreshed on a config change. + + Parameters + ---------- + options: dict + The options to set. Should be the output of :func:`to_dict` + """ + logger.debug("Setting options from dict: %s", options) + self._options = options + self._set_options() + + def to_dict(self): + """ Collect the current GUI options and place them in a dict for retrieval or storage. + + This function is required for reloading the GUI state when the GUI has been force + refreshed on a config change. + + Returns + ------- + dict: The current cli options ready for saving or retrieval by :func:`from_dict` + """ + opts = self._current_gui_state() + logger.debug("Collected opts: %s", opts) + if not opts or opts == self._default_options: + logger.debug("Default session, or no opts found. Not saving last session.") + return None + opts["tab_name"] = self._active_tab + opts["project"] = self._config.project.filename + logger.debug("Added project items: %s", {k: v for k, v in opts.items() + if k in ("tab_name", "project")}) + return opts + + def ask_load(self): + """ Pop a message box to ask the user if they wish to load their last session. """ + if not self._file_exists: + logger.debug("No last session file found") + elif tk.messagebox.askyesno("Last Session", "Load last session?"): + logger.debug("Loading last session at user request") + self.load() + else: + logger.debug("Not loading last session at user request") + logger.debug("Deleting LastSession file") + os.remove(self._filename) + + def load(self): + """ Load the last session. + + Loads the last saved session options. Checks if a previous project was loaded + and whether there have been changes since the last saved version of the project. + Sets the display and :class:`Project` and :class:`Task` objects accordingly. + """ + loaded = self._load() + if not loaded: + return + self._set_project() + self._set_options() + + def _set_project(self): + """ Set the :class:`Project` if session is resuming from one. """ + if self._options.get("project", None) is None: + logger.debug("No project stored") + else: + logger.debug("Loading stored project") + self._config.project.load(filename=self._options["project"], last_session=True) + + def save(self): + """ Save a snapshot of currently set GUI config options. + + Called on Faceswap shutdown. + """ + if not self._enabled: + logger.debug("LastSession not enabled") + if os.path.exists(self._filename): + logger.debug("Deleting existing LastSession file") + os.remove(self._filename) + return + + opts = self.to_dict() + if opts is None and os.path.exists(self._filename): + logger.debug("Last session default or blank. Clearing saved last session.") + os.remove(self._filename) + if opts is not None: + self._serializer.save(self._filename, opts) + logger.debug("Saved last session. (filename: '%s', opts: %s", self._filename, opts) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/stats.py b/lib/gui/stats.py deleted file mode 100644 index 2862e6f333..0000000000 --- a/lib/gui/stats.py +++ /dev/null @@ -1,501 +0,0 @@ -#!/usr/bin python3 -""" Stats functions for the GUI """ - -import logging -import time -import os -import warnings - -from math import ceil, sqrt - -import numpy as np -import tensorflow as tf -from lib.Serializer import JSONSerializer - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -def convert_time(timestamp): - """ Convert time stamp to total hours, minutes and seconds """ - hrs = int(timestamp // 3600) - if hrs < 10: - hrs = "{0:02d}".format(hrs) - mins = "{0:02d}".format((int(timestamp % 3600) // 60)) - secs = "{0:02d}".format((int(timestamp % 3600) % 60)) - return hrs, mins, secs - - -class TensorBoardLogs(): - """ Parse and return data from TensorBoard logs """ - def __init__(self, logs_folder): - self.folder_base = logs_folder - self.log_filenames = self.set_log_filenames() - - def set_log_filenames(self): - """ Set the TensorBoard log filenames for all existing sessions """ - logger.debug("Loading log filenames. base_dir: '%s'", self.folder_base) - log_filenames = dict() - for dirpath, _, filenames in os.walk(self.folder_base): - if not any(filename.startswith("events.out.tfevents") for filename in filenames): - continue - logfiles = [filename for filename in filenames - if filename.startswith("events.out.tfevents")] - # Take the last logfile, in case of previous crash - logfile = os.path.join(dirpath, sorted(logfiles)[-1]) - side, session = os.path.split(dirpath) - side = os.path.split(side)[1] - session = int(session[session.rfind("_") + 1:]) - log_filenames.setdefault(session, dict())[side] = logfile - logger.debug("logfiles: %s", log_filenames) - return log_filenames - - def get_loss(self, side=None, session=None): - """ Read the loss from the TensorBoard logs - Specify a side or a session or leave at None for all - """ - logger.debug("Getting loss: (side: %s, session: %s)", side, session) - all_loss = dict() - for sess, sides in self.log_filenames.items(): - if session is not None and sess != session: - logger.debug("Skipping session: %s", sess) - continue - loss = dict() - for sde, logfile in sides.items(): - if side is not None and sde != side: - logger.debug("Skipping side: %s", sde) - continue - for event in tf.train.summary_iterator(logfile): - for summary in event.summary.value: - if "loss" not in summary.tag: - continue - tag = summary.tag.replace("batch_", "") - loss.setdefault(tag, - dict()).setdefault(sde, - list()).append(summary.simple_value) - all_loss[sess] = loss - return all_loss - - def get_timestamps(self, session=None): - """ Read the timestamps from the TensorBoard logs - Specify a session or leave at None for all - NB: For all intents and purposes timestamps are the same for - both sides, so just read from one side """ - logger.debug("Getting timestamps") - all_timestamps = dict() - for sess, sides in self.log_filenames.items(): - if session is not None and sess != session: - logger.debug("Skipping sessions: %s", sess) - continue - for logfile in sides.values(): - timestamps = [event.wall_time - for event in tf.train.summary_iterator(logfile)] - logger.debug("Total timestamps for session %s: %s", sess, len(timestamps)) - all_timestamps[sess] = timestamps - break # break after first file read - return all_timestamps - - -class Session(): - """ The Loaded or current training session """ - def __init__(self, model_dir=None, model_name=None): - logger.debug("Initializing %s", self.__class__.__name__) - self.serializer = JSONSerializer - self.state = None - self.modeldir = model_dir # Set and reset by wrapper for training sessions - self.modelname = model_name # Set and reset by wrapper for training sessions - self.tb_logs = None - self.initialized = False - self.session_id = None # Set to specific session_id or current training session - self.summary = SessionsSummary(self) - logger.debug("Initialized %s", self.__class__.__name__) - - @property - def batchsize(self): - """ Return the session batchsize """ - return self.session["batchsize"] - - @property - def config(self): - """ Return config and other information """ - retval = {key: val for key, val in self.state["config"]} - retval["training_size"] = self.state["training_size"] - retval["input_size"] = [val[0] for key, val in self.state["inputs"].items() - if key.startswith("face")][0] - return retval - - @property - def full_summary(self): - """ Retun all sessions summary data""" - return self.summary.compile_stats() - - @property - def iterations(self): - """ Return session iterations """ - return self.session["iterations"] - - @property - def logging_disabled(self): - """ Return whether logging is disabled for this session """ - return self.session["no_logs"] - - @property - def loss(self): - """ Return loss from logs for current session """ - loss_dict = self.tb_logs.get_loss(session=self.session_id)[self.session_id] - return loss_dict - - @property - def loss_keys(self): - """ Return list of unique session loss keys """ - if self.session_id is None: - loss_keys = self.total_loss_keys - else: - loss_keys = set(loss_key for side_keys in self.session["loss_names"].values() - for loss_key in side_keys) - return list(loss_keys) - - @property - def lowest_loss(self): - """ Return the lowest average loss per save iteration seen """ - return self.state["lowest_avg_loss"] - - @property - def session(self): - """ Return current session dictionary """ - return self.state["sessions"][str(self.session_id)] - - @property - def session_ids(self): - """ Return sorted list of all existing session ids in the state file """ - return sorted([int(key) for key in self.state["sessions"].keys()]) - - @property - def timestamps(self): - """ Return timestamps from logs for current session """ - ts_dict = self.tb_logs.get_timestamps(session=self.session_id) - return ts_dict[self.session_id] - - @property - def total_batchsize(self): - """ Return all session batch sizes """ - return {int(sess_id): sess["batchsize"] - for sess_id, sess in self.state["sessions"].items()} - - @property - def total_iterations(self): - """ Return session iterations """ - return self.state["iterations"] - - @property - def total_loss(self): - """ Return collated loss for all session """ - loss_dict = dict() - for sess in self.tb_logs.get_loss().values(): - for loss_key, side_loss in sess.items(): - for side, loss in side_loss.items(): - loss_dict.setdefault(loss_key, dict()).setdefault(side, list()).extend(loss) - return loss_dict - - @property - def total_loss_keys(self): - """ Return list of unique session loss keys across all sessions """ - loss_keys = set(loss_key - for session in self.state["sessions"].values() - for loss_keys in session["loss_names"].values() - for loss_key in loss_keys) - return list(loss_keys) - - @property - def total_timestamps(self): - """ Return timestamps from logs seperated per session for all sessions """ - return self.tb_logs.get_timestamps() - - def initialize_session(self, is_training=False, session_id=None): - """ Initialize the training session """ - logger.debug("Initializing session: (is_training: %s, session_id: %s)", - is_training, session_id) - self.load_state_file() - self.tb_logs = TensorBoardLogs(os.path.join(self.modeldir, - "{}_logs".format(self.modelname))) - if is_training: - self.session_id = max(int(key) for key in self.state["sessions"].keys()) - else: - self.session_id = session_id - self.initialized = True - logger.debug("Initialized session") - - def load_state_file(self): - """ Load the current state file """ - state_file = os.path.join(self.modeldir, "{}_state.json".format(self.modelname)) - logger.debug("Loading State: '%s'", state_file) - try: - with open(state_file, "rb") as inp: - state = self.serializer.unmarshal(inp.read().decode("utf-8")) - self.state = state - logger.debug("Loaded state: %s", state) - except IOError as err: - logger.warning("Unable to load state file. Graphing disabled: %s", str(err)) - - -class SessionsSummary(): - """ Calculations for analysis summary stats """ - - def __init__(self, session): - logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session) - self.session = session - logger.debug("Initialized %s", self.__class__.__name__) - - @property - def time_stats(self): - """ Return session time stats """ - ts_data = self.session.tb_logs.get_timestamps() - time_stats = {sess_id: {"start_time": min(timestamps), - "end_time": max(timestamps), - "iterations": len(timestamps)} - for sess_id, timestamps in ts_data.items()} - return time_stats - - @property - def sessions_stats(self): - """ Return compiled stats """ - compiled = list() - for sess_idx, ts_data in self.time_stats.items(): - elapsed = ts_data["end_time"] - ts_data["start_time"] - batchsize = self.session.total_batchsize[sess_idx] - compiled.append({"session": sess_idx, - "start": ts_data["start_time"], - "end": ts_data["end_time"], - "elapsed": elapsed, - "rate": (batchsize * ts_data["iterations"]) / elapsed, - "batch": batchsize, - "iterations": ts_data["iterations"]}) - return compiled - - def compile_stats(self): - """ Compile sessions stats with totals, format and return """ - logger.debug("Compiling sessions summary data") - compiled_stats = self.sessions_stats - logger.debug("sessions_stats: %s", compiled_stats) - total_stats = self.total_stats(compiled_stats) - compiled_stats.append(total_stats) - compiled_stats = self.format_stats(compiled_stats) - logger.debug("Final stats: %s", compiled_stats) - return compiled_stats - - @staticmethod - def total_stats(sessions_stats): - """ Return total stats """ - logger.debug("Compiling Totals") - elapsed = 0 - rate = 0 - batchset = set() - iterations = 0 - total_summaries = len(sessions_stats) - for idx, summary in enumerate(sessions_stats): - if idx == 0: - starttime = summary["start"] - if idx == total_summaries - 1: - endtime = summary["end"] - elapsed += summary["elapsed"] - rate += summary["rate"] - batchset.add(summary["batch"]) - iterations += summary["iterations"] - batch = ",".join(str(bs) for bs in batchset) - totals = {"session": "Total", - "start": starttime, - "end": endtime, - "elapsed": elapsed, - "rate": rate / total_summaries, - "batch": batch, - "iterations": iterations} - logger.debug(totals) - return totals - - @staticmethod - def format_stats(compiled_stats): - """ Format for display """ - logger.debug("Formatting stats") - for summary in compiled_stats: - hrs, mins, secs = convert_time(summary["elapsed"]) - summary["start"] = time.strftime("%x %X", time.gmtime(summary["start"])) - summary["end"] = time.strftime("%x %X", time.gmtime(summary["end"])) - summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs) - summary["rate"] = "{0:.1f}".format(summary["rate"]) - return compiled_stats - - -class Calculations(): - """ Class to pull raw data for given session(s) and perform calculations """ - def __init__(self, session, display="loss", loss_keys=["loss"], selections=["raw"], - avg_samples=10, flatten_outliers=False, is_totals=False): - logger.debug("Initializing %s: (session: %s, display: %s, loss_keys: %s, selections: %s, " - "avg_samples: %s, flatten_outliers: %s, is_totals: %s", - self.__class__.__name__, session, display, loss_keys, selections, avg_samples, - flatten_outliers, is_totals) - - warnings.simplefilter("ignore", np.RankWarning) - - self.session = session - self.display = display - self.loss_keys = loss_keys - self.selections = selections - self.is_totals = is_totals - self.args = {"avg_samples": int(avg_samples), - "flatten_outliers": flatten_outliers} - self.iterations = 0 - self.stats = None - self.refresh() - logger.debug("Initialized %s", self.__class__.__name__) - - def refresh(self): - """ Refresh the stats """ - logger.debug("Refreshing") - if not self.session.initialized: - logger.warning("Session data is not initialized. Not refreshing") - return - self.iterations = 0 - self.stats = self.get_raw() - self.get_calculations() - self.remove_raw() - logger.debug("Refreshed") - - def get_raw(self): - """ Add raw data to stats dict """ - logger.debug("Getting Raw Data") - - raw = dict() - iterations = set() - if self.display.lower() == "loss": - loss_dict = self.session.total_loss if self.is_totals else self.session.loss - for loss_name, side_loss in loss_dict.items(): - if loss_name not in self.loss_keys: - continue - for side, loss in side_loss.items(): - if self.args["flatten_outliers"]: - loss = self.flatten_outliers(loss) - iterations.add(len(loss)) - raw["raw_{}_{}".format(loss_name, side)] = loss - - self.iterations = 0 if not iterations else min(iterations) - if len(iterations) > 1: - # Crop all losses to the same number of items - if self.iterations == 0: - raw = {lossname: list() for lossname in raw.keys()} - else: - raw = {lossname: loss[:self.iterations] for lossname, loss in raw.items()} - - else: # Rate calulation - data = self.calc_rate_total() if self.is_totals else self.calc_rate() - if self.args["flatten_outliers"]: - data = self.flatten_outliers(data) - self.iterations = len(data) - raw = {"raw_rate": data} - - logger.debug("Got Raw Data") - return raw - - def remove_raw(self): - """ Remove raw values from stats if not requested """ - if "raw" in self.selections: - return - logger.debug("Removing Raw Data from output") - for key in list(self.stats.keys()): - if key.startswith("raw"): - del self.stats[key] - logger.debug("Removed Raw Data from output") - - def calc_rate(self): - """ Calculate rate per iteration """ - logger.debug("Calculating rate") - batchsize = self.session.batchsize - timestamps = self.session.timestamps - iterations = range(len(timestamps) - 1) - rate = [batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations] - logger.debug("Calculated rate: Item_count: %s", len(rate)) - return rate - - def calc_rate_total(self): - """ Calculate rate per iteration - NB: For totals, gaps between sessions can be large - so time difference has to be reset for each session's - rate calculation """ - logger.debug("Calculating totals rate") - batchsizes = self.session.total_batchsize - total_timestamps = self.session.total_timestamps - rate = list() - for sess_id in sorted(total_timestamps.keys()): - batchsize = batchsizes[sess_id] - timestamps = total_timestamps[sess_id] - iterations = range(len(timestamps) - 1) - rate.extend([batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations]) - logger.debug("Calculated totals rate: Item_count: %s", len(rate)) - return rate - - @staticmethod - def flatten_outliers(data): - """ Remove the outliers from a provided list """ - logger.debug("Flattening outliers") - retdata = list() - samples = len(data) - mean = (sum(data) / samples) - limit = sqrt(sum([(item - mean)**2 for item in data]) / samples) - logger.debug("samples: %s, mean: %s, limit: %s", samples, mean, limit) - - for idx, item in enumerate(data): - if (mean - limit) <= item <= (mean + limit): - retdata.append(item) - else: - logger.debug("Item idx: %s, value: %s flattened to %s", idx, item, mean) - retdata.append(mean) - logger.debug("Flattened outliers") - return retdata - - def get_calculations(self): - """ Perform the required calculations """ - for selection in self.selections: - if selection == "raw": - continue - logger.debug("Calculating: %s", selection) - method = getattr(self, "calc_{}".format(selection)) - raw_keys = [key for key in self.stats.keys() if key.startswith("raw_")] - for key in raw_keys: - selected_key = "{}_{}".format(selection, key.replace("raw_", "")) - self.stats[selected_key] = method(self.stats[key]) - - def calc_avg(self, data): - """ Calculate rolling average """ - logger.debug("Calculating Average") - avgs = list() - presample = ceil(self.args["avg_samples"] / 2) - postsample = self.args["avg_samples"] - presample - datapoints = len(data) - - if datapoints <= (self.args["avg_samples"] * 2): - logger.info("Not enough data to compile rolling average") - return avgs - - for idx in range(0, datapoints): - if idx < presample or idx >= datapoints - postsample: - avgs.append(None) - continue - else: - avg = sum(data[idx - presample:idx + postsample]) \ - / self.args["avg_samples"] - avgs.append(avg) - logger.debug("Calculated Average") - return avgs - - @staticmethod - def calc_trend(data): - """ Compile trend data """ - logger.debug("Calculating Trend") - points = len(data) - if points < 10: - dummy = [None for i in range(points)] - return dummy - x_range = range(points) - fit = np.polyfit(x_range, data, 3) - poly = np.poly1d(fit) - trend = poly(x_range) - logger.debug("Calculated Trend") - return trend diff --git a/lib/gui/statusbar.py b/lib/gui/statusbar.py deleted file mode 100644 index 86524e77fd..0000000000 --- a/lib/gui/statusbar.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin python3 -""" Status bar for the GUI """ - -import tkinter as tk -from tkinter import ttk - - -class StatusBar(ttk.Frame): - """ Status Bar for displaying the Status Message and - Progress Bar """ - - def __init__(self, parent): - ttk.Frame.__init__(self, parent) - self.pack(side=tk.BOTTOM, padx=10, pady=2, fill=tk.X, expand=False) - - self.status_message = tk.StringVar() - self.pbar_message = tk.StringVar() - self.pbar_position = tk.IntVar() - - self.status_message.set("Ready") - - self.status() - self.pbar = self.progress_bar() - - def status(self): - """ Place Status into bottom bar """ - statusframe = ttk.Frame(self) - statusframe.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=False) - - lbltitle = ttk.Label(statusframe, text="Status:", width=6, anchor=tk.W) - lbltitle.pack(side=tk.LEFT, expand=False) - - lblstatus = ttk.Label(statusframe, - width=20, - textvariable=self.status_message, - anchor=tk.W) - lblstatus.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=True) - - def progress_bar(self): - """ Place progress bar into bottom bar """ - progressframe = ttk.Frame(self) - progressframe.pack(side=tk.RIGHT, anchor=tk.E, fill=tk.X) - - lblmessage = ttk.Label(progressframe, textvariable=self.pbar_message) - lblmessage.pack(side=tk.LEFT, padx=3, fill=tk.X, expand=True) - - pbar = ttk.Progressbar(progressframe, - length=200, - variable=self.pbar_position, - maximum=1000, - mode="determinate") - pbar.pack(side=tk.LEFT, padx=2, fill=tk.X, expand=True) - pbar.pack_forget() - return pbar - - def progress_start(self, mode): - """ Set progress bar mode and display """ - self.progress_set_mode(mode) - self.pbar.pack() - - def progress_stop(self): - """ Reset progress bar and hide """ - self.pbar_message.set("") - self.pbar_position.set(0) - self.progress_set_mode("determinate") - self.pbar.pack_forget() - - def progress_set_mode(self, mode): - """ Set the progress bar mode """ - self.pbar.config(mode=mode) - if mode == "indeterminate": - self.pbar.config(maximum=100) - self.pbar.start() - else: - self.pbar.stop() - self.pbar.config(maximum=1000) - - def progress_update(self, message, position, update_position=True): - """ Update the GUIs progress bar and position """ - self.pbar_message.set(message) - if update_position: - self.pbar_position.set(position) diff --git a/lib/gui/theme.py b/lib/gui/theme.py new file mode 100644 index 0000000000..5f84d08fa3 --- /dev/null +++ b/lib/gui/theme.py @@ -0,0 +1,588 @@ +#!/usr/bin/env python3 +""" functions for implementing themes in Faceswap's GUI """ +import logging +import os +import tkinter as tk +from tkinter import ttk + +import numpy as np + +from lib.serializer import get_serializer +from lib.utils import FaceswapError, get_module_objects + + +logger = logging.getLogger(__name__) + + +class Style(): + """ Set the overarching theme and customize widgets. + + Parameters + ---------- + default_font: tuple + The name and size of the default font + root: :class:`tkinter.Tk` + The root tkinter object + path_cache: str + The path to the GUI's cache + """ + def __init__(self, default_font, root, path_cache): + self._root = root + self._font = default_font + default = os.path.join(path_cache, "themes", "default.json") + self._user_theme = get_serializer("json").load(default) + self._style = ttk.Style() + self._widgets = _Widgets(self._style) + self._set_styles() + + @property + def user_theme(self): + """ dict: The currently selected user theme. """ + return self._user_theme + + def _set_styles(self): + """ Configure widget theme and styles """ + self._config_settings_group() + # Command page + theme = self._user_theme["command_tabs"] + self._widgets.notebook("CPanel", + theme["frame_border"], + theme["tab_color"], + theme["tab_selected"], + theme["tab_hover"]) + + # Settings Popup + self._style.configure("SPanel.Header1.TLabel", + font=(self._font[0], self._font[1] + 4, "bold")) + self._style.configure("SPanel.Header2.TLabel", + font=(self._font[0], self._font[1] + 2, "bold")) + # Console + theme = self._user_theme["console"] + console_sbar = tuple(tuple(theme[f"scrollbar_{area}_{state}"] + for state in ("normal", "disabled", "active")) + for area in ("background", "foreground", "border")) + self._widgets.scrollbar("Console", + theme["scrollbar_trough"], + theme["scrollbar_border"], + *console_sbar) + self._widgets.frame("Console", + theme["background_color"], + theme["border_color"], + borderwidth=1) + + def _config_settings_group(self): + """ Configures the style of the control panel entry boxes. Used for inputting Faceswap + options or controlling plugin settings. """ + theme = self._user_theme["group_panel"] + for panel_type in ("CPanel", "SPanel"): + if panel_type == "SPanel": # Merge in Settings Panel overrides + theme = {**theme, **self._user_theme["group_settings"]} + self._style.configure(f"{panel_type}.Holder.TFrame", + background=theme["panel_background"]) + # Header Colors on option/group controls + self._style.configure(f"{panel_type}.Group.TLabelframe.Label", + foreground=theme["header_color"]) + self._style.configure(f"{panel_type}.Groupheader.TLabel", + background=theme["header_color"], + foreground=theme["header_font"], + font=(self._font[0], self._font[1], "bold")) + # Widgets and specific areas + self._group_panel_widgets(panel_type, theme) + self._group_panel_infoheader(panel_type, theme) + self._widgets.slider(panel_type, + theme["control_color"], + theme["control_active"], + self._user_theme["group_panel"]["group_background"]) + backgrounds = (theme["control_color"], + theme["control_disabled"], + theme["control_active"]) + foregrounds = (theme["control_disabled"], + theme["control_color"], + theme["control_disabled"]) + borders = (theme["header_color"], theme["control_color"], theme["header_color"]) + self._widgets.scrollbar(panel_type, + theme["scrollbar_trough"], + theme["scrollbar_border"], + backgrounds, + foregrounds, + borders) + self._widgets.combobox(panel_type, + theme["control_color"], + theme["control_active"], + theme["control_disabled"], + theme["header_color"], + theme["group_background"], + theme["group_font"]) + + def _group_panel_infoheader(self, key, theme): + """ Set the theme for the information header box that appears at the top of each group + panel + + Parameters + ---------- + key: str + The section that the slider will belong to + theme: dict + The user configuration theme options + """ + self._widgets.frame(f"{key}.InfoHeader", + theme["info_color"], + theme["info_border"], + borderwidth=1) + + self._style.configure(f"{key}.InfoHeader.TLabel", + background=theme["info_color"], + foreground=theme["info_font"], + font=(self._font[0], self._font[1], "bold")) + self._style.configure(f"{key}.InfoBody.TLabel", + background=theme["info_color"], + foreground=theme["info_font"]) + + def _group_panel_widgets(self, key, theme): + """ Configure the foreground and background colors of common widgets. + + Parameters + ---------- + key: str + The section that the slider will belong to + theme: dict + The user configuration theme options + """ + # Put a border on a group's sub-frame + self._widgets.frame(f"{key}.Subframe.Group", + theme["group_background"], + theme["group_border"], + borderwidth=1) + + # Background and Foreground of widgets and labels + for lbl in ["TLabel", "TFrame", "TLabelframe", "TCheckbutton", "TRadiobutton", + "TLabelframe.Label"]: + self._style.configure(f"{key}.Group.{lbl}", + background=theme["group_background"], + foreground=theme["group_font"]) + + +class _Widgets(): + """ Create custom ttk widget layouts for themed widgets. + + Parameters + ---------- + style: :class:`ttk.Style` + The master style object + """ + def __init__(self, style): + self._images = _TkImage() + self._style = style + + def combobox(self, key, control_color, active_color, arrow_color, control_border, field_color, + field_border): + """ Combo-boxes are fairly complex to style. + + Parameters + ---------- + key: str + The section that the slider will belong to + control_color: str + The color of inactive combo pull down button + active_color: str + The color of combo pull down button when it is hovered or pressed + arrow_color: str + The color of the combo pull down arrow + control_border: str + The color of the combo pull down button border + field_color: str + The color of the input field's background + field_border: str + The color of the input field's border + """ + # All the stock down arrow images are bad + images = {} + for state in ("active", "normal"): + images[f"arrow_{state}"] = self._images.get_image( + (20, 20), + control_color if state == "normal" else active_color, + foreground=arrow_color, + pattern="arrow", + thickness=2, + border_width=1, + border_color=control_border) + + self._style.element_create(f"{key}.Combobox.downarrow", + "image", + images["arrow_normal"], + ("active", images["arrow_active"]), + ("pressed", images["arrow_active"]), + sticky="e", + width=20) + + # None of the themes give us the border control we need, so create an image + box = self._images.get_image((16, 16), + field_color, + border_width=1, + border_color=field_border) + self._style.element_create(f"{key}.Combobox.field", + "image", + box, + border=1, + padding=(6, 0, 0, 0)) + + # Set a layout so we can access required params + self._style.layout(f"{key}.TCombobox", [ + (f"{key}.Combobox.field", { + "children": [ + (f"{key}.Combobox.downarrow", {"side": "right", "sticky": "ns"}), + (f"{key}.Combobox.padding", { + "expand": "1", + "sticky": "nswe", + "children": [(f"{key}.Combobox.focus", { + "expand": "1", + "sticky": "nswe", + "children": [(f"{key}.Combobox.textarea", {"sticky": "nswe"})]})]})], + "sticky": "nswe"})]) + + def frame(self, key, background, border, borderwidth=1): + """ Create a custom frame widget for controlling background and border colors. + + Parameters + ---------- + key: str + The section that the Frame will belong to + background: str + The hex code for the background of the frame + border: str + The hex code for the border of the frame + """ + self._style.element_create(f"{key}.Frame.border", "from", "alt") + self._style.layout(f"{key}.TFrame", + [(f"{key}.Frame.border", {"sticky": "nswe"})]) + self._style.configure(f"{key}.TFrame", + background=background, + relief=tk.SOLID, + borderwidth=borderwidth, + bordercolor=border) + + def notebook(self, key, frame_border, tab_color, tab_selected, tab_hover): + """ Create a custom notebook widget so we can control the colors. + + Parameters + ---------- + key: str + The section that the scrollbar will belong to + frame_border: str + The border color around the tab's contents + tab_color: str + The color of non selected tabs + tab_selected: str + The color of selected tabs + tab_hover: str + The color of hovered tabs + """ + # TODO This lags out the GUI, so need to test where this is failing prior to implementing + client = self._images.get_image((8, 8), frame_border) + self._style.element_create(f"{key}.Notebook.client", "image", client, border=1) + + tabs = [self._images.get_image((8, 8), color) + for color in (tab_color, tab_selected, tab_hover)] + + self._style.element_create(f"{key}.Notebook.tab", + "image", + tabs[0], + ("selected", tabs[1]), + ("active", tabs[2]), + padding=(0, 2, 0, 0), + border=3) + + self._style.layout(f"{key}.TNotebook", [(f"{key}.Notebook.client", {"sticky": "nswe"})]) + self._style.layout(f"{key}.TNotebook.Tab", [ + (f"{key}.Notebook.tab", { + "sticky": "nswe", + "children": [ + ("Notebook.padding", { + "side": "top", + "sticky": "nswe", + "children": [ + ("Notebook.focus", { + "side": "top", + "sticky": "nswe", + "children": [("Notebook.label", {"side": "top", "sticky": ""})] + })] + })] + })]) + + self._style.configure(f"{key}.TNotebook", tabmargins=(0, 2, 0, 0)) + self._style.configure(f"{key}.TNotebook.Tab", padding=(6, 2, 6, 2), expand=(0, 0, 2)) + self._style.configure(f"{key}.TNotebook.Tab", expand=("selected", (1, 2, 4, 2))) + + def scrollbar(self, # pylint:disable=too-many-locals + key, + trough_color, + border_color, + control_backgrounds, + control_foregrounds, + control_borders): + """ Create a custom scroll bar widget so we can control the colors. + + Parameters + ---------- + key: str + The section that the scrollbar will belong to + theme: dict + The theme options for a scroll bar. The dict should contain the keys: `background`, + `foreground`, `border`, with each item containing a tuple of the colors for the states + `normal`, `disabled` and `active` respectively + trough_color: str + The hex code for the scrollbar trough color + border_color: str + The hex code for the scrollbar border color + control_backgrounds: tuple + Tuple of length 3 for the button and slider colors for the states `normal`, + `disabled`, `active` + control_foregrounds: tuple + Tuple of length 3 for the button arrow colors for the states `normal`, + `disabled`, `active` + control_borders: tuple + Tuple of length 3 for the borders of the buttons and slider for the states `normal`, + `disabled`, `active` + """ + logger.debug("Creating scrollbar: (key: %s, trough_color: %s, border_color: %s, " + "control_backgrounds: %s, control_foregrounds: %s, control_borders: %s)", + key, trough_color, border_color, control_backgrounds, control_foregrounds, + control_borders) + images = {} + for idx, state in enumerate(("normal", "disabled", "active")): + # Create arrow and slider widgets for each state + img_args = ((16, 16), control_backgrounds[idx]) + for dir_ in ("up", "down"): + images[f"img_{dir_}_{state}"] = self._images.get_image( + *img_args, + foreground=control_foregrounds[idx], + pattern="arrow", + direction=dir_, + thickness=4, + border_width=1, + border_color=control_borders[idx]) + images[f"img_thumb_{state}"] = self._images.get_image( + *img_args, + border_width=1, + border_color=control_borders[idx]) + + for element in ("thumb", "uparrow", "downarrow"): + # Create the elements with the new images + lookup = element.replace("arrow", "") + args = (f"{key}.Vertical.Scrollbar.{element}", + "image", + images[f"img_{lookup}_normal"], + ("disabled", images[f"img_{lookup}_disabled"]), + ("pressed !disabled", images[f"img_{lookup}_active"]), + ("active !disabled", images[f"img_{lookup}_active"])) + kwargs = {"border": 1, "sticky": "ns"} if element == "thumb" else {} + self._style.element_create(*args, **kwargs) + + # Get a configurable trough + self._style.element_create(f"{key}.Vertical.Scrollbar.trough", "from", "clam") + + self._style.layout( + f"{key}.Vertical.TScrollbar", + [(f"{key}.Vertical.Scrollbar.trough", { + "sticky": "ns", + "children": [ + (f"{key}.Vertical.Scrollbar.uparrow", {"side": "top", "sticky": ""}), + (f"{key}.Vertical.Scrollbar.downarrow", {"side": "bottom", "sticky": ""}), + (f"{key}.Vertical.Scrollbar.thumb", {"expand": "1", "sticky": "nswe"}) + ] + })]) + self._style.configure(f"{key}.Vertical.TScrollbar", + troughcolor=trough_color, + bordercolor=border_color, + troughrelief=tk.SOLID, + troughborderwidth=1) + + def slider(self, key, control_color, active_color, trough_color): + """ Take a copy of the default ttk.Scale widget and replace the slider element with a + version we can control the color and shape of. + + Parameters + ---------- + key: str + The section that the slider will belong to + control_color: str + The color of inactive slider and up down buttons + active_color: str + The color of slider and up down buttons when they are hovered or pressed + trough_color: str + The color of the scroll bar's trough + """ + img_slider = self._images.get_image((10, 25), control_color) + img_slider_alt = self._images.get_image((10, 25), active_color) + + self._style.element_create(f"{key}.Horizontal.Scale.trough", "from", "alt") + self._style.element_create(f"{key}.Horizontal.Scale.slider", + "image", + img_slider, + ("active", img_slider_alt)) + + self._style.layout( + f"{key}.Horizontal.TScale", + [(f"{key}.Scale.focus", { + "expand": "1", + "sticky": "nswe", + "children": [ + (f"{key}.Horizontal.Scale.trough", { + "expand": "1", + "sticky": "nswe", + "children": [ + (f"{key}.Horizontal.Scale.track", {"sticky": "we"}), + (f"{key}.Horizontal.Scale.slider", {"side": "left", "sticky": ""}) + ] + }) + ] + })]) + + self._style.configure(f"{key}.Horizontal.TScale", + background=trough_color, + groovewidth=4, + troughcolor=trough_color) + + +class _TkImage(): + """ Create a tk image for a given pattern and shape. + """ + def __init__(self): + self._cache = [] # We need to keep a reference to every image created + + # Numpy array patterns + @classmethod + def _get_solid(cls, dimensions): + """ Return a solid background color pattern. + + Parameters + ---------- + dimensions: tuple + The (`width`, `height`) of the desired tk image + + Returns + ------- + :class:`numpy.ndarray` + A 2D, UINT8 array of shape (height, width) of all zeros + """ + return np.zeros((dimensions[1], dimensions[0]), dtype="uint8") + + @classmethod + def _get_arrow(cls, dimensions, thickness, direction): + """ Return a background color with a "v" arrow in foreground color + + Parameters + ---------- + dimensions: tuple + The (`width`, `height`) of the desired tk image + thickness: int + The thickness of the pattern to be drawn + direction: ["left", "up", "right", "down"] + The direction that the pattern should be facing + + Returns + ------- + :class:`numpy.ndarray` + A 2D, UINT8 array of shape (height, width) of all zeros + """ + square_size = min(dimensions[1], dimensions[0]) + if square_size < 16 or any(dim % 2 != 0 for dim in dimensions): + raise FaceswapError("For arrow image, the minimum size across any axis must be 8 and " + "dimensions must all be divisible by 2") + crop_size = (square_size // 16) * 16 + draw_rows = int(6 * crop_size / 16) + start_row = dimensions[1] // 2 - draw_rows // 2 + initial_indent = 2 * (crop_size // 16) + (dimensions[0] - crop_size) // 2 + + retval = np.zeros((dimensions[1], dimensions[0]), dtype="uint8") + for i in range(start_row, start_row + draw_rows): + indent = initial_indent + i - start_row + join = (min(indent + thickness, dimensions[0] // 2), + max(dimensions[0] - indent - thickness, dimensions[0] // 2)) + retval[i, np.r_[indent:join[0], join[1]:dimensions[0] - indent]] = 1 + if direction in ("right", "left"): + retval = np.rot90(retval) + if direction in ("up", "left"): + retval = np.flip(retval) + return retval + + def get_image(self, + dimensions, + background, + foreground=None, + pattern="solid", + border_width=0, + border_color=None, + thickness=2, + direction="down"): + """ Obtain a tk image. + + Generates the requested image and stores in cache. + + Parameters + ---------- + dimensions: tuple + The (`width`, `height`) of the desired tk image + background: str + The hex code for the background (main) color + foreground: str, optional + The hex code for the background (secondary) color. If ``None`` is provided then a + solid background color image will be returned. Default: ``None`` + pattern: ["solid", "arrow"], optional + The pattern to generate for the tk image. Default: `"solid"` + border_width: int, optional + The thickness of foreground border to apply. Default: 0 + border_color: int, optional + The color of the border, if one is to be created. Default: ``None`` (use foreground + color) + thickness: int, optional + The thickness of the pattern to be drawn. Default: `2` + direction: ["left", "up", "right", "down"], optional + The direction that the pattern should be facing. Default: `"down"` + """ + foreground = foreground if foreground else background + border_color = border_color if border_color else foreground + + args = [dimensions] + if pattern.lower() == "arrow": + args.extend([thickness, direction]) + if pattern.lower() == "border": + args.extend([thickness]) + pattern = getattr(self, f"_get_{pattern.lower()}")(*args) + + if border_width > 0: + border = np.ones_like(pattern) + 1 + border[border_width:-border_width, + border_width:-border_width] = pattern[border_width:-border_width, + border_width:-border_width] + pattern = border + + return self._create_photoimage(background, foreground, border_color, pattern) + + def _create_photoimage(self, background, foreground, border, pattern): + """ Create a tkinter PhotoImage and populate it with the requested color pattern. + + Parameters + ---------- + background: str + The hex code for the background (main) color + foreground: str + The hex code for the foreground (secondary) color + border: str + The hex code for the border color + pattern: class:`numpy.ndarray` + The pattern for the final image with background colors marked as 0 and foreground + colors marked as 1 + """ + image = tk.PhotoImage(width=pattern.shape[1], height=pattern.shape[0]) + self._cache.append(image) + + pixels = "} {".join(" ".join(foreground + if pxl == 1 else border if pxl == 2 else background + for pxl in row) + for row in pattern) + image.put("{" + pixels + "}") + return image + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/tooltip.py b/lib/gui/tooltip.py deleted file mode 100755 index d89e8eb58c..0000000000 --- a/lib/gui/tooltip.py +++ /dev/null @@ -1,165 +0,0 @@ -#!/usr/bin python3 -""" Tooltip. Pops up help messages for the GUI """ -import platform -import tkinter as tk - - -class Tooltip: - """ - Create a tooltip for a given widget as the mouse goes on it. - - Adapted from StackOverflow: - - http://stackoverflow.com/questions/3221956/ - what-is-the-simplest-way-to-make-tooltips- - in-tkinter/36221216#36221216 - - http://www.daniweb.com/programming/software-development/ - code/484591/a-tooltip-class-for-tkinter - - - Originally written by vegaseat on 2014.09.09. - - - Modified to include a delay time by Victor Zaccardo on 2016.03.25. - - - Modified - - to correct extreme right and extreme bottom behavior, - - to stay inside the screen whenever the tooltip might go out on - the top but still the screen is higher than the tooltip, - - to use the more flexible mouse positioning, - - to add customizable background color, padding, waittime and - wraplength on creation - by Alberto Vassena on 2016.11.05. - - Tested on Ubuntu 16.04/16.10, running Python 3.5.2 - - """ - - def __init__(self, widget, - *, - background="#FFFFEA", - pad=(5, 3, 5, 3), - text="widget info", - waittime=400, - wraplength=250): - - self.waittime = waittime # in milliseconds, originally 500 - self.wraplength = wraplength # in pixels, originally 180 - self.widget = widget - self.text = text - self.widget.bind("", self.on_enter) - self.widget.bind("", self.on_leave) - self.widget.bind("", self.on_leave) - self.background = background - self.pad = pad - self.ident = None - self.topwidget = None - - def on_enter(self, event=None): - """ Schedule on an enter event """ - self.schedule() - - def on_leave(self, event=None): - """ Unschedule on a leave event """ - self.unschedule() - self.hide() - - def schedule(self): - """ Show the tooltip after wait period """ - self.unschedule() - self.ident = self.widget.after(self.waittime, self.show) - - def unschedule(self): - """ Hide the tooltip """ - id_ = self.ident - self.ident = None - if id_: - self.widget.after_cancel(id_) - - def show(self): - """ Show the tooltip """ - def tip_pos_calculator(widget, label, - *, - tip_delta=(10, 5), pad=(5, 3, 5, 3)): - """ Calculate the tooltip position """ - - s_width, s_height = widget.winfo_screenwidth(), widget.winfo_screenheight() - - width, height = (pad[0] + label.winfo_reqwidth() + pad[2], - pad[1] + label.winfo_reqheight() + pad[3]) - - mouse_x, mouse_y = widget.winfo_pointerxy() - - x_1, y_1 = mouse_x + tip_delta[0], mouse_y + tip_delta[1] - x_2, y_2 = x_1 + width, y_1 + height - - x_delta = x_2 - s_width - if x_delta < 0: - x_delta = 0 - y_delta = y_2 - s_height - if y_delta < 0: - y_delta = 0 - - offscreen = (x_delta, y_delta) != (0, 0) - - if offscreen: - - if x_delta: - x_1 = mouse_x - tip_delta[0] - width - - if y_delta: - y_1 = mouse_y - tip_delta[1] - height - - offscreen_again = y_1 < 0 # out on the top - - if offscreen_again: - # No further checks will be done. - - # TIP: - # A further mod might auto-magically augment the - # wraplength when the tooltip is too high to be - # kept inside the screen. - y_1 = 0 - - return x_1, y_1 - - background = self.background - pad = self.pad - widget = self.widget - - # creates a toplevel window - self.topwidget = tk.Toplevel(widget) - if platform.system() == "Darwin": - # For Mac OS - self.topwidget.tk.call("::tk::unsupported::MacWindowStyle", - "style", self.topwidget._w, - "help", "none") - - # Leaves only the label and removes the app window - self.topwidget.wm_overrideredirect(True) - - win = tk.Frame(self.topwidget, - background=background, - borderwidth=0) - label = tk.Label(win, - text=self.text, - justify=tk.LEFT, - background=background, - relief=tk.SOLID, - borderwidth=0, - wraplength=self.wraplength) - - label.grid(padx=(pad[0], pad[2]), - pady=(pad[1], pad[3]), - sticky=tk.NSEW) - win.grid() - - xpos, ypos = tip_pos_calculator(widget, label) - - self.topwidget.wm_geometry("+%d+%d" % (xpos, ypos)) - - def hide(self): - """ Hide the tooltip """ - topwidget = self.topwidget - if topwidget: - topwidget.destroy() - self.topwidget = None diff --git a/lib/gui/utils.py b/lib/gui/utils.py deleted file mode 100644 index 5632dee7ee..0000000000 --- a/lib/gui/utils.py +++ /dev/null @@ -1,604 +0,0 @@ -#!/usr/bin/env python3 -""" Utility functions for the GUI """ -import logging -import os -import platform -import sys -import tkinter as tk - -from tkinter import filedialog, ttk -from PIL import Image, ImageTk - -from lib.Serializer import JSONSerializer - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name -_CONFIG = None -_IMAGES = None - - -def initialize_config(cli_opts, scaling_factor, pathcache, statusbar, session): - """ Initialize the config and add to global constant """ - global _CONFIG # pylint: disable=global-statement - if _CONFIG is not None: - return - logger.debug("Initializing config: (cli_opts: %s, tk_vars: %s, pathcache: %s, statusbar: %s, " - "session: %s)", cli_opts, scaling_factor, pathcache, statusbar, session) - _CONFIG = Config(cli_opts, scaling_factor, pathcache, statusbar, session) - - -def get_config(): - """ return the _CONFIG constant """ - return _CONFIG - - -def initialize_images(): - """ Initialize the config and add to global constant """ - global _IMAGES # pylint: disable=global-statement - if _IMAGES is not None: - return - logger.debug("Initializing images") - _IMAGES = Images() - - -def get_images(): - """ return the _CONFIG constant """ - return _IMAGES - - -def set_slider_rounding(value, var, d_type, round_to, min_max): - """ Set the underlying variable to correct number based on slider rounding """ - if d_type == float: - var.set(round(float(value), round_to)) - else: - steps = range(min_max[0], min_max[1] + round_to, round_to) - value = min(steps, key=lambda x: abs(x - int(float(value)))) - var.set(value) - - -class FileHandler(): - """ Raise a filedialog box and capture input """ - - def __init__(self, handletype, filetype, command=None, action=None, - variable=None): - logger.debug("Initializing %s: (Handletype: '%s', filetype: '%s', command: '%s', action: " - "'%s', variable: %s)", self.__class__.__name__, handletype, filetype, command, - action, variable) - self.handletype = handletype - all_files = ("All files", "*.*") - self.filetypes = {"default": (all_files,), - "alignments": (("JSON", "*.json"), - ("Pickle", "*.p"), - ("YAML", "*.yaml"), - all_files), - "config": (("Faceswap config files", "*.fsw"), all_files), - "csv": (("Comma separated values", "*.csv"), all_files), - "image": (("Bitmap", "*.bmp"), - ("JPG", "*.jpeg", "*.jpg"), - ("PNG", "*.png"), - ("TIFF", "*.tif", "*.tiff"), - all_files), - "state": (("State files", "*.json"), all_files), - "log": (("Log files", "*.log"), all_files), - "video": (("Audio Video Interleave", "*.avi"), - ("Flash Video", "*.flv"), - ("Matroska", "*.mkv"), - ("MOV", "*.mov"), - ("MP4", "*.mp4"), - ("MPEG", "*.mpeg"), - ("WebM", "*.webm"), - all_files)} - self.contexts = { - "effmpeg": { - "input": {"extract": "filename", - "gen-vid": "dir", - "get-fps": "filename", - "get-info": "filename", - "mux-audio": "filename", - "rescale": "filename", - "rotate": "filename", - "slice": "filename"}, - "output": {"extract": "dir", - "gen-vid": "savefilename", - "get-fps": "nothing", - "get-info": "nothing", - "mux-audio": "savefilename", - "rescale": "savefilename", - "rotate": "savefilename", - "slice": "savefilename"} - } - } - self.defaults = self.set_defaults() - self.kwargs = self.set_kwargs(filetype, command, action, variable) - self.retfile = getattr(self, self.handletype.lower())() - logger.debug("Initialized %s", self.__class__.__name__) - - def set_defaults(self): - """ Set the default filetype to be first in list of filetypes, - or set a custom filetype if the first is not correct """ - defaults = {key: val[0][1].replace("*", "") - for key, val in self.filetypes.items()} - defaults["default"] = None - defaults["video"] = ".mp4" - defaults["image"] = ".png" - logger.debug(defaults) - return defaults - - def set_kwargs(self, filetype, command, action, variable=None): - """ Generate the required kwargs for the requested browser """ - logger.debug("Setting Kwargs: (filetype: '%s', command: '%s': action: '%s', " - "variable: '%s')", filetype, command, action, variable) - kwargs = dict() - if self.handletype.lower() == "context": - self.set_context_handletype(command, action, variable) - - if self.handletype.lower() in ( - "open", "save", "filename", "savefilename"): - kwargs["filetypes"] = self.filetypes[filetype] - if self.defaults.get(filetype, None): - kwargs['defaultextension'] = self.defaults[filetype] - if self.handletype.lower() == "save": - kwargs["mode"] = "w" - if self.handletype.lower() == "open": - kwargs["mode"] = "r" - logger.debug("Set Kwargs: %s", kwargs) - return kwargs - - def set_context_handletype(self, command, action, variable): - """ Choose the correct file browser action based on context """ - if self.contexts[command].get(variable, None) is not None: - handletype = self.contexts[command][variable][action] - else: - handletype = self.contexts[command][action] - logger.debug(handletype) - self.handletype = handletype - - def open(self): - """ Open a file """ - logger.debug("Popping Open browser") - return filedialog.askopenfile(**self.kwargs) - - def save(self): - """ Save a file """ - logger.debug("Popping Save browser") - return filedialog.asksaveasfile(**self.kwargs) - - def dir(self): - """ Get a directory location """ - logger.debug("Popping Dir browser") - return filedialog.askdirectory(**self.kwargs) - - def savedir(self): - """ Get a save dir location """ - logger.debug("Popping SaveDir browser") - return filedialog.askdirectory(**self.kwargs) - - def filename(self): - """ Get an existing file location """ - logger.debug("Popping Filename browser") - return filedialog.askopenfilename(**self.kwargs) - - def savefilename(self): - """ Get a save file location """ - logger.debug("Popping SaveFilename browser") - return filedialog.asksaveasfilename(**self.kwargs) - - @staticmethod - def nothing(): # pylint: disable=useless-return - """ Method that does nothing, used for disabling open/save pop up """ - logger.debug("Popping Nothing browser") - return - - -class Images(): - """ Holds locations of images and actual images - - Don't call directly. Call get_images() - """ - - def __init__(self): - logger.debug("Initializing %s", self.__class__.__name__) - pathcache = get_config().pathcache - self.pathicons = os.path.join(pathcache, "icons") - self.pathpreview = os.path.join(pathcache, "preview") - self.pathoutput = None - self.previewoutput = None - self.previewtrain = dict() - self.errcount = 0 - self.icons = dict() - self.icons["folder"] = ImageTk.PhotoImage(file=os.path.join( - self.pathicons, "open_folder.png")) - self.icons["load"] = ImageTk.PhotoImage(file=os.path.join( - self.pathicons, "open_file.png")) - self.icons["context"] = ImageTk.PhotoImage(file=os.path.join( - self.pathicons, "open_file.png")) - self.icons["save"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "save.png")) - self.icons["reset"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "reset.png")) - self.icons["clear"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "clear.png")) - self.icons["graph"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "graph.png")) - self.icons["zoom"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "zoom.png")) - self.icons["move"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "move.png")) - self.icons["favicon"] = ImageTk.PhotoImage(file=os.path.join(self.pathicons, "logo.png")) - logger.debug("Initialized %s: (icons: %s)", self.__class__.__name__, self.icons) - - def delete_preview(self): - """ Delete the preview files """ - logger.debug("Deleting previews") - for item in os.listdir(self.pathpreview): - if item.startswith(".gui_training_preview") and item.endswith(".jpg"): - fullitem = os.path.join(self.pathpreview, item) - logger.debug("Deleting: '%s'", fullitem) - os.remove(fullitem) - self.clear_image_cache() - - def clear_image_cache(self): - """ Clear all cached images """ - logger.debug("Clearing image cache") - self.pathoutput = None - self.previewoutput = None - self.previewtrain = dict() - - @staticmethod - def get_images(imgpath): - """ Get the images stored within the given directory """ - logger.trace("Getting images: '%s'", imgpath) - if not os.path.isdir(imgpath): - logger.debug("Folder does not exist") - return None - files = [os.path.join(imgpath, f) - for f in os.listdir(imgpath) if f.endswith((".png", ".jpg"))] - logger.trace("Image files: %s", files) - return files - - def load_latest_preview(self): - """ Load the latest preview image for extract and convert """ - logger.trace("Loading preview image") - imagefiles = self.get_images(self.pathoutput) - if not imagefiles or len(imagefiles) == 1: - logger.debug("No preview to display") - self.previewoutput = None - return - # Get penultimate file so we don't accidentally - # load a file that is being saved - show_file = sorted(imagefiles, key=os.path.getctime)[-2] - img = Image.open(show_file) - img.thumbnail((768, 432)) - logger.trace("Displaying preview: '%s'", show_file) - self.previewoutput = (img, ImageTk.PhotoImage(img)) - - def load_training_preview(self): - """ Load the training preview images """ - logger.trace("Loading Training preview images") - imagefiles = self.get_images(self.pathpreview) - modified = None - if not imagefiles: - logger.debug("No preview to display") - self.previewtrain = dict() - return - for img in imagefiles: - modified = os.path.getmtime(img) if modified is None else modified - name = os.path.basename(img) - name = os.path.splitext(name)[0] - name = name[name.rfind("_") + 1:].title() - try: - logger.trace("Displaying preview: '%s'", img) - size = self.get_current_size(name) - self.previewtrain[name] = [Image.open(img), None, modified] - self.resize_image(name, size) - self.errcount = 0 - except ValueError: - # This is probably an error reading the file whilst it's - # being saved so ignore it for now and only pick up if - # there have been multiple consecutive fails - logger.warning("Unable to display preview: (image: '%s', attempt: %s)", - img, self.errcount) - if self.errcount < 10: - self.errcount += 1 - else: - logger.error("Error reading the preview file for '%s'", img) - print("Error reading the preview file for {}".format(name)) - self.previewtrain[name] = None - - def get_current_size(self, name): - """ Return the size of the currently displayed image """ - logger.trace("Getting size: '%s'", name) - if not self.previewtrain.get(name, None): - return None - img = self.previewtrain[name][1] - if not img: - return None - logger.trace("Got size: (name: '%s', width: '%s', height: '%s')", - name, img.width(), img.height()) - return img.width(), img.height() - - def resize_image(self, name, framesize): - """ Resize the training preview image - based on the passed in frame size """ - logger.trace("Resizing image: (name: '%s', framesize: %s", name, framesize) - displayimg = self.previewtrain[name][0] - if framesize: - frameratio = float(framesize[0]) / float(framesize[1]) - imgratio = float(displayimg.size[0]) / float(displayimg.size[1]) - - if frameratio <= imgratio: - scale = framesize[0] / float(displayimg.size[0]) - size = (framesize[0], int(displayimg.size[1] * scale)) - else: - scale = framesize[1] / float(displayimg.size[1]) - size = (int(displayimg.size[0] * scale), framesize[1]) - logger.trace("Scaling: (scale: %s, size: %s", scale, size) - - # Hacky fix to force a reload if it happens to find corrupted - # data, probably due to reading the image whilst it is partially - # saved. If it continues to fail, then eventually raise. - for i in range(0, 1000): - try: - displayimg = displayimg.resize(size, Image.ANTIALIAS) - except OSError: - if i == 999: - raise - else: - continue - break - - self.previewtrain[name][1] = ImageTk.PhotoImage(displayimg) - - -class ContextMenu(tk.Menu): # pylint: disable=too-many-ancestors - """ Pop up menu """ - def __init__(self, widget): - logger.debug("Initializing %s: (widget_class: '%s')", - self.__class__.__name__, widget.winfo_class()) - super().__init__(tearoff=0) - self.widget = widget - self.standard_actions() - logger.debug("Initialized %s", self.__class__.__name__) - - def standard_actions(self): - """ Standard menu actions """ - self.add_command(label="Cut", command=lambda: self.widget.event_generate("<>")) - self.add_command(label="Copy", command=lambda: self.widget.event_generate("<>")) - self.add_command(label="Paste", command=lambda: self.widget.event_generate("<>")) - self.add_separator() - self.add_command(label="Select all", command=self.select_all) - - def cm_bind(self): - """ Bind the menu to the widget's Right Click event """ - button = "" if platform.system() == "Darwin" else "" - logger.debug("Binding '%s' to '%s'", button, self.widget.winfo_class()) - x_offset = int(34 * get_config().scaling_factor) - self.widget.bind(button, - lambda event: self.tk_popup(event.x_root + x_offset, event.y_root, 0)) - - def select_all(self): - """ Select all for Text or Entry widgets """ - logger.debug("Selecting all for '%s'", self.widget.winfo_class()) - if self.widget.winfo_class() == "Text": - self.widget.focus_force() - self.widget.tag_add("sel", "1.0", "end") - else: - self.widget.focus_force() - self.widget.select_range(0, tk.END) - - -class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors - """ The Console out section of the GUI """ - - def __init__(self, parent, debug): - logger.debug("Initializing %s: (parent: %s, debug: %s)", - self.__class__.__name__, parent, debug) - ttk.Frame.__init__(self, parent) - self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0), - fill=tk.BOTH, expand=True) - self.console = tk.Text(self) - rc_menu = ContextMenu(self.console) - rc_menu.cm_bind() - self.console_clear = get_config().tk_vars['consoleclear'] - self.set_console_clear_var_trace() - self.debug = debug - self.build_console() - logger.debug("Initialized %s", self.__class__.__name__) - - def set_console_clear_var_trace(self): - """ Set the trigger actions for the clear console var - when it has been triggered from elsewhere """ - logger.debug("Set clear trace") - self.console_clear.trace("w", self.clear) - - def build_console(self): - """ Build and place the console """ - logger.debug("Build console") - self.console.config(width=100, height=6, bg="gray90", fg="black") - self.console.pack(side=tk.LEFT, anchor=tk.N, fill=tk.BOTH, expand=True) - - scrollbar = ttk.Scrollbar(self, command=self.console.yview) - scrollbar.pack(side=tk.LEFT, fill="y") - self.console.configure(yscrollcommand=scrollbar.set) - - self.redirect_console() - logger.debug("Built console") - - def redirect_console(self): - """ Redirect stdout/stderr to console frame """ - logger.debug("Redirect console") - if self.debug: - logger.info("Console debug activated. Outputting to main terminal") - else: - sys.stdout = SysOutRouter(console=self.console, out_type="stdout") - sys.stderr = SysOutRouter(console=self.console, out_type="stderr") - logger.debug("Redirected console") - - def clear(self, *args): # pylint: disable=unused-argument - """ Clear the console output screen """ - logger.debug("Clear console") - if not self.console_clear.get(): - logger.debug("Console not set for clearing. Skipping") - return - self.console.delete(1.0, tk.END) - self.console_clear.set(False) - logger.debug("Cleared console") - - -class SysOutRouter(): - """ Route stdout/stderr to the console window """ - - def __init__(self, console=None, out_type=None): - logger.debug("Initializing %s: (console: %s, out_type: '%s')", - self.__class__.__name__, console, out_type) - self.console = console - self.out_type = out_type - self.color = ("black" if out_type == "stdout" else "red") - logger.debug("Initialized %s", self.__class__.__name__) - - def write(self, string): - """ Capture stdout/stderr """ - self.console.insert(tk.END, string, self.out_type) - self.console.tag_config(self.out_type, foreground=self.color) - self.console.see(tk.END) - - @staticmethod - def flush(): - """ If flush is forced, send it to normal terminal """ - sys.__stdout__.flush() - - -class Config(): - """ Global configuration settings - - Don't call directly. Call get_config() - """ - - def __init__(self, cli_opts, scaling_factor, pathcache, statusbar, session): - logger.debug("Initializing %s: (cli_opts: %s, scaling_factor: %s, pathcache: %s, " - "statusbar: %s, session: %s)", self.__class__.__name__, cli_opts, - scaling_factor, pathcache, statusbar, session) - self.cli_opts = cli_opts - self.scaling_factor = scaling_factor - self.pathcache = pathcache - self.statusbar = statusbar - self.serializer = JSONSerializer - self.tk_vars = self.set_tk_vars() - self.command_notebook = None # set in command.py - self.session = session - logger.debug("Initialized %s", self.__class__.__name__) - - @property - def command_tabs(self): - """ Return dict of command tab titles with their IDs """ - return {self.command_notebook.tab(tab_id, "text").lower(): tab_id - for tab_id in range(0, self.command_notebook.index("end"))} - - @staticmethod - def set_tk_vars(): - """ TK Variables to be triggered by to indicate - what state various parts of the GUI should be in """ - display = tk.StringVar() - display.set(None) - - runningtask = tk.BooleanVar() - runningtask.set(False) - - actioncommand = tk.StringVar() - actioncommand.set(None) - - generatecommand = tk.StringVar() - generatecommand.set(None) - - consoleclear = tk.BooleanVar() - consoleclear.set(False) - - refreshgraph = tk.BooleanVar() - refreshgraph.set(False) - - updatepreview = tk.BooleanVar() - updatepreview.set(False) - - tk_vars = {"display": display, - "runningtask": runningtask, - "action": actioncommand, - "generate": generatecommand, - "consoleclear": consoleclear, - "refreshgraph": refreshgraph, - "updatepreview": updatepreview} - logger.debug(tk_vars) - return tk_vars - - def load(self, command=None, filename=None): - """ Pop up load dialog for a saved config file """ - logger.debug("Loading config: (command: '%s')", command) - if filename: - with open(filename, "r") as cfgfile: - cfg = self.serializer.unmarshal(cfgfile.read()) - else: - cfgfile = FileHandler("open", "config").retfile - if not cfgfile: - return - cfg = self.serializer.unmarshal(cfgfile.read()) - - if not command and len(cfg.keys()) == 1: - command = list(cfg.keys())[0] - - opts = self.get_command_options(cfg, command) if command else cfg - if not opts: - return - - for cmd, opts in opts.items(): - self.set_command_args(cmd, opts) - - if command: - self.command_notebook.select(self.command_tabs[command]) - - self.add_to_recent(cfgfile.name, command) - logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile) - - def get_command_options(self, cfg, command): - """ return the saved options for the requested - command, if not loading global options """ - opts = cfg.get(command, None) - retval = {command: opts} - if not opts: - self.tk_vars["consoleclear"].set(True) - print("No {} section found in file".format(command)) - logger.info("No %s section found in file", command) - retval = None - logger.debug(retval) - return retval - - def set_command_args(self, command, options): - """ Pass the saved config items back to the CliOptions """ - if not options: - return - for srcopt, srcval in options.items(): - optvar = self.cli_opts.get_one_option_variable(command, srcopt) - if not optvar: - continue - optvar.set(srcval) - - def save(self, command=None): - """ Save the current GUI state to a config file in json format """ - logger.debug("Saving config: (command: '%s')", command) - cfgfile = FileHandler("save", "config").retfile - if not cfgfile: - return - cfg = self.cli_opts.get_option_values(command) - cfgfile.write(self.serializer.marshal(cfg)) - cfgfile.close() - self.add_to_recent(cfgfile.name, command) - logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile) - - def add_to_recent(self, filename, command): - """ Add to recent files """ - recent_filename = os.path.join(self.pathcache, ".recent.json") - logger.debug("Adding to recent files '%s': (%s, %s)", recent_filename, filename, command) - with open(recent_filename, "rb") as inp: - recent_files = self.serializer.unmarshal(inp.read().decode("utf-8")) - logger.debug("Initial recent files: %s", recent_files) - filenames = [recent[0] for recent in recent_files] - if filename in filenames: - idx = filenames.index(filename) - del recent_files[idx] - recent_files.insert(0, (filename, command)) - recent_files = recent_files[:20] - logger.debug("Final recent files: %s", recent_files) - recent_json = self.serializer.marshal(recent_files) - with open(recent_filename, "wb") as out: - out.write(recent_json.encode("utf-8")) diff --git a/lib/gui/utils/__init__.py b/lib/gui/utils/__init__.py new file mode 100644 index 0000000000..24e46983d7 --- /dev/null +++ b/lib/gui/utils/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin python3 +""" Utilities for the Faceswap GUI """ + +from .config import get_config, initialize_config, PATHCACHE +from .file_handler import FileHandler +from .image import get_images, initialize_images, preview_trigger +from .misc import LongRunningTask diff --git a/lib/gui/utils/config.py b/lib/gui/utils/config.py new file mode 100644 index 0000000000..92661e5809 --- /dev/null +++ b/lib/gui/utils/config.py @@ -0,0 +1,446 @@ +#!/usr/bin python3 +""" Global configuration optiopns for the Faceswap GUI """ +from __future__ import annotations +import logging +import os +import sys +import tkinter as tk +import typing as T + +from dataclasses import dataclass, field + +from lib.gui import gui_config as cfg +from lib.gui.project import Project, Tasks +from lib.gui.theme import Style +from lib.utils import get_module_objects, PROJECT_ROOT + +from .file_handler import FileHandler + +if T.TYPE_CHECKING: + from lib.gui.options import CliOptions + from lib.gui.custom_widgets import StatusBar + from lib.gui.command import CommandNotebook + from lib.gui.command import ToolsNotebook + +logger = logging.getLogger(__name__) + +PATHCACHE = os.path.join(PROJECT_ROOT, "lib", "gui", ".cache") +_CONFIG: Config | None = None + + +def initialize_config(root: tk.Tk, + cli_opts: CliOptions | None, + statusbar: StatusBar | None) -> Config | None: + """ Initialize the GUI Master :class:`Config` and add to global constant. + + This should only be called once on first GUI startup. Future access to :class:`Config` + should only be executed through :func:`get_config`. + + Parameters + ---------- + root: :class:`tkinter.Tk` + The root Tkinter object + cli_opts: :class:`lib.gui.options.CliOptions` or ``None`` + The command line options object. Must be provided for main GUI. Must be ``None`` for tools + statusbar: :class:`lib.gui.custom_widgets.StatusBar` or ``None`` + The GUI Status bar. Must be provided for main GUI. Must be ``None`` for tools + + Returns + ------- + :class:`Config` or ``None`` + ``None`` if the config has already been initialized otherwise the global configuration + options + """ + global _CONFIG # pylint:disable=global-statement + if _CONFIG is not None: + return None + logger.debug("Initializing config: (root: %s, cli_opts: %s, " + "statusbar: %s)", root, cli_opts, statusbar) + _CONFIG = Config(root, cli_opts, statusbar) + return _CONFIG + + +def get_config() -> "Config": + """ Get the Master GUI configuration. + + Returns + ------- + :class:`Config` + The Master GUI Config + """ + assert _CONFIG is not None + return _CONFIG + + +class GlobalVariables(): + """ Global tkinter variables accessible from all parts of the GUI. Should only be accessed from + :attr:`get_config().tk_vars` """ + def __init__(self) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + self._display = tk.StringVar() + self._running_task = tk.BooleanVar() + self._is_training = tk.BooleanVar() + self._action_command = tk.StringVar() + self._generate_command = tk.StringVar() + self._console_clear = tk.BooleanVar() + self._refresh_graph = tk.BooleanVar() + self._analysis_folder = tk.StringVar() + + self._initialize_variables() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def display(self) -> tk.StringVar: + """ :class:`tkinter.StringVar`: The current Faceswap command running """ + return self._display + + @property + def running_task(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: ``True`` if a Faceswap task is running otherwise + ``False`` """ + return self._running_task + + @property + def is_training(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: ``True`` if Faceswap is currently training otherwise + ``False`` """ + return self._is_training + + @property + def action_command(self) -> tk.StringVar: + """ :class:`tkinter.StringVar`: The command line action to perform """ + return self._action_command + + @property + def generate_command(self) -> tk.StringVar: + """ :class:`tkinter.StringVar`: The command line action to generate """ + return self._generate_command + + @property + def console_clear(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: ``True`` if the console should be cleared otherwise + ``False`` """ + return self._console_clear + + @property + def refresh_graph(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: ``True`` if the training graph should be refreshed + otherwise ``False`` """ + return self._refresh_graph + + @property + def analysis_folder(self) -> tk.StringVar: + """ :class:`tkinter.StringVar`: Full path the analysis folder""" + return self._analysis_folder + + def _initialize_variables(self) -> None: + """ Initialize the default variable values""" + self._display.set("") + self._running_task.set(False) + self._is_training.set(False) + self._action_command.set("") + self._generate_command.set("") + self._console_clear.set(False) + self._refresh_graph.set(False) + self._analysis_folder.set("") + + +@dataclass +class _GuiObjects: + """ Data class for commonly accessed GUI Objects """ + cli_opts: CliOptions | None + tk_vars: GlobalVariables + project: Project + tasks: Tasks + status_bar: StatusBar | None + default_options: dict[str, dict[str, T.Any]] = field(default_factory=dict) + command_notebook: CommandNotebook | None = None + + +class Config(): # pylint:disable=too-many-public-methods + """ The centralized configuration class for holding items that should be made available to all + parts of the GUI. + + This class should be initialized on GUI startup through :func:`initialize_config`. Any further + access to this class should be through :func:`get_config`. + + Parameters + ---------- + root: :class:`tkinter.Tk` + The root Tkinter object + cli_opts: :class:`lib.gui.options.CliOptions` or ``None`` + The command line options object. Must be provided for main GUI. Must be ``None`` for tools + statusbar: :class:`lib.gui.custom_widgets.StatusBar` or ``None`` + The GUI Status bar. Must be provided for main GUI. Must be ``None`` for tools + """ + def __init__(self, + root: tk.Tk, + cli_opts: CliOptions | None, + statusbar: StatusBar | None) -> None: + logger.debug("Initializing %s: (root %s, cli_opts: %s, statusbar: %s)", + self.__class__.__name__, root, cli_opts, statusbar) + self._default_font = T.cast(dict, + tk.font.nametofont("TkDefaultFont").configure())["family"] + self._constants = {"root": root, + "scaling_factor": self._get_scaling(root), + "default_font": self._default_font} + self._gui_objects = _GuiObjects( + cli_opts=cli_opts, + tk_vars=GlobalVariables(), + project=Project(self, FileHandler), + tasks=Tasks(self, FileHandler), + status_bar=statusbar) + + self._style = Style(self.default_font, root, PATHCACHE) + self._user_theme = self._style.user_theme + logger.debug("Initialized %s", self.__class__.__name__) + + # Constants + @property + def root(self) -> tk.Tk: + """ :class:`tkinter.Tk`: The root tkinter window. """ + return self._constants["root"] + + @property + def scaling_factor(self) -> float: + """ float: The scaling factor for current display. """ + return self._constants["scaling_factor"] + + @property + def pathcache(self) -> str: + """ str: The path to the GUI cache folder """ + return PATHCACHE + + # GUI Objects + @property + def cli_opts(self) -> CliOptions: + """ :class:`lib.gui.options.CliOptions`: The command line options for this GUI Session. """ + # This should only be None when a separate tool (not main GUI) is used, at which point + # cli_opts do not exist + assert self._gui_objects.cli_opts is not None + return self._gui_objects.cli_opts + + @property + def tk_vars(self) -> GlobalVariables: + """ dict: The global tkinter variables. """ + return self._gui_objects.tk_vars + + @property + def project(self) -> Project: + """ :class:`lib.gui.project.Project`: The project session handler. """ + return self._gui_objects.project + + @property + def tasks(self) -> Tasks: + """ :class:`lib.gui.project.Tasks`: The session tasks handler. """ + return self._gui_objects.tasks + + @property + def default_options(self) -> dict[str, dict[str, T.Any]]: + """ dict: The default options for all tabs """ + return self._gui_objects.default_options + + @property + def statusbar(self) -> StatusBar: + """ :class:`lib.gui.custom_widgets.StatusBar`: The GUI StatusBar + :class:`tkinter.ttk.Frame`. """ + # This should only be None when a separate tool (not main GUI) is used, at which point + # this statusbar does not exist + assert self._gui_objects.status_bar is not None + return self._gui_objects.status_bar + + @property + def command_notebook(self) -> CommandNotebook | None: + """ :class:`lib.gui.command.CommandNotebook`: The main Faceswap Command Notebook. """ + return self._gui_objects.command_notebook + + # Convenience GUI Objects + @property + def tools_notebook(self) -> ToolsNotebook: + """ :class:`lib.gui.command.ToolsNotebook`: The Faceswap Tools sub-Notebook. """ + assert self.command_notebook is not None + return self.command_notebook.tools_notebook + + @property + def modified_vars(self) -> dict[str, tk.BooleanVar]: + """ dict: The command notebook modified tkinter variables. """ + assert self.command_notebook is not None + return self.command_notebook.modified_vars + + @property + def _command_tabs(self) -> dict[str, int]: + """ dict: Command tab titles with their IDs. """ + assert self.command_notebook is not None + return self.command_notebook.tab_names + + @property + def _tools_tabs(self) -> dict[str, int]: + """ dict: Tools command tab titles with their IDs. """ + assert self.command_notebook is not None + return self.command_notebook.tools_tab_names + + @property + def user_theme(self) -> dict[str, T.Any]: # TODO Dataclass + """ dict: The GUI theme selection options. """ + return self._user_theme + + @property + def default_font(self) -> tuple[str, int]: + """ tuple: The selected font as configured in user settings. First item is the font (`str`) + second item the font size (`int`). """ + font = cfg.font() + font = self._default_font if font == "default" else font + return (font, cfg.font_size()) + + @staticmethod + def _get_scaling(root) -> float: + """ Get the display DPI. + + Returns + ------- + float: + The scaling factor + """ + dpi = root.winfo_fpixels("1i") + scaling = dpi / 72.0 + logger.debug("dpi: %s, scaling: %s'", dpi, scaling) + return scaling + + def set_default_options(self) -> None: + """ Set the default options for :mod:`lib.gui.projects` + + The Default GUI options are stored on Faceswap startup. + + Exposed as the :attr:`_default_opts` for a project cannot be set until after the main + Command Tabs have been loaded. + """ + default = self.cli_opts.get_option_values() + logger.debug(default) + self._gui_objects.default_options = default + self.project.set_default_options() + + def set_command_notebook(self, notebook: CommandNotebook) -> None: + """ Set the command notebook to the :attr:`command_notebook` attribute + and enable the modified callback for :attr:`project`. + + Parameters + ---------- + notebook: :class:`lib.gui.command.CommandNotebook` + The main command notebook for the Faceswap GUI + """ + logger.debug("Setting commane notebook: %s", notebook) + self._gui_objects.command_notebook = notebook + self.project.set_modified_callback() + + def set_active_tab_by_name(self, name: str) -> None: + """ Sets the :attr:`command_notebook` or :attr:`tools_notebook` to active based on given + name. + + Parameters + ---------- + name: str + The name of the tab to set active + """ + assert self.command_notebook is not None + name = name.lower() + if name in self._command_tabs: + tab_id = self._command_tabs[name] + logger.debug("Setting active tab to: (name: %s, id: %s)", name, tab_id) + self.command_notebook.select(tab_id) + elif name in self._tools_tabs: + self.command_notebook.select(self._command_tabs["tools"]) + tab_id = self._tools_tabs[name] + logger.debug("Setting active Tools tab to: (name: %s, id: %s)", name, tab_id) + self.tools_notebook.select() + else: + logger.debug("Name couldn't be found. Setting to id 0: %s", name) + self.command_notebook.select(0) + + def set_modified_true(self, command: str) -> None: + """ Set the modified variable to ``True`` for the given command in :attr:`modified_vars`. + + Parameters + ---------- + command: str + The command to set the modified state to ``True`` + + """ + tkvar = self.modified_vars.get(command, None) + if tkvar is None: + logger.debug("No tkvar for command: '%s'", command) + return + tkvar.set(True) + logger.debug("Set modified var to True for: '%s'", command) + + def set_cursor_busy(self, widget: tk.Widget | None = None) -> None: + """ Set the root or widget cursor to busy. + + Parameters + ---------- + widget: tkinter object, optional + The widget to set busy cursor for. If the provided value is ``None`` then sets the + cursor busy for the whole of the GUI. Default: ``None``. + """ + logger.debug("Setting cursor to busy. widget: %s", widget) + component = self.root if widget is None else widget + component.config(cursor="watch") # type: ignore + component.update_idletasks() + + def set_cursor_default(self, widget: tk.Widget | None = None) -> None: + """ Set the root or widget cursor to default. + + Parameters + ---------- + widget: tkinter object, optional + The widget to set default cursor for. If the provided value is ``None`` then sets the + cursor busy for the whole of the GUI. Default: ``None`` + """ + logger.debug("Setting cursor to default. widget: %s", widget) + component = self.root if widget is None else widget + component.config(cursor="") # type: ignore + component.update_idletasks() + + def set_root_title(self, text: str | None = None) -> None: + """ Set the main title text for Faceswap. + + The title will always begin with 'Faceswap.py'. Additional text can be appended. + + Parameters + ---------- + text: str, optional + Additional text to be appended to the GUI title bar. Default: ``None`` + """ + title = "Faceswap.py" + title += f" - {text}" if text is not None and text else "" + self.root.title(title) + + def set_geometry(self, width: int, height: int, fullscreen: bool = False) -> None: + """ Set the geometry for the root tkinter object. + + Parameters + ---------- + width: int + The width to set the window to (prior to scaling) + height: int + The height to set the window to (prior to scaling) + fullscreen: bool, optional + Whether to set the window to full-screen mode. If ``True`` then :attr:`width` and + :attr:`height` are ignored. Default: ``False`` + """ + self.root.tk.call("tk", "scaling", self.scaling_factor) + if fullscreen: + initial_dimensions = (self.root.winfo_screenwidth(), self.root.winfo_screenheight()) + else: + initial_dimensions = (round(width * self.scaling_factor), + round(height * self.scaling_factor)) + + if fullscreen and sys.platform in ("win32", "darwin"): + self.root.state('zoomed') + elif fullscreen: + self.root.attributes('-zoomed', True) + else: + self.root.geometry(f"{str(initial_dimensions[0])}x{str(initial_dimensions[1])}+80+80") + logger.debug("Geometry: %sx%s", *initial_dimensions) + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/utils/file_handler.py b/lib/gui/utils/file_handler.py new file mode 100644 index 0000000000..6d9c8ea0fc --- /dev/null +++ b/lib/gui/utils/file_handler.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +""" File browser utility functions for the Faceswap GUI. """ +import logging +import platform +import tkinter as tk +from tkinter import filedialog, ttk +import typing as T + +from lib.utils import get_module_objects + +logger = logging.getLogger(__name__) +_FILETYPE = T.Literal["default", "alignments", "config_project", "config_task", + "config_all", "csv", "image", "ini", "json", "state", "log", "video"] +_HANDLETYPE = T.Literal["open", "save", "filename", "filename_multi", "save_filename", + "context", "dir"] + + +class FileHandler(): # pylint:disable=too-few-public-methods + """ Handles all GUI File Dialog actions and tasks. + + Parameters + ---------- + handle_type: ['open', 'save', 'filename', 'filename_multi', 'save_filename', 'context', 'dir'] + The type of file dialog to return. `open` and `save` will perform the open and save actions + and return the file. `filename` returns the filename from an `open` dialog. + `filename_multi` allows for multi-selection of files and returns a list of files selected. + `save_filename` returns the filename from a `save as` dialog. `context` is a context + sensitive parameter that returns a certain dialog based on the current options. `dir` asks + for a folder location. + file_type: ['default', 'alignments', 'config_project', 'config_task', 'config_all', 'csv', \ + 'image', 'ini', 'state', 'log', 'video'] or ``None`` + The type of file that this dialog is for. `default` allows selection of any files. Other + options limit the file type selection + title: str, optional + The title to display on the file dialog. If `None` then the default title will be used. + Default: ``None`` + initial_folder: str, optional + The folder to initially open with the file dialog. If `None` then tkinter will decide. + Default: ``None`` + initial_file: str, optional + The filename to set with the file dialog. If `None` then tkinter no initial filename is. + specified. Default: ``None`` + command: str, optional + Required for context handling file dialog, otherwise unused. Default: ``None`` + action: str, optional + Required for context handling file dialog, otherwise unused. Default: ``None`` + variable: str, optional + Required for context handling file dialog, otherwise unused. The variable to associate + with this file dialog. Default: ``None`` + parent: :class:`tkinter.Frame` | :class:`tkinter.ttk.Frame`, optional + The parent that is launching the file dialog. ``None`` sets this to root. Default: ``None`` + + Attributes + ---------- + return_file: str or object + The return value from the file dialog + + Example + ------- + >>> handler = FileHandler('filename', 'video', title='Select a video...') + >>> video_file = handler.return_file + >>> print(video_file) + '/path/to/selected/video.mp4' + """ + + def __init__(self, + handle_type: _HANDLETYPE, + file_type: _FILETYPE | None, + title: str | None = None, + initial_folder: str | None = None, + initial_file: str | None = None, + command: str | None = None, + action: str | None = None, + variable: str | None = None, + parent: tk.Frame | ttk.Frame | None = None) -> None: + logger.debug("Initializing %s: (handle_type: '%s', file_type: '%s', title: '%s', " + "initial_folder: '%s', initial_file: '%s', command: '%s', action: '%s', " + "variable: %s, parent: %s)", self.__class__.__name__, handle_type, file_type, + title, initial_folder, initial_file, command, action, variable, parent) + self._handletype = handle_type + self._dummy_master = self._set_dummy_master() + self._defaults = self._set_defaults() + self._kwargs = self._set_kwargs(title, + initial_folder, + initial_file, + file_type, + command, + action, + variable, + parent) + self.return_file = getattr(self, f"_{self._handletype.lower()}")() + self._remove_dummy_master() + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _filetypes(self) -> dict[str, list[tuple[str, str]]]: + """ dict: The accepted extensions for each file type for opening/saving """ + all_files = ("All files", "*.*") + filetypes = { + "default": [all_files], + "alignments": [("Faceswap Alignments", "*.fsa"), all_files], + "config_project": [("Faceswap Project files", "*.fsw"), all_files], + "config_task": [("Faceswap Task files", "*.fst"), all_files], + "config_all": [("Faceswap Project and Task files", "*.fst *.fsw"), all_files], + "csv": [("Comma separated values", "*.csv"), all_files], + "image": [("Bitmap", "*.bmp"), + ("JPG", "*.jpeg *.jpg"), + ("PNG", "*.png"), + ("TIFF", "*.tif *.tiff"), + all_files], + "ini": [("Faceswap config files", "*.ini"), all_files], + "json": [("JSON file", "*.json"), all_files], + "model": [("Keras model files", "*.keras"), all_files], + "state": [("State files", "*.json"), all_files], + "log": [("Log files", "*.log"), all_files], + "video": [("Audio Video Interleave", "*.avi"), + ("Flash Video", "*.flv"), + ("Matroska", "*.mkv"), + ("MOV", "*.mov"), + ("MP4", "*.mp4"), + ("MPEG", "*.mpeg *.mpg *.ts *.vob"), + ("WebM", "*.webm"), + ("Windows Media Video", "*.wmv"), + all_files]} + + # Add in multi-select options and upper case extensions for Linux + for key in filetypes: + if platform.system() == "Linux": + filetypes[key] = [item + if item[0] == "All files" + else (item[0], f"{item[1]} {item[1].upper()}") + for item in filetypes[key]] + if len(filetypes[key]) > 2: + multi = [f"{key.title()} Files"] + multi.append(" ".join([ftype[1] + for ftype in filetypes[key] if ftype[0] != "All files"])) + filetypes[key].insert(0, T.cast(tuple[str, str], tuple(multi))) + return filetypes + + @property + def _contexts(self) -> dict[str, dict[str, str | dict[str, str]]]: + """dict: Mapping of commands, actions and their corresponding file dialog for context + handle types. """ + return {"effmpeg": {"input": {"extract": "filename", + "gen-vid": "dir", + "get-fps": "filename", + "get-info": "filename", + "mux-audio": "filename", + "rescale": "filename", + "rotate": "filename", + "slice": "filename"}, + "output": {"extract": "dir", + "gen-vid": "save_filename", + "get-fps": "nothing", + "get-info": "nothing", + "mux-audio": "save_filename", + "rescale": "save_filename", + "rotate": "save_filename", + "slice": "save_filename"}}} + + @classmethod + def _set_dummy_master(cls) -> tk.Frame | None: + """ Add an option to force black font on Linux file dialogs KDE issue that displays light + font on white background). + + This is a pretty hacky solution, but tkinter does not allow direct editing of file dialogs, + so we create a dummy frame and add the foreground option there, so that the file dialog can + inherit the foreground. + + Returns + ------- + tkinter.Frame or ``None`` + The dummy master frame for Linux systems, otherwise ``None`` + """ + if platform.system().lower() == "linux": + frame = tk.Frame() + frame.option_add("*foreground", "black") + retval: tk.Frame | None = frame + else: + retval = None + return retval + + def _remove_dummy_master(self) -> None: + """ Destroy the dummy master widget on Linux systems. """ + if platform.system().lower() != "linux" or self._dummy_master is None: + return + self._dummy_master.destroy() + del self._dummy_master + self._dummy_master = None + + def _set_defaults(self) -> dict[str, str | None]: + """ Set the default file type for the file dialog. Generally the first found file type + will be used, but this is overridden if it is not appropriate. + + Returns + ------- + dict: + The default file extension for each file type + """ + defaults: dict[str, str | None] = { + key: next(ext for ext in val[0][1].split(" ")).replace("*", "") + for key, val in self._filetypes.items()} + defaults["default"] = None + defaults["video"] = ".mp4" + defaults["image"] = ".png" + logger.debug(defaults) + return defaults + + def _set_kwargs(self, + title: str | None, + initial_folder: str | None, + initial_file: str | None, + file_type: _FILETYPE | None, + command: str | None, + action: str | None, + variable: str | None, + parent: tk.Frame | ttk.Frame | None + ) -> dict[str, None | tk.Frame | ttk.Frame | str | list[tuple[str, str]]]: + """ Generate the required kwargs for the requested file dialog browser. + + Parameters + ---------- + title: str + The title to display on the file dialog. If `None` then the default title will be used. + initial_folder: str + The folder to initially open with the file dialog. If `None` then tkinter will decide. + initial_file: str + The filename to set with the file dialog. If `None` then tkinter no initial filename + is. + file_type: ['default', 'alignments', 'config_project', 'config_task', 'config_all', \ + 'csv', 'image', 'ini', 'state', 'log', 'video'] or ``None`` + The type of file that this dialog is for. `default` allows selection of any files. + Other options limit the file type selection + command: str + Required for context handling file dialog, otherwise unused. + action: str + Required for context handling file dialog, otherwise unused. + variable: str, optional + Required for context handling file dialog, otherwise unused. The variable to associate + with this file dialog. Default: ``None`` + parent: :class:`tkinter.Frame` | :class:`tkinter.tk.Frame | None + The parent that is launching the file dialog. ``None`` sets this to root + + Returns + ------- + dict: + The key word arguments for the file dialog to be launched + """ + logger.debug("Setting Kwargs: (title: %s, initial_folder: %s, initial_file: '%s', " + "file_type: '%s', command: '%s': action: '%s', variable: '%s', parent: %s)", + title, initial_folder, initial_file, file_type, command, action, variable, + parent) + + kwargs: dict[str, None | tk.Frame | ttk.Frame | str | list[tuple[str, str]]] = { + "master": self._dummy_master} + + if self._handletype.lower() == "context": + assert command is not None and action is not None and variable is not None + self._set_context_handletype(command, action, variable) + + if title is not None: + kwargs["title"] = title + + if initial_folder is not None: + kwargs["initialdir"] = initial_folder + + if initial_file is not None: + kwargs["initialfile"] = initial_file + + if parent is not None: + kwargs["parent"] = parent + + if self._handletype.lower() in ( + "open", "save", "filename", "filename_multi", "save_filename"): + assert file_type is not None + kwargs["filetypes"] = self._filetypes[file_type] + if self._defaults.get(file_type): + kwargs['defaultextension'] = self._defaults[file_type] + if self._handletype.lower() == "save": + kwargs["mode"] = "w" + if self._handletype.lower() == "open": + kwargs["mode"] = "r" + logger.debug("Set Kwargs: %s", kwargs) + return kwargs + + def _set_context_handletype(self, command: str, action: str, variable: str) -> None: + """ Sets the correct handle type based on context. + + Parameters + ---------- + command: str + The command that is being executed. Used to look up the context actions + action: str + The action that is being performed. Used to look up the correct file dialog + variable: str + The variable associated with this file dialog + """ + if self._contexts[command].get(variable, None) is not None: + handletype = T.cast(dict[str, dict[str, dict[str, str]]], + self._contexts)[command][variable][action] + else: + handletype = T.cast(dict[str, dict[str, str]], + self._contexts)[command][action] + logger.debug(handletype) + self._handletype = T.cast(_HANDLETYPE, handletype) + + def _open(self) -> T.IO | None: + """ Open a file. """ + logger.debug("Popping Open browser") + return filedialog.askopenfile(**self._kwargs) # type: ignore + + def _save(self) -> T.IO | None: + """ Save a file. """ + logger.debug("Popping Save browser") + return filedialog.asksaveasfile(**self._kwargs) # type: ignore + + def _dir(self) -> str: + """ Get a directory location. """ + logger.debug("Popping Dir browser") + return filedialog.askdirectory(**self._kwargs) # type: ignore + + def _savedir(self) -> str: + """ Get a save directory location. """ + logger.debug("Popping SaveDir browser") + return filedialog.askdirectory(**self._kwargs) # type: ignore + + def _filename(self) -> str: + """ Get an existing file location. """ + logger.debug("Popping Filename browser") + return filedialog.askopenfilename(**self._kwargs) # type: ignore + + def _filename_multi(self) -> tuple[str, ...]: + """ Get multiple existing file locations. """ + logger.debug("Popping Filename browser") + return filedialog.askopenfilenames(**self._kwargs) # type: ignore + + def _save_filename(self) -> str: + """ Get a save file location. """ + logger.debug("Popping Save Filename browser") + return filedialog.asksaveasfilename(**self._kwargs) # type: ignore + + @staticmethod + def _nothing() -> None: # pylint:disable=useless-return + """ Method that does nothing, used for disabling open/save pop up. """ + logger.debug("Popping Nothing browser") + return + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/utils/image.py b/lib/gui/utils/image.py new file mode 100644 index 0000000000..05305aab75 --- /dev/null +++ b/lib/gui/utils/image.py @@ -0,0 +1,665 @@ +#!/usr/bin python3 +""" Utilities for handling images in the Faceswap GUI """ +from __future__ import annotations +import logging +import os +import typing as T + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageTk + +from lib.gui import gui_config as cfg +from lib.training.preview_cv import PreviewBuffer +from lib.utils import get_module_objects + +from .config import get_config, PATHCACHE + +if T.TYPE_CHECKING: + from collections.abc import Sequence + +logger = logging.getLogger(__name__) +_IMAGES: Images | None = None +_PREVIEW_TRIGGER: PreviewTrigger | None = None +TRAININGPREVIEW = ".gui_training_preview.png" + + +def initialize_images() -> None: + """ Initialize the :class:`Images` handler and add to global constant. + + This should only be called once on first GUI startup. Future access to :class:`Images` + handler should only be executed through :func:`get_images`. + """ + global _IMAGES # pylint:disable=global-statement + if _IMAGES is not None: + return + logger.debug("Initializing images") + _IMAGES = Images() + + +def get_images() -> "Images": + """ Get the Master GUI Images handler. + + Returns + ------- + :class:`Images` + The Master GUI Images handler + """ + assert _IMAGES is not None + return _IMAGES + + +def _get_previews(image_path: str) -> list[str]: + """ Get the images stored within the given directory. + + Parameters + ---------- + image_path: str + The folder containing images to be scanned + + Returns + ------- + list: + The image filenames stored within the given folder + + """ + logger.debug("Getting images: '%s'", image_path) + if not os.path.isdir(image_path): + logger.debug("Folder does not exist") + return [] + files = [os.path.join(image_path, f) + for f in os.listdir(image_path) if f.lower().endswith((".png", ".jpg"))] + logger.debug("Image files: %s", files) + return files + + +class PreviewTrain(): + """ Handles the loading of the training preview image(s) and adding to the display buffer + + Parameters + ---------- + cache_path: str + Full path to the cache folder that contains the preview images + """ + def __init__(self, cache_path: str) -> None: + logger.debug("Initializing %s: (cache_path: '%s')", self.__class__.__name__, cache_path) + self._buffer = PreviewBuffer() + self._cache_path = cache_path + self._modified: float = 0.0 + self._error_count: int = 0 + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def buffer(self) -> PreviewBuffer: + """ :class:`~lib.training.PreviewBuffer` The preview buffer for the training preview + image. """ + return self._buffer + + def load(self) -> bool: + """ Load the latest training preview image(s) from disk and add to :attr:`buffer` """ + logger.trace("Loading Training preview images") # type:ignore + image_files = _get_previews(self._cache_path) + filename = next((fname for fname in image_files + if os.path.basename(fname) == TRAININGPREVIEW), "") + img: np.ndarray | None = None + if not filename: + logger.trace("No preview to display") # type:ignore + return False + try: + modified = os.path.getmtime(filename) + if modified <= self._modified: + logger.trace("preview '%s' not updated. Current timestamp: %s, " # type:ignore + "existing timestamp: %s", filename, modified, self._modified) + return False + + logger.debug("Loading preview: '%s'", filename) + img = cv2.imread(filename, cv2.IMREAD_UNCHANGED) + assert img is not None + self._modified = modified + self._buffer.add_image(os.path.basename(filename), img) + self._error_count = 0 + except (ValueError, AssertionError): + # This is probably an error reading the file whilst it's being saved so ignore it + # for now and only pick up if there have been multiple consecutive fails + logger.debug("Unable to display preview: (image: '%s', attempt: %s)", + img, self._error_count) + if self._error_count < 10: + self._error_count += 1 + else: + logger.error("Error reading the preview file for '%s'", filename) + return False + + logger.debug("Loaded preview: '%s' (%s)", filename, img.shape) + return True + + def reset(self) -> None: + """ Reset the preview buffer when the display page has been disabled. + + Notes + ----- + The buffer requires resetting, otherwise the re-enabled preview window hangs waiting for a + training image that has already been marked as processed + """ + logger.debug("Resetting training preview") + del self._buffer + self._buffer = PreviewBuffer() + self._modified = 0.0 + self._error_count = 0 + + +class PreviewExtract(): + """ Handles the loading of preview images for extract and convert + + Parameters + ---------- + cache_path: str + Full path to the cache folder that contains the preview images + """ + def __init__(self, cache_path: str) -> None: + logger.debug("Initializing %s: (cache_path: '%s')", self.__class__.__name__, cache_path) + self._cache_path = cache_path + + self._batch_mode = False + self._output_path = "" + + self._modified: float = 0.0 + self._filenames: list[str] = [] + self._images: np.ndarray | None = None + self._placeholder: np.ndarray | None = None + + self._preview_image: Image.Image | None = None + self._preview_image_tk: ImageTk.PhotoImage | None = None + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def image(self) -> ImageTk.PhotoImage: + """:class:`PIL.ImageTk.PhotoImage` The preview image for displaying in a tkinter canvas """ + assert self._preview_image_tk is not None + return self._preview_image_tk + + def save(self, filename: str) -> None: + """ Save the currently displaying preview image to the given location + + Parameters + ---------- + filename: str + The full path to the filename to save the preview image to + """ + logger.debug("Saving preview to %s", filename) + assert self._preview_image is not None + self._preview_image.save(filename) + + def set_faceswap_output_path(self, location: str, batch_mode: bool = False) -> None: + """ Set the path that will contain the output from an Extract or Convert task. + + Required so that the GUI can fetch output images to display for return in + :attr:`preview_image`. + + Parameters + ---------- + location: str + The output location that has been specified for an Extract or Convert task + batch_mode: bool + ``True`` if extracting in batch mode otherwise False + """ + self._output_path = location + self._batch_mode = batch_mode + + def _get_newest_folder(self) -> str: + """ Obtain the most recent folder created in the extraction output folder when processing + in batch mode. + + Returns + ------- + str + The most recently modified folder within the parent output folder. If no folders have + been created, returns the parent output folder + + """ + folders = [] if not os.path.exists(self._output_path) else [ + os.path.join(self._output_path, folder) + for folder in os.listdir(self._output_path) + if os.path.isdir(os.path.join(self._output_path, folder))] + + folders.sort(key=os.path.getmtime) + retval = folders[-1] if folders else self._output_path + logger.debug("sorted folders: %s, return value: %s", folders, retval) + return retval + + def _get_newest_filenames(self, image_files: list[str]) -> list[str]: + """ Return image filenames that have been modified since the last check. + + Parameters + ---------- + image_files: list + The list of image files to check the modification date for + + Returns + ------- + list: + A list of images that have been modified since the last check + """ + if not self._modified: + retval = image_files + else: + retval = [fname for fname in image_files + if os.path.getmtime(fname) > self._modified] + if not retval: + logger.debug("No new images in output folder") + else: + self._modified = max(os.path.getmtime(img) for img in retval) + logger.debug("Number new images: %s, Last Modified: %s", + len(retval), self._modified) + return retval + + def _pad_and_border(self, image: Image.Image, size: int) -> np.ndarray: + """ Pad rectangle images to a square and draw borders + + Parameters + ---------- + image: :class:`PIL.Image` + The image to process + size: int + The size of the image as it should be displayed + + Returns + ------- + :class:`numpy.ndarray`: + The processed image + """ + if image.size[0] != image.size[1]: + # Pad to square + new_img = Image.new("RGB", (size, size)) + new_img.paste(image, ((size - image.size[0]) // 2, (size - image.size[1]) // 2)) + image = new_img + draw = ImageDraw.Draw(image) + draw.rectangle(((0, 0), (size, size)), outline="#E5E5E5", width=1) + retval = np.array(image) + logger.trace("image shape: %s", retval.shape) # type: ignore + return retval + + def _process_samples(self, + samples: list[np.ndarray], + filenames: list[str], + num_images: int) -> bool: + """ Process the latest sample images into a displayable image. + + Parameters + ---------- + samples: list + The list of extract/convert preview images to display + filenames: list + The full path to the filenames corresponding to the images + num_images: int + The number of images that should be displayed + + Returns + ------- + bool + ``True`` if samples succesfully compiled otherwise ``False`` + """ + asamples = np.array(samples) + if not np.any(asamples): + logger.debug("No preview images collected.") + return False + + self._filenames = (self._filenames + filenames)[-num_images:] + cache = self._images + + if cache is None: + logger.debug("Creating new cache") + cache = asamples[-num_images:] + else: + logger.debug("Appending to existing cache") + cache = np.concatenate((cache, asamples))[-num_images:] + + self._images = cache + assert self._images is not None + logger.debug("Cache shape: %s", self._images.shape) + return True + + def _load_images_to_cache(self, # pylint:disable=too-many-locals + image_files: list[str], + frame_dims: tuple[int, int], + thumbnail_size: int) -> bool: + """ Load preview images to the image cache. + + Load new images and append to cache, filtering the cache to the number of thumbnails that + will fit inside the display panel. + + Parameters + ---------- + image_files: list + A list of new image files that have been modified since the last check + frame_dims: tuple + The (width (`int`), height (`int`)) of the display panel that will display the preview + thumbnail_size: int + The size of each thumbnail that should be created + + Returns + ------- + bool + ``True`` if images were successfully loaded to cache otherwise ``False`` + """ + logger.debug("Number image_files: %s, frame_dims: %s, thumbnail_size: %s", + len(image_files), frame_dims, thumbnail_size) + num_images = (frame_dims[0] // thumbnail_size) * (frame_dims[1] // thumbnail_size) + logger.debug("num_images: %s", num_images) + if num_images == 0: + return False + samples: list[np.ndarray] = [] + start_idx = len(image_files) - num_images if len(image_files) > num_images else 0 + show_files = sorted(image_files, key=os.path.getctime)[start_idx:] + dropped_files = [] + for fname in show_files: + try: + img_file = Image.open(fname) + except PermissionError as err: + logger.debug("Permission error opening preview file: '%s'. Original error: %s", + fname, str(err)) + dropped_files.append(fname) + continue + except Exception as err: # pylint:disable=broad-except + # Swallow any issues with opening an image rather than spamming console + # Can happen when trying to read partially saved images + logger.debug("Error opening preview file: '%s'. Original error: %s", + fname, str(err)) + dropped_files.append(fname) + continue + + width, height = img_file.size + scaling = thumbnail_size / max(width, height) + logger.debug("image width: %s, height: %s, scaling: %s", width, height, scaling) + + try: + img = img_file.resize((int(width * scaling), int(height * scaling))) + except OSError as err: + # Image only gets loaded when we call a method, so may error on partial loads + logger.debug("OS Error resizing preview image: '%s'. Original error: %s", + fname, err) + dropped_files.append(fname) + continue + + samples.append(self._pad_and_border(img, thumbnail_size)) + + return self._process_samples(samples, + [fname for fname in show_files if fname not in dropped_files], + num_images) + + def _create_placeholder(self, thumbnail_size: int) -> None: + """ Create a placeholder image for when there are fewer thumbnails available + than columns to display them. + + Parameters + ---------- + thumbnail_size: int + The size of the thumbnail that the placeholder should replicate + """ + logger.debug("Creating placeholder. thumbnail_size: %s", thumbnail_size) + placeholder = Image.new("RGB", (thumbnail_size, thumbnail_size)) + draw = ImageDraw.Draw(placeholder) + draw.rectangle(((0, 0), (thumbnail_size, thumbnail_size)), outline="#E5E5E5", width=1) + nplaceholder = np.array(placeholder) + self._placeholder = nplaceholder + logger.debug("Created placeholder. shape: %s", nplaceholder.shape) + + def _place_previews(self, frame_dims: tuple[int, int]) -> Image.Image | None: + """ Format the preview thumbnails stored in the cache into a grid fitting the display + panel. + + Parameters + ---------- + frame_dims: tuple + The (width (`int`), height (`int`)) of the display panel that will display the preview + + Returns + ------- + :class:`PIL.Image`: | None + The final preview display image + """ + if self._images is None: + logger.debug("No images in cache. Returning None") + return None + samples = self._images.copy() + num_images, thumbnail_size = samples.shape[:2] + if self._placeholder is None: + self._create_placeholder(thumbnail_size) + + logger.debug("num_images: %s, thumbnail_size: %s", num_images, thumbnail_size) + cols, rows = frame_dims[0] // thumbnail_size, frame_dims[1] // thumbnail_size + logger.debug("cols: %s, rows: %s", cols, rows) + if cols == 0 or rows == 0: + logger.debug("Cols or Rows is zero. No items to display") + return None + + remainder = (cols * rows) - num_images + if remainder != 0: + logger.debug("Padding sample display. Remainder: %s", remainder) + assert self._placeholder is not None + placeholder = np.concatenate([np.expand_dims(self._placeholder, 0)] * remainder) + samples = np.concatenate((samples, placeholder)) + + display = np.vstack([np.hstack(T.cast("Sequence", samples[row * cols: (row + 1) * cols])) + for row in range(rows)]) + logger.debug("display shape: %s", display.shape) + return Image.fromarray(display) + + def load_latest_preview(self, thumbnail_size: int, frame_dims: tuple[int, int]) -> bool: + """ Load the latest preview image for extract and convert. + + Retrieves the latest preview images from the faceswap output folder, resizes to thumbnails + and lays out for display. Places the images into :attr:`preview_image` for loading into + the display panel. + + Parameters + ---------- + thumbnail_size: int + The size of each thumbnail that should be created + frame_dims: tuple + The (width (`int`), height (`int`)) of the display panel that will display the preview + + Returns + ------- + bool + ``True`` if a preview was succesfully loaded otherwise ``False`` + """ + logger.debug("Loading preview image: (thumbnail_size: %s, frame_dims: %s)", + thumbnail_size, frame_dims) + image_path = self._get_newest_folder() if self._batch_mode else self._output_path + image_files = _get_previews(image_path) + gui_preview = os.path.join(self._output_path, ".gui_preview.jpg") + if not image_files or (len(image_files) == 1 and gui_preview not in image_files): + logger.debug("No preview to display") + return False + # Filter to just the gui_preview if it exists in folder output + image_files = [gui_preview] if gui_preview in image_files else image_files + logger.debug("Image Files: %s", len(image_files)) + + image_files = self._get_newest_filenames(image_files) + if not image_files: + return False + + if not self._load_images_to_cache(image_files, frame_dims, thumbnail_size): + logger.debug("Failed to load any preview images") + if gui_preview in image_files: + # Reset last modified for failed loading of a gui preview image so it is picked + # up next time + self._modified = 0.0 + return False + + if image_files == [gui_preview]: + # Delete the preview image so that the main scripts know to output another + logger.debug("Deleting preview image") + os.remove(image_files[0]) + show_image = self._place_previews(frame_dims) + if not show_image: + self._preview_image = None + self._preview_image_tk = None + return False + + logger.debug("Displaying preview: %s", self._filenames) + self._preview_image = show_image + self._preview_image_tk = ImageTk.PhotoImage(show_image) + return True + + def delete_previews(self) -> None: + """ Remove any image preview files """ + for fname in self._filenames: + if os.path.basename(fname) == ".gui_preview.jpg": + logger.debug("Deleting: '%s'", fname) + try: + os.remove(fname) + except FileNotFoundError: + logger.debug("File does not exist: %s", fname) + + +class Images(): + """ The centralized image repository for holding all icons and images required by the GUI. + + This class should be initialized on GUI startup through :func:`initialize_images`. Any further + access to this class should be through :func:`get_images`. + """ + def __init__(self) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + self._pathpreview = os.path.join(PATHCACHE, "preview") + self._pathoutput: str | None = None + self._batch_mode = False + self._preview_train = PreviewTrain(self._pathpreview) + self._preview_extract = PreviewExtract(self._pathpreview) + self._icons = self._load_icons() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def preview_train(self) -> PreviewTrain: + """ :class:`PreviewTrain` The object handling the training preview images """ + return self._preview_train + + @property + def preview_extract(self) -> PreviewExtract: + """ :class:`PreviewTrain` The object handling the training preview images """ + return self._preview_extract + + @property + def icons(self) -> dict[str, ImageTk.PhotoImage]: + """ dict: The faceswap icons for all parts of the GUI. The dictionary key is the icon + name (`str`) the value is the icon sized and formatted for display + (:class:`PIL.ImageTK.PhotoImage`). + + Example + ------- + >>> icons = get_images().icons + >>> save = icons["save"] + >>> button = ttk.Button(parent, image=save) + >>> button.pack() + """ + return self._icons + + @staticmethod + def _load_icons() -> dict[str, ImageTk.PhotoImage]: + """ Scan the icons cache folder and load the icons into :attr:`icons` for retrieval + throughout the GUI. + + Returns + ------- + dict: + The icons formatted as described in :attr:`icons` + + """ + size = cfg.icon_size() + size = int(round(size * get_config().scaling_factor)) + icons: dict[str, ImageTk.PhotoImage] = {} + pathicons = os.path.join(PATHCACHE, "icons") + for fname in os.listdir(pathicons): + name, ext = os.path.splitext(fname) + if ext != ".png": + continue + img = Image.open(os.path.join(pathicons, fname)) + pimg = ImageTk.PhotoImage(img.resize((size, size), resample=Image.Resampling.HAMMING)) + icons[name] = pimg + logger.debug(icons) + return icons + + def delete_preview(self) -> None: + """ Delete the preview files in the cache folder and reset the image cache. + + Should be called when terminating tasks, or when Faceswap starts up or shuts down. + """ + logger.debug("Deleting previews") + for item in os.listdir(self._pathpreview): + if item.startswith(os.path.splitext(TRAININGPREVIEW)[0]) and item.endswith((".jpg", + ".png")): + fullitem = os.path.join(self._pathpreview, item) + logger.debug("Deleting: '%s'", fullitem) + os.remove(fullitem) + + self._preview_extract.delete_previews() + del self._preview_train + del self._preview_extract + self._preview_train = PreviewTrain(self._pathpreview) + self._preview_extract = PreviewExtract(self._pathpreview) + + +class PreviewTrigger(): + """ Triggers to indicate to underlying Faceswap process that the preview image should + be updated. + + Writes a file to the cache folder that is picked up by the main process. + """ + def __init__(self) -> None: + logger.debug("Initializing: %s", self.__class__.__name__) + self._trigger_files = {"update": os.path.join(PATHCACHE, ".preview_trigger"), + "mask_toggle": os.path.join(PATHCACHE, ".preview_mask_toggle")} + logger.debug("Initialized: %s (trigger_files: %s)", + self.__class__.__name__, self._trigger_files) + + def set(self, trigger_type: T.Literal["update", "mask_toggle"]): + """ Place the trigger file into the cache folder + + Parameters + ---------- + trigger_type: ["update", "mask_toggle"] + The type of action to trigger. 'update': Full preview update. 'mask_toggle': toggle + mask on and off + """ + trigger = self._trigger_files[trigger_type] + if not os.path.isfile(trigger): + with open(trigger, "w", encoding="utf8"): + pass + logger.debug("Set preview trigger: %s", trigger) + + def clear(self, trigger_type: T.Literal["update", "mask_toggle"] | None = None) -> None: + """ Remove the trigger file from the cache folder. + + Parameters + ---------- + trigger_type: ["update", "mask_toggle", ``None``], optional + The trigger to clear. 'update': Full preview update. 'mask_toggle': toggle mask on + and off. ``None`` - clear all triggers. Default: ``None`` + """ + if trigger_type is None: + triggers = list(self._trigger_files.values()) + else: + triggers = [self._trigger_files[trigger_type]] + for trigger in triggers: + if os.path.isfile(trigger): + os.remove(trigger) + logger.debug("Removed preview trigger: %s", trigger) + + +def preview_trigger() -> PreviewTrigger: + """ Set the global preview trigger if it has not already been set and return. + + Returns + ------- + :class:`PreviewTrigger` + The trigger to indicate to the main faceswap process that it should perform a training + preview update + """ + global _PREVIEW_TRIGGER # pylint:disable=global-statement + if _PREVIEW_TRIGGER is None: + _PREVIEW_TRIGGER = PreviewTrigger() + return _PREVIEW_TRIGGER + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/utils/misc.py b/lib/gui/utils/misc.py new file mode 100644 index 0000000000..c559e329cc --- /dev/null +++ b/lib/gui/utils/misc.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +""" Miscellaneous Utility functions for the GUI. Includes LongRunningTask object """ +from __future__ import annotations +import logging +import sys +import typing as T + +from threading import Event, Thread +from queue import Queue + +from lib.utils import get_module_objects + +from .config import get_config + +if T.TYPE_CHECKING: + from collections.abc import Callable + from types import TracebackType + from lib.multithreading import _ErrorType + + +logger = logging.getLogger(__name__) + + +class LongRunningTask(Thread): + """ Runs long running tasks in a background thread to prevent the GUI from becoming + unresponsive. + + This is sub-classed from :class:`Threading.Thread` so check documentation there for base + parameters. Additional parameters listed below. + + Parameters + ---------- + widget: tkinter object, optional + The widget that this :class:`LongRunningTask` is associated with. Used for setting the busy + cursor in the correct location. Default: ``None``. + """ + _target: Callable + _args: tuple + _kwargs: dict[str, T.Any] + _name: str + + def __init__(self, + target: Callable | None = None, + name: str | None = None, + args: tuple = (), + kwargs: dict[str, T.Any] | None = None, + *, + daemon: bool = True, + widget=None): + logger.debug("Initializing %s: (target: %s, name: %s, args: %s, kwargs: %s, " + "daemon: %s)", self.__class__.__name__, target, name, args, kwargs, + daemon) + super().__init__(target=target, name=name, args=args, kwargs=kwargs, + daemon=daemon) + self.err: _ErrorType = None + self._widget = widget + self._config = get_config() + self._config.set_cursor_busy(widget=self._widget) + self._complete = Event() + self._queue: Queue = Queue() + logger.debug("Initialized %s", self.__class__.__name__,) + + @property + def complete(self) -> Event: + """ :class:`threading.Event`: Event is set if the thread has completed its task, + otherwise it is unset. + """ + return self._complete + + def run(self) -> None: + """ Commence the given task in a background thread. """ + try: + if self._target is not None: + retval = self._target(*self._args, **self._kwargs) + self._queue.put(retval) + except Exception: # pylint:disable=broad-except + self.err = T.cast(tuple[type[BaseException], BaseException, "TracebackType"], + sys.exc_info()) + assert self.err is not None + logger.debug("Error in thread (%s): %s", self._name, + self.err[1].with_traceback(self.err[2])) + finally: + self._complete.set() + # Avoid a ref-cycle if the thread is running a function with + # an argument that has a member that points to the thread. + del self._target, self._args, self._kwargs + + def get_result(self) -> T.Any: + """ Return the result from the given task. + + Returns + ------- + varies: + The result of the thread will depend on the given task. If a call is made to + :func:`get_result` prior to the thread completing its task then ``None`` will be + returned + """ + if not self._complete.is_set(): + logger.warning("Aborting attempt to retrieve result from a LongRunningTask that is " + "still running") + return None + if self.err: + logger.debug("Error caught in thread") + self._config.set_cursor_default(widget=self._widget) + raise self.err[1].with_traceback(self.err[2]) + + logger.debug("Getting result from thread") + retval = self._queue.get() + logger.debug("Got result from thread") + self._config.set_cursor_default(widget=self._widget) + return retval + + +__all__ = get_module_objects(__name__) diff --git a/lib/gui/wrapper.py b/lib/gui/wrapper.py index c5363697a6..6cef8a3c3c 100644 --- a/lib/gui/wrapper.py +++ b/lib/gui/wrapper.py @@ -1,89 +1,164 @@ #!/usr/bin python3 """ Process wrapper for underlying faceswap commands for the GUI """ +from __future__ import annotations import os import logging import re import signal -from subprocess import PIPE, Popen, TimeoutExpired import sys +import typing as T + +from subprocess import PIPE, Popen from threading import Thread from time import time import psutil -from .utils import get_config, get_images +from lib.gui import gui_config as cfg +from lib.utils import get_module_objects + +from .analysis import Session +from .utils import get_config, get_images, LongRunningTask, preview_trigger + +if os.name == "nt": + import win32console # pylint:disable=import-error -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) class ProcessWrapper(): """ Builds command, launches and terminates the underlying faceswap process. Updates GUI display depending on state """ - def __init__(self, pathscript=None): - logger.debug("Initializing %s: (pathscript: %s)", self.__class__.__name__, pathscript) - self.tk_vars = get_config().tk_vars - self.set_callbacks() - self.pathscript = pathscript - self.command = None - self.statusbar = get_config().statusbar - self.task = FaceswapControl(self) + def __init__(self) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + self._tk_vars = get_config().tk_vars + self._set_callbacks() + self._command: str | None = None + """ str | None: The currently executing command, when process running or ``None`` """ + + self._statusbar = get_config().statusbar + self._training_session_location: dict[T.Literal["model_name", "model_folder"], str] = {} + self._task = FaceswapControl(self) logger.debug("Initialized %s", self.__class__.__name__) - def set_callbacks(self): - """ Set the tk variable callbacks """ - logger.debug("Setting tk variable traces") - self.tk_vars["action"].trace("w", self.action_command) - self.tk_vars["generate"].trace("w", self.generate_command) + @property + def task(self) -> FaceswapControl: + """ :class:`FaceswapControl`: The object that controls the underlying faceswap process """ + return self._task - def action_command(self, *args): - """ The action to perform when the action button is pressed """ - if not self.tk_vars["action"].get(): + def _set_callbacks(self) -> None: + """ Set the tkinter variable callbacks for performing an action or generating a command """ + logger.debug("Setting tk variable traces") + self._tk_vars.action_command.trace("w", self._action_command) + self._tk_vars.generate_command.trace("w", self._generate_command) + + def _action_command(self, *args: tuple[str, str, str]): # pylint:disable=unused-argument + """ Callback for when the Action button is pressed. Process command line options and + launches the action + + Parameters + ---------- + args: + tuple[str, str, str] + Tkinter variable callback args. Required but unused + """ + if not self._tk_vars.action_command.get(): return - category, command = self.tk_vars["action"].get().split(",") + category, command = self._tk_vars.action_command.get().split(",") - if self.tk_vars["runningtask"].get(): - self.task.terminate() + if self._tk_vars.running_task.get(): + self._task.terminate() else: - self.command = command - args = self.prepare(category) - self.task.execute_script(command, args) - self.tk_vars["action"].set(None) - - def generate_command(self, *args): - """ Generate the command line arguments and output """ - if not self.tk_vars["generate"].get(): + self._command = command + fs_args = self._prepare(T.cast(T.Literal["faceswap", "tools"], category)) + self._task.execute_script(command, fs_args) + self._tk_vars.action_command.set("") + + def _generate_command(self, # pylint:disable=unused-argument + *args: tuple[str, str, str]) -> None: + """ Callback for when the Generate button is pressed. Process command line options and + output the cli command + + Parameters + ---------- + args: + tuple[str, str, str] + Tkinter variable callback args. Required but unused + """ + if not self._tk_vars.generate_command.get(): return - category, command = self.tk_vars["generate"].get().split(",") - args = self.build_args(category, command=command, generate=True) - self.tk_vars["consoleclear"].set(True) - logger.debug(" ".join(args)) - print(" ".join(args)) - self.tk_vars["generate"].set(None) - - def prepare(self, category): - """ Prepare the environment for execution """ + category, command = self._tk_vars.generate_command.get().split(",") + fs_args = self._build_args(category, command=command, generate=True) + self._tk_vars.console_clear.set(True) + logger.debug(" ".join(fs_args)) + print(" ".join(fs_args)) + self._tk_vars.generate_command.set("") + + def _prepare(self, category: T.Literal["faceswap", "tools"]) -> list[str]: + """ Prepare the environment for execution, Sets the 'running task' and 'console clear' + global tkinter variables. If training, sets the 'is training' variable + + Parameters + ---------- + category: str, ["faceswap", "tools"] + The script that is executing the command + + Returns + ------- + list[str] + The command line arguments to execute for the faceswap job + """ logger.debug("Preparing for execution") - self.tk_vars["runningtask"].set(True) - self.tk_vars["consoleclear"].set(True) + assert self._command is not None + self._tk_vars.running_task.set(True) + self._tk_vars.console_clear.set(True) + if self._command == "train": + self._tk_vars.is_training.set(True) print("Loading...") - self.statusbar.status_message.set("Executing - {}.py".format(self.command)) - mode = "indeterminate" if self.command in ("effmpeg", "train") else "determinate" - self.statusbar.progress_start(mode) + self._statusbar.message.set(f"Executing - {self._command}.py") + mode: T.Literal["indeterminate", + "determinate"] = ("indeterminate" if self._command in ("effmpeg", "train") + else "determinate") + self._statusbar.start(mode) - args = self.build_args(category) - self.tk_vars["display"].set(self.command) + args = self._build_args(category) + self._tk_vars.display.set(self._command) logger.debug("Prepared for execution") return args - def build_args(self, category, command=None, generate=False): - """ Build the faceswap command and arguments list """ + def _build_args(self, + category: str, + command: str | None = None, + generate: bool = False) -> list[str]: + """ Build the faceswap command and arguments list. + + If training, pass the model folder and name to the training + :class:`lib.gui.analysis.Session` for the GUI. + + Parameters + ---------- + category: str, ["faceswap", "tools"] + The script that is executing the command + command: str, optional + The main faceswap command to execute, if provided. The currently running task if + ``None``. Default: ``None`` + generate: bool, optional + ``True`` if the command is just to be generated for display. ``False`` if the command + is to be executed + + Returns + ------- + list[str] + The full faceswap command to be executed or displayed + """ logger.debug("Build cli arguments: (category: %s, command: %s, generate: %s)", category, command, generate) - command = self.command if not command else command - script = "{}.{}".format(category, "py") - pathexecscript = os.path.join(self.pathscript, script) + command = self._command if not command else command + assert command is not None + script = f"{category}.py" + pathexecscript = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])), script) args = [sys.executable] if generate else [sys.executable, "-u"] args.extend([pathexecscript, command]) @@ -92,281 +167,539 @@ def build_args(self, category, command=None, generate=False): for cliopt in cli_opts.gen_cli_arguments(command): args.extend(cliopt) if command == "train" and not generate: - self.init_training_session(cliopt) + self._get_training_session_info(cliopt) + if not generate: - args.append("-gui") # Indicate to Faceswap that we are running the GUI + args.append("-G") # Indicate to Faceswap that we are running the GUI if generate: # Delimit args with spaces - args = ['"{}"'.format(arg) if " " in arg else arg for arg in args] + args = [f'"{arg}"' if " " in arg and not arg.startswith(("[", "(")) + and not arg.endswith(("]", ")")) else arg + for arg in args] logger.debug("Built cli arguments: (%s)", args) return args - @staticmethod - def init_training_session(cliopt): - """ Set the session stats for disable logging, model folder and model name """ - session = get_config().session - if cliopt[0] == "-t": - session.modelname = cliopt[1].lower().replace("-", "_") - logger.debug("modelname: '%s'", session.modelname) - if cliopt[0] == "-m": - session.modeldir = cliopt[1] - logger.debug("modeldir: '%s'", session.modeldir) - - def terminate(self, message): - """ Finalize wrapper when process has exited """ + def _get_training_session_info(self, cli_option: tuple[str, ...]) -> None: + """ Set the model folder and model name to :`attr:_training_session_location` so the global + session picks them up for logging to the graph and analysis tab. + + Parameters + ---------- + cli_option: list[str] + The command line option to be checked for model folder or name + """ + if cli_option[0] == "-t": + self._training_session_location["model_name"] = cli_option[1].lower().replace("-", "_") + logger.debug("model_name: '%s'", self._training_session_location["model_name"]) + if cli_option[0] == "-m": + self._training_session_location["model_folder"] = cli_option[1] + logger.debug("model_folder: '%s'", self._training_session_location["model_folder"]) + + def terminate(self, message: str) -> None: + """ Finalize wrapper when process has exited. Stops the progress bar, sets the status + message. If the terminating task is 'train', then triggers the training close down actions + + Parameters + ---------- + message: str + The message to display in the status bar + """ logger.debug("Terminating Faceswap processes") - self.tk_vars["runningtask"].set(False) - self.statusbar.progress_stop() - self.statusbar.status_message.set(message) - self.tk_vars["display"].set(None) + self._tk_vars.running_task.set(False) + if self._task.command == "train": + self._tk_vars.is_training.set(False) + Session.stop_training() + self._statusbar.stop() + self._statusbar.message.set(message) + self._tk_vars.display.set("") get_images().delete_preview() - get_config().session.__init__() - self.command = None + preview_trigger().clear(trigger_type=None) + self._command = None logger.debug("Terminated Faceswap processes") print("Process exited.") class FaceswapControl(): - """ Control the underlying Faceswap tasks """ - def __init__(self, wrapper): - logger.debug("Initializing %s", self.__class__.__name__) - self.wrapper = wrapper - self.statusbar = get_config().statusbar - self.command = None - self.args = None - self.process = None - self.train_stats = {"iterations": 0, "timestamp": None} - self.consoleregex = { - "loss": re.compile(r"([a-zA-Z_]+):.*?(\d+\.\d+)"), - "tqdm": re.compile(r".*?(?P\d+%).*?(?P\d+/\d+)\W\[" - r"(?P\d+:\d+<.*),\W(?P.*)[a-zA-Z/]*\]")} + """ Control the underlying Faceswap tasks. + + wrapper: :class:`ProcessWrapper` + The object responsible for managing this faceswap task + """ + def __init__(self, wrapper: ProcessWrapper) -> None: + logger.debug("Initializing %s (wrapper: %s)", self.__class__.__name__, wrapper) + self._wrapper = wrapper + self._session_info = wrapper._training_session_location + self._config = get_config() + self._statusbar = self._config.statusbar + self._command: str | None = None + self._process: Popen | None = None + self._thread: LongRunningTask | None = None + self._train_stats: dict[T.Literal["iterations", "timestamp"], + int | float | None] = {"iterations": 0, "timestamp": None} + self._consoleregex: dict[T.Literal["loss", "tqdm", "ffmpeg"], re.Pattern] = { + "loss": re.compile(r"[\W]+(\d+)?[\W]+([a-zA-Z\s]*)[\W]+?(\d+\.\d+)"), + "tqdm": re.compile(r"(?P.*?)(?P\d+%).*?(?P\S+/\S+)\W\[" + r"(?P[\d+:]+<.*),\W(?P.*)[a-zA-Z/]*\]"), + "ffmpeg": re.compile(r"([a-zA-Z]+)=\s*(-?[\d|N/A]\S+)")} + self._first_loss_seen = False logger.debug("Initialized %s", self.__class__.__name__) - def execute_script(self, command, args): - """ Execute the requested Faceswap Script """ + @property + def command(self) -> str | None: + """ str | None: The currently executing command, when process running or ``None`` """ + return self._command + + def execute_script(self, command: str, args: list[str]) -> None: + """ Execute the requested Faceswap Script + + Parameters + ---------- + command: str + The faceswap command that is to be run + args: list[str] + The full command line arguments to be executed + """ logger.debug("Executing Faceswap: (command: '%s', args: %s)", command, args) - self.command = command - kwargs = {"stdout": PIPE, - "stderr": PIPE, - "bufsize": 1, - "universal_newlines": True} - - self.process = Popen(args, **kwargs, stdin=PIPE) - self.thread_stdout() - self.thread_stderr() + self._thread = None + self._command = command + + proc = Popen(args, # pylint:disable=consider-using-with + stdout=PIPE, + stderr=PIPE, + bufsize=1, + text=True, + stdin=PIPE, + encoding="utf-8", + errors="backslashreplace") + self._process = proc + self._thread_stdout() + self._thread_stderr() logger.debug("Executed Faceswap") - def read_stdout(self): - """ Read stdout from the subprocess. If training, pass the loss - values to Queue """ + def _process_training_determinate_function(self, output: str) -> bool: + """ Process an stdout/stderr message to check for determinate TQDM output when training + + Parameters + ---------- + output: str + The stdout/stderr string to test + + Returns + ------- + bool + ``True`` if a determinate TQDM line was parsed when training otherwise ``False`` + """ + if self._command == "train" and not self._first_loss_seen and self._capture_tqdm(output): + self._statusbar.set_mode("determinate") + return True + return False + + def _process_progress_stdout(self, output: str) -> bool: + """ Process stdout for any faceswap processes that update the status/progress bar(s) + + Parameters + ---------- + output: str + The output line read from stdout + + Returns + ------- + bool + ``True`` if all actions have been completed on the output line otherwise ``False`` + """ + if self._process_training_determinate_function(output): + return True + + if self._command == "train" and self._capture_loss(output): + return True + + if self._command == "train" and output.strip() == "\x1b[2K": # Clear line command for cli + return True + + if self._command == "effmpeg" and self._capture_ffmpeg(output): + return True + + if self._command not in ("train", "effmpeg") and self._capture_tqdm(output): + return True + + return False + + def _process_training_stdout(self, output: str) -> None: + """ Process any triggers that are required to update the GUI when Faceswap is running a + training session. + + Parameters + ---------- + output: str + The output line read from stdout + """ + tk_vars = get_config().tk_vars + if self._command != "train" or not tk_vars.is_training.get(): + return + + t_output = output.strip().lower() + if "[saved model]" not in t_output or t_output.endswith("[saved model]"): + # Not a saved model line or saving the model for a reason other than standard saving + return + + logger.debug("Trigger GUI Training update") + logger.trace("tk_vars: %s", {itm: var.get() # type:ignore[attr-defined] + for itm, var in tk_vars.__dict__.items()}) + if not Session.is_training: + # Don't initialize session until after the first save as state file must exist first + logger.debug("Initializing curret training session") + Session.initialize_session(self._session_info["model_folder"], + self._session_info["model_name"], + is_training=True) + tk_vars.refresh_graph.set(True) + + def _read_stdout(self) -> None: + """ Read stdout from the subprocess. """ logger.debug("Opening stdout reader") + assert self._process is not None while True: try: - output = self.process.stdout.readline() + buff = self._process.stdout + assert buff is not None + output: str = buff.readline() except ValueError as err: if str(err).lower().startswith("i/o operation on closed file"): break raise - if output == "" and self.process.poll() is not None: + + if output == "" and self._process.poll() is not None: break - if output: - if (self.command == "train" and self.capture_loss(output)) or ( - self.command != "train" and self.capture_tqdm(output)): - continue - if self.command == "train" and output.strip().endswith("saved models"): - logger.debug("Trigger update preview") - self.wrapper.tk_vars["updatepreview"].set(True) - print(output.strip()) - returncode = self.process.poll() - message = self.set_final_status(returncode) - self.wrapper.terminate(message) + + if output and self._process_progress_stdout(output): + continue + + if output.strip(): + self._process_training_stdout(output) + print(output.rstrip()) + + returncode = self._process.poll() + assert returncode is not None + self._first_loss_seen = False + message = self._set_final_status(returncode) + self._wrapper.terminate(message) logger.debug("Terminated stdout reader. returncode: %s", returncode) - def read_stderr(self): + def _read_stderr(self) -> None: """ Read stdout from the subprocess. If training, pass the loss values to Queue """ logger.debug("Opening stderr reader") + assert self._process is not None while True: try: - output = self.process.stderr.readline() + buff = self._process.stderr + assert buff is not None + output: str = buff.readline() except ValueError as err: if str(err).lower().startswith("i/o operation on closed file"): break raise - if output == "" and self.process.poll() is not None: + if output == "" and self._process.poll() is not None: break if output: - if self.command != "train" and self.capture_tqdm(output): + if self._command != "train" and self._capture_tqdm(output): + continue + if self._process_training_determinate_function(output): continue print(output.strip(), file=sys.stderr) logger.debug("Terminated stderr reader") - def thread_stdout(self): - """ Put the subprocess stdout so that it can be read without - blocking """ + def _thread_stdout(self) -> None: + """ Put the subprocess stdout so that it can be read without blocking """ logger.debug("Threading stdout") - thread = Thread(target=self.read_stdout) + thread = Thread(target=self._read_stdout) thread.daemon = True thread.start() logger.debug("Threaded stdout") - def thread_stderr(self): - """ Put the subprocess stderr so that it can be read without - blocking """ + def _thread_stderr(self) -> None: + """ Put the subprocess stderr so that it can be read without blocking """ logger.debug("Threading stderr") - thread = Thread(target=self.read_stderr) + thread = Thread(target=self._read_stderr) thread.daemon = True thread.start() logger.debug("Threaded stderr") - def capture_loss(self, string): - """ Capture loss values from stdout """ - logger.trace("Capturing loss") + def _capture_loss(self, string: str) -> bool: + """ Capture loss values from stdout + + Parameters + ---------- + string: str + An output line read from stdout + + Returns + ------- + bool + ``True`` if a loss line was captured from stdout, otherwise ``False`` + """ + logger.trace("Capturing loss") # type:ignore[attr-defined] if not str.startswith(string, "["): - logger.trace("Not loss message. Returning False") + logger.trace("Not loss message. Returning False") # type:ignore[attr-defined] return False - loss = self.consoleregex["loss"].findall(string) - if len(loss) < 2: - logger.trace("Not loss message. Returning False") + loss = self._consoleregex["loss"].findall(string) + if len(loss) != 2 or not all(len(itm) == 3 for itm in loss): + logger.trace("Not loss message. Returning False") # type:ignore[attr-defined] return False - message = "" - for item in loss: - message += "{}: {} ".format(item[0], item[1]) + message = f"Total Iterations: {int(loss[0][0])} | " + message += " ".join([f"{itm[1]}: {itm[2]}" for itm in loss]) if not message: - logger.trace("Error creating loss message. Returning False") + logger.trace( # type:ignore[attr-defined] + "Error creating loss message. Returning False") return False - iterations = self.train_stats["iterations"] + iterations = self._train_stats["iterations"] + assert isinstance(iterations, int) if iterations == 0: - # Initialize session stats and set initial timestamp - self.train_stats["timestamp"] = time() - - if not get_config().session.initialized and iterations > 0: - # Don't initialize session until after the first iteration as state - # file must exist first - get_config().session.initialize_session(is_training=True) - self.wrapper.tk_vars["refreshgraph"].set(True) + # Set initial timestamp + self._train_stats["timestamp"] = time() iterations += 1 - if iterations % 100 == 0: - self.wrapper.tk_vars["refreshgraph"].set(True) - self.train_stats["iterations"] = iterations - - elapsed = self.calc_elapsed() - message = "Elapsed: {} Iteration: {} {}".format(elapsed, - self.train_stats["iterations"], message) - self.statusbar.progress_update(message, 0, False) - logger.trace("Succesfully captured loss: %s", message) + self._train_stats["iterations"] = iterations + + elapsed = self._calculate_elapsed() + message = (f"Elapsed: {elapsed} | " + f"Session Iterations: {self._train_stats['iterations']} {message}") + + if not self._first_loss_seen: + self._statusbar.set_mode("indeterminate") + self._first_loss_seen = True + + self._statusbar.progress_update(message, 0, False) + logger.trace("Succesfully captured loss: %s", message) # type:ignore[attr-defined] return True - def calc_elapsed(self): - """ Calculate and format time since training started """ + def _calculate_elapsed(self) -> str: + """ Calculate and format time since training started + + Returns + ------- + str + The amount of time elapsed since training started in HH:mm:ss format + """ now = time() - elapsed_time = now - self.train_stats["timestamp"] + timestamp = self._train_stats["timestamp"] + assert isinstance(timestamp, float) + elapsed_time = now - timestamp try: - hrs = int(elapsed_time // 3600) - if hrs < 10: - hrs = "{0:02d}".format(hrs) - mins = "{0:02d}".format((int(elapsed_time % 3600) // 60)) - secs = "{0:02d}".format((int(elapsed_time % 3600) % 60)) + i_hrs = int(elapsed_time // 3600) + hrs = f"{i_hrs:02d}" if i_hrs < 10 else str(i_hrs) + mins = f"{(int(elapsed_time % 3600) // 60):02d}" + secs = f"{(int(elapsed_time % 3600) % 60):02d}" except ZeroDivisionError: - hrs = "00" - mins = "00" - secs = "00" - return "{}:{}:{}".format(hrs, mins, secs) - - def capture_tqdm(self, string): - """ Capture tqdm output for progress bar """ - logger.trace("Capturing tqdm") - tqdm = self.consoleregex["tqdm"].match(string) - if not tqdm: + hrs = mins = secs = "00" + return f"{hrs}:{mins}:{secs}" + + def _capture_tqdm(self, string: str) -> bool: + """ Capture tqdm output for progress bar + + Parameters + ---------- + string: str + An output line read from stdout + + Returns + ------- + bool + ``True`` if a tqdm line was captured from stdout, otherwise ``False`` + """ + logger.trace("Capturing tqdm") # type:ignore[attr-defined] + mtqdm = self._consoleregex["tqdm"].match(string) + if not mtqdm: return False - tqdm = tqdm.groupdict() + tqdm = mtqdm.groupdict() if any("?" in val for val in tqdm.values()): - logger.trace("tqdm initializing. Skipping") + logger.trace("tqdm initializing. Skipping") # type:ignore[attr-defined] return True - processtime = "Elapsed: {} Remaining: {}".format(tqdm["tme"].split("<")[0], - tqdm["tme"].split("<")[1]) - message = "{} | {} | {} | {}".format(processtime, - tqdm["rte"], - tqdm["itm"], - tqdm["pct"]) - - current, total = tqdm["itm"].split("/") - position = int((float(current) / float(total)) * 1000) - - self.statusbar.progress_update(message, position, True) - logger.trace("Succesfully captured tqdm message: %s", message) + description = tqdm["dsc"].strip() + description = description if description == "" else f"{description[:-1]} | " + processtime = (f"Elapsed: {tqdm['tme'].split('<')[0]} " + f"Remaining: {tqdm['tme'].split('<')[1]}") + msg = f"{description}{processtime} | {tqdm['rte']} | {tqdm['itm']} | {tqdm['pct']}" + + position = tqdm["pct"].replace("%", "") + position = int(position) if position.isdigit() else 0 + + self._statusbar.progress_update(msg, position, True) + logger.trace("Succesfully captured tqdm message: %s", msg) # type:ignore[attr-defined] return True - def terminate(self): - """ Terminate the subprocess """ + def _capture_ffmpeg(self, string: str) -> bool: + """ Capture ffmpeg output for progress bar + + Parameters + ---------- + string: str + An output line read from stdout + + Returns + ------- + bool + ``True`` if an ffmpeg line was captured from stdout, otherwise ``False`` + """ + logger.trace("Capturing ffmpeg") # type:ignore[attr-defined] + ffmpeg = self._consoleregex["ffmpeg"].findall(string) + if len(ffmpeg) < 7: + logger.trace("Not ffmpeg message. Returning False") # type:ignore[attr-defined] + return False + + message = "" + for item in ffmpeg: + message += f"{item[0]}: {item[1]} " + if not message: + logger.trace( # type:ignore[attr-defined] + "Error creating ffmpeg message. Returning False") + return False + + self._statusbar.progress_update(message, 0, False) + logger.trace("Succesfully captured ffmpeg message: %s", # type:ignore[attr-defined] + message) + return True + + def terminate(self) -> None: + """ Terminate the running process in a LongRunningTask so console can still be updated + console """ + if self._thread is None: + logger.debug("Terminating wrapper in LongRunningTask") + self._thread = LongRunningTask(target=self._terminate_in_thread, + args=(self._command, self._process)) + if self._command == "train": + get_config().tk_vars.is_training.set(False) + self._thread.start() + self._config.root.after(1000, self.terminate) + elif not self._thread.complete.is_set(): + logger.debug("Not finished terminating") + self._config.root.after(1000, self.terminate) + else: + logger.debug("Termination Complete. Cleaning up") + _ = self._thread.get_result() # Terminate the LongRunningTask object + self._thread = None + + def _terminate_in_thread(self, command: str, process: Popen) -> bool: + """ Terminate the subprocess + + Parameters + ---------- + command: str + The command that is running + + process: :class:`subprocess.Popen` + The running process + + Returns + ------- + bool + ``True`` when this function exits + """ logger.debug("Terminating wrapper") - if self.command == "train": + if command == "train": + timeout = cfg.timeout() logger.debug("Sending Exit Signal") print("Sending Exit Signal", flush=True) - try: - now = time() - if os.name == "nt": - try: - logger.debug("Sending carriage return to process") - self.process.communicate(input="\n", timeout=60) - except TimeoutExpired: - raise ValueError("Timeout reached sending Exit Signal") - else: - logger.debug("Sending SIGINT to process") - self.process.send_signal(signal.SIGINT) - while True: - timeelapsed = time() - now - if self.process.poll() is not None: - break - if timeelapsed > 60: - raise ValueError("Timeout reached sending Exit Signal") - return - except ValueError as err: - logger.error("Error terminating process", exc_info=True) - print(err) - else: - logger.debug("Terminating Process...") - print("Terminating Process...") - children = psutil.Process().children(recursive=True) - for child in children: - child.terminate() - _, alive = psutil.wait_procs(children, timeout=10) - if not alive: - logger.debug("Terminated") - print("Terminated") - return - - logger.debug("Termination timed out. Killing Process...") - print("Termination timed out. Killing Process...") - for child in alive: - child.kill() - _, alive = psutil.wait_procs(alive, timeout=10) - if not alive: - logger.debug("Killed") - print("Killed") + now = time() + if os.name == "nt": + logger.debug("Sending carriage return to process") + con_in = win32console.GetStdHandle( # pylint:disable=c-extension-no-member + win32console.STD_INPUT_HANDLE) # pylint:disable=c-extension-no-member + keypress = self._generate_windows_keypress("\n") + con_in.WriteConsoleInput([keypress]) else: - for child in alive: - msg = "Process {} survived SIGKILL. Giving up".format(child) - logger.debug(msg) - print(msg) + logger.debug("Sending SIGINT to process") + process.send_signal(signal.SIGINT) + while True: + timeelapsed = time() - now + if process.poll() is not None: + break + if timeelapsed > timeout: + logger.error("Timeout reached sending Exit Signal") + self._terminate_all_children() + else: + self._terminate_all_children() + return True + + @classmethod + def _generate_windows_keypress(cls, character: str) -> bytes: + """ Generate a Windows keypress + + Parameters + ---------- + character: str + The caracter to generate the keypress for + + Returns + ------- + bytes + The generated Windows keypress + """ + buf = win32console.PyINPUT_RECORDType( # pylint:disable=c-extension-no-member + win32console.KEY_EVENT) # pylint:disable=c-extension-no-member + buf.KeyDown = 1 + buf.RepeatCount = 1 + buf.Char = character + return buf + + @classmethod + def _terminate_all_children(cls) -> None: + """ Terminates all children """ + logger.debug("Terminating Process...") + print("Terminating Process...", flush=True) + children = psutil.Process().children(recursive=True) + for child in children: + child.terminate() + _, alive = psutil.wait_procs(children, timeout=10) + if not alive: + logger.debug("Terminated") + print("Terminated") + return - def set_final_status(self, returncode): - """ Set the status bar output based on subprocess return code """ + logger.debug("Termination timed out. Killing Process...") + print("Termination timed out. Killing Process...", flush=True) + for child in alive: + child.kill() + _, alive = psutil.wait_procs(alive, timeout=10) + if not alive: + logger.debug("Killed") + print("Killed") + else: + for child in alive: + msg = f"Process {child} survived SIGKILL. Giving up" + logger.debug(msg) + print(msg) + + def _set_final_status(self, returncode: int) -> str: + """ Set the status bar output based on subprocess return code and reset training stats + + Parameters + ---------- + returncode: int + The returncode from the terminated process + + Returns + ------- + str + The final statusbar text + """ logger.debug("Setting final status. returncode: %s", returncode) + self._train_stats = {"iterations": 0, "timestamp": None} if returncode in (0, 3221225786): status = "Ready" elif returncode == -15: - status = "Terminated - {}.py".format(self.command) + status = f"Terminated - {self._command}.py" elif returncode == -9: - status = "Killed - {}.py".format(self.command) + status = f"Killed - {self._command}.py" elif returncode == -6: - status = "Aborted - {}.py".format(self.command) + status = f"Aborted - {self._command}.py" else: - status = "Failed - {}.py. Return Code: {}".format(self.command, returncode) + status = f"Failed - {self._command}.py. Return Code: {returncode}" logger.debug("Set final status: %s", status) return status + + +__all__ = get_module_objects(__name__) diff --git a/lib/image.py b/lib/image.py new file mode 100644 index 0000000000..26eb1de46c --- /dev/null +++ b/lib/image.py @@ -0,0 +1,1703 @@ +#!/usr/bin python3 +""" Utilities for working with images and videos """ +from __future__ import annotations +import json +import logging +import re +import subprocess +import os +import struct +import sys +import typing as T + +from ast import literal_eval +from bisect import bisect +from concurrent import futures +from zlib import crc32 + +import cv2 +import imageio +import imageio_ffmpeg as im_ffm +import numpy as np +from tqdm import tqdm + +from lib.multithreading import MultiThread +from lib.queue_manager import queue_manager, QueueEmpty +from lib.utils import (convert_to_secs, FaceswapError, get_image_paths, + get_module_objects, VIDEO_EXTENSIONS) + + +if T.TYPE_CHECKING: + from lib.align.alignments import PNGHeaderDict + +logger = logging.getLogger(__name__) + +# ################### # +# <<< IMAGE UTILS >>> # +# ################### # + + +# <<< IMAGE IO >>> # + +class FfmpegReader(imageio.plugins.ffmpeg.FfmpegFormat.Reader): # type:ignore + """ Monkey patch imageio ffmpeg to use keyframes whilst seeking """ + def __init__(self, format, request): + super().__init__(format, request) + self._frame_pts = None + self._keyframes = None + self.use_patch = False + + def get_frame_info(self, frame_pts=None, keyframes=None): + """ Store the source video's keyframes in :attr:`_frame_info" for the current video for use + in :func:`initialize`. + + Parameters + ---------- + frame_pts: list, optional + A list corresponding to the video frame count of the pts_time per frame. If this and + `keyframes` are provided, then analyzing the video is skipped and the values from the + given lists are used. Default: ``None`` + keyframes: list, optional + A list containing the frame numbers of each key frame. if this and `frame_pts` are + provided, then analyzing the video is skipped and the values from the given lists are + used. Default: ``None`` + """ + if frame_pts is not None and keyframes is not None: + logger.debug("Video meta information provided. Not analyzing video") + self._frame_pts = frame_pts + self._keyframes = keyframes + return len(frame_pts), dict(pts_time=self._frame_pts, keyframes=self._keyframes) + + assert isinstance(self._filename, str), "Video path must be a string" + + # NB: The below video filter applies the detected frame rate prior to showinfo. This + # appears to help prevent an issue where the number of timestamp entries generated by + # showinfo does not correspond to the number of frames that the video file generates. + # This is because the demuxer will duplicate frames to meet the required frame rate. + # This **may** cause issues so be aware. + + # Also, drop frame rates (i.e 23.98, 29.97 and 59.94) will introduce rounding errors which + # means sync will drift on generated pts. These **should** be the only 'drop-frame rates' + # that appear in video files, but this is video files, and nothing is guaranteed. + # (The actual values for these should be 24000/1001, 30000/1001 and 60000/1001 + # respectively). The solutions to round these values is hacky at best, so: + # TODO find a more robust method for extracting/handling drop-frame rates. + + fps = self._meta["fps"] + rounded_fps = round(fps, 0) + if 0.01 < rounded_fps - fps < 0.10: # 0.90 - 0.99 + new_fps = f"{int(rounded_fps * 1000)}/1001" + logger.debug("Adjusting drop-frame fps: %s to %s", fps, new_fps) + fps = new_fps + + cmd = [im_ffm.get_ffmpeg_exe(), + "-hide_banner", + "-copyts", + "-i", self._filename, + "-vf", f"fps=fps={fps},showinfo", + "-start_number", "0", + "-an", + "-f", "null", + "-"] + logger.debug("FFMPEG Command: '%s'", " ".join(cmd)) + process = subprocess.Popen(cmd, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + universal_newlines=True) + frame_pts = [] + key_frames = [] + last_update = 0 + pbar = tqdm(desc="Analyzing Video", + leave=False, + total=int(self._meta["duration"]), + unit="secs") + while True: + output = process.stdout.readline().strip() + if output == "" and process.poll() is not None: + break + if "iskey" not in output: + continue + logger.trace("Keyframe line: %s", output) # type:ignore[attr-defined] + line = re.split(r"\s+|:\s*", output) + pts_time = float(line[line.index("pts_time") + 1]) + frame_no = int(line[line.index("n") + 1]) + frame_pts.append(pts_time) + if "iskey:1" in output: + key_frames.append(frame_no) + + logger.trace("pts_time: %s, frame_no: %s", # type:ignore[attr-defined] + pts_time, frame_no) + if int(pts_time) == last_update: + # Floating points make TQDM display poorly, so only update on full + # second increments + continue + pbar.update(int(pts_time) - last_update) + last_update = int(pts_time) + pbar.close() + return_code = process.poll() + frame_count = len(frame_pts) + logger.debug("Return code: %s, frame_pts: %s, keyframes: %s, frame_count: %s", + return_code, frame_pts, key_frames, frame_count) + + self._frame_pts = frame_pts + self._keyframes = key_frames + return frame_count, dict(pts_time=self._frame_pts, keyframes=self._keyframes) + + def _previous_keyframe_info(self, index=0): + """ Return the previous keyframe's pts_time and frame number """ + prev_keyframe_idx = bisect(self._keyframes, index) - 1 + prev_keyframe = self._keyframes[prev_keyframe_idx] + prev_pts_time = self._frame_pts[prev_keyframe] + logger.trace("keyframe pts_time: %s, keyframe: %s", # type:ignore[attr-defined] + prev_pts_time, prev_keyframe) + return prev_pts_time, prev_keyframe + + def _initialize(self, index=0): # noqa:C901 + """ Replace ImageIO _initialize with a version that explictly uses keyframes. + + Notes + ----- + This introduces a minor change by seeking fast to the previous keyframe and then discarding + subsequent frames until the desired frame is reached. In testing, setting -ss flag either + prior to input, or both prior (fast) and after (slow) would not always bring back the + correct frame for all videos. Navigating to the previous keyframe then discarding frames + until the correct frame is reached appears to work well. + """ + # pylint:disable-all + if self._read_gen is not None: + self._read_gen.close() + + iargs = [] + oargs = [] + skip_frames = 0 + + # Create input args + iargs += self._arg_input_params + if self.request._video: + iargs += ["-f", CAM_FORMAT] # noqa + if self._arg_pixelformat: + iargs += ["-pix_fmt", self._arg_pixelformat] + if self._arg_size: + iargs += ["-s", self._arg_size] + elif index > 0: # re-initialize / seek + # Note: only works if we initialized earlier, and now have meta. Some info here: + # https://trac.ffmpeg.org/wiki/Seeking + # There are two ways to seek, one before -i (input_params) and after (output_params). + # The former is fast, because it uses keyframes, the latter is slow but accurate. + # According to the article above, the fast method should also be accurate from ffmpeg + # version 2.1, however in version 4.1 our tests start failing again. Not sure why, but + # we can solve this by combining slow and fast. + # Further note: The old method would go back 10 seconds and then seek slow. This was + # still somewhat unresponsive and did not always land on the correct frame. This monkey + # patched version goes to the previous keyframe then discards frames until the correct + # frame is landed on. + if self.use_patch and self._frame_pts is None: + self.get_frame_info() + + if self.use_patch: + keyframe_pts, keyframe = self._previous_keyframe_info(index) + seek_fast = keyframe_pts + skip_frames = index - keyframe + else: + starttime = index / self._meta["fps"] + seek_slow = min(10, starttime) + seek_fast = starttime - seek_slow + + # We used to have this epsilon earlier, when we did not use + # the slow seek. I don't think we need it anymore. + # epsilon = -1 / self._meta["fps"] * 0.1 + iargs += ["-ss", "%.06f" % (seek_fast)] + if not self.use_patch: + oargs += ["-ss", "%.06f" % (seek_slow)] + + # Output args, for writing to pipe + if self._arg_size: + oargs += ["-s", self._arg_size] + if self.request.kwargs.get("fps", None): + fps = float(self.request.kwargs["fps"]) + oargs += ["-r", "%.02f" % fps] + oargs += self._arg_output_params + + # Get pixelformat and bytes per pixel + pix_fmt = self._pix_fmt + bpp = self._depth * self._bytes_per_channel + + # Create generator + rf = self._ffmpeg_api.read_frames + self._read_gen = rf( + self._filename, pix_fmt, bpp, input_params=iargs, output_params=oargs + ) + + # Read meta data. This start the generator (and ffmpeg subprocess) + if self.request._video: + # With cameras, catch error and turn into IndexError + try: + meta = self._read_gen.__next__() + except IOError as err: + err_text = str(err) + if "darwin" in sys.platform: + if "Unknown input format: 'avfoundation'" in err_text: + err_text += ( + "Try installing FFMPEG using " + "home brew to get a version with " + "support for cameras." + ) + raise IndexError( + "No camera at {}.\n\n{}".format(self.request._video, err_text) + ) + else: + self._meta.update(meta) + elif index == 0: + self._meta.update(self._read_gen.__next__()) + else: + if self.use_patch: + frames_skipped = 0 + while skip_frames != frames_skipped: + # Skip frames that are not the desired frame + _ = self._read_gen.__next__() + frames_skipped += 1 + self._read_gen.__next__() # we already have meta data + + +imageio.plugins.ffmpeg.FfmpegFormat.Reader = FfmpegReader # type: ignore + + +@T.overload +def read_image(filename: str, + raise_error: T.Literal[False] = False, + with_metadata: T.Literal[False] = False) -> np.ndarray | None: ... + + +@T.overload +def read_image(filename: str, + raise_error: T.Literal[True], + with_metadata: T.Literal[False] = False) -> np.ndarray: ... + + +@T.overload +def read_image(filename: str, + raise_error: T.Literal[False] = False, + *, + with_metadata: T.Literal[True]) -> tuple[np.ndarray, PNGHeaderDict]: ... + + +@T.overload +def read_image(filename: str, + raise_error: T.Literal[True], + with_metadata: T.Literal[True]) -> np.ndarray: ... + + +def read_image(filename: str, raise_error: bool = False, with_metadata: bool = False + ) -> np.ndarray | None | tuple[np.ndarray, PNGHeaderDict]: + """ Read an image file from a file location. + + Extends the functionality of :func:`cv2.imread()` by ensuring that an image was actually + loaded. Errors can be logged and ignored so that the process can continue on an image load + failure. + + Parameters + ---------- + filename : str + Full path to the image to be loaded. + raise_error: bool, optional + If ``True`` then any failures (including the returned image being ``None``) will be + raised. If ``False`` then an error message will be logged, but the error will not be + raised. Default: ``False`` + with_metadata : bool, optional + Only returns a value if the images loaded are extracted Faceswap faces. If ``True`` then + returns the Faceswap metadata stored with in a Face images .png exif header. + Default: ``False`` + + Returns + ------- + Returns + ------- + batch : :class:`numpy.ndarray` + The image in `BGR` channel order for the corresponding :attr:`filename` + metadata : :class:`~lib.align.alignments.PNGHeaderDict`, optional + The faceswap metadata corresponding to the image. Only returned if + `with_metadata` is ``True`` + + Example + ------- + >>> image_file = "/path/to/image.png" + >>> try: + >>> image = read_image(image_file, raise_error=True, with_metadata=False) + >>> except: + >>> raise ValueError("There was an error") + """ + logger.trace("Requested image: '%s'", filename) # type:ignore[attr-defined] + success = True + image = None + retval: np.ndarray | tuple[np.ndarray, PNGHeaderDict] | None = None + try: + with open(filename, "rb") as infile: + raw_file = infile.read() + image = cv2.imdecode(np.frombuffer(raw_file, dtype="uint8"), cv2.IMREAD_COLOR) + if image is None: + raise ValueError("Image is None") + if with_metadata: + metadata = T.cast("PNGHeaderDict", png_read_meta(raw_file)) + retval = (image, metadata) + else: + retval = image + except TypeError as err: + success = False + msg = "Error while reading image (TypeError): '{}'".format(filename) + msg += ". Original error message: {}".format(str(err)) + logger.error(msg) + if raise_error: + raise Exception(msg) + except ValueError as err: + success = False + msg = ("Error while reading image. This can be caused by special characters in the " + "filename or a corrupt image file: '{}'".format(filename)) + msg += ". Original error message: {}".format(str(err)) + logger.error(msg) + if raise_error: + raise Exception(msg) + except Exception as err: # pylint:disable=broad-except + success = False + msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err)) + logger.error(msg) + if raise_error: + raise Exception(msg) + logger.trace("Loaded image: '%s'. Success: %s", filename, success) # type:ignore[attr-defined] + return retval + + +@T.overload +def read_image_batch(filenames: list[str], with_metadata: T.Literal[False] = False + ) -> np.ndarray: ... + + +@T.overload +def read_image_batch(filenames: list[str], with_metadata: T.Literal[True] + ) -> tuple[np.ndarray, list[PNGHeaderDict]]: ... + + +def read_image_batch(filenames: list[str], with_metadata: bool = False + ) -> np.ndarray | tuple[np.ndarray, list[PNGHeaderDict]]: + """ Load a batch of images from the given file locations. + + Leverages multi-threading to load multiple images from disk at the same time leading to vastly + reduced image read times. + + Parameters + ---------- + filenames : list[str] + A of full paths to the images to be loaded. + with_metadata : bool, optional + Only returns a value if the images loaded are extracted Faceswap faces. If ``True`` then + returns the Faceswap metadata stored within each Face's .png exif header. + Default: ``False`` + + Returns + ------- + batch : :class:`numpy.ndarray` + The batch of images in `BGR` channel order returned in the order of :attr:`filenames` + metadata : list[:class:`~lib.align.alignments.PNGHeaderDict`], optional + The faceswap metadata corresponding to each image in the batch. Only returned if + `with_metadata` is ``True`` + + Notes + ----- + As the images are compiled into a batch, they should be all of the same dimensions, otherwise a + homongenous array will be returned + + Example + ------- + >>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"] + >>> images = read_image_batch(image_filenames) + >>> print(images.shape) + ... (3, 64, 64, 3) + >>> images, metatdata = read_image_batch(image_filenames, with_metadata=True) + >>> print(images.shape) + ... (3, 64, 64, 3) + >>> print(len(metadata)) + ... 3 + """ + logger.trace("Requested batch: '%s'", filenames) # type:ignore[attr-defined] + batch: list[np.ndarray | None] = [None for _ in range(len(filenames))] + meta: list[PNGHeaderDict | None] = [None for _ in range(len(filenames))] + + with futures.ThreadPoolExecutor() as executor: + images = {executor.submit( # NOTE submit strips positionals, breaking type-checking + read_image, # type:ignore[arg-type] + filename, + raise_error=True, # pyright:ignore[reportArgumentType] + with_metadata=with_metadata): idx # pyright:ignore[reportArgumentType] + for idx, filename in enumerate(filenames)} + + for future in futures.as_completed(images): + result = T.cast(np.ndarray | tuple[np.ndarray, "PNGHeaderDict"], future.result()) + ret_idx = images[future] + if with_metadata: + assert isinstance(result, tuple) + batch[ret_idx], meta[ret_idx] = result + else: + assert isinstance(result, np.ndarray) + batch[ret_idx] = result + + arr_batch = np.array(batch) + retval: np.ndarray | tuple[np.ndarray, list[PNGHeaderDict]] + if with_metadata: + retval = (arr_batch, T.cast(list["PNGHeaderDict"], meta)) + else: + retval = arr_batch + + logger.trace( # type:ignore[attr-defined] + "Returning images: (filenames: %s, batch shape: %s, with_metadata: %s)", + filenames, arr_batch.shape, with_metadata) + return retval + + +def read_image_meta(filename): + """ Read the Faceswap metadata stored in an extracted face's exif header. + + Parameters + ---------- + filename: str + Full path to the image to be retrieve the meta information for. + + Returns + ------- + dict + The output dictionary will contain the `width` and `height` of the png image as well as any + `itxt` information. + Example + ------- + >>> image_file = "/path/to/image.png" + >>> metadata = read_image_meta(image_file) + >>> width = metadata["width] + >>> height = metadata["height"] + >>> faceswap_info = metadata["itxt"] + """ + retval = dict() + if os.path.splitext(filename)[-1].lower() != ".png": + # Get the dimensions directly from the image for non-pngs + logger.trace( # type:ignore[attr-defined] + "Non png found. Loading file for dimensions: '%s'", + filename) + img = cv2.imread(filename) + retval["height"], retval["width"] = img.shape[:2] + return retval + with open(filename, "rb") as infile: + try: + chunk = infile.read(8) + except PermissionError: + raise PermissionError(f"PermissionError while reading: {filename}") + + if chunk != b"\x89PNG\r\n\x1a\n": + raise ValueError(f"Invalid header found in png: {filename}") + + while True: + chunk = infile.read(8) + length, field = struct.unpack(">I4s", chunk) + logger.trace( # type:ignore[attr-defined] + "Read chunk: (chunk: %s, length: %s, field: %s", + chunk, length, field) + if not chunk or field == b"IDAT": + break + if field == b"IHDR": + # Get dimensions + chunk = infile.read(8) + retval["width"], retval["height"] = struct.unpack(">II", chunk) + length -= 8 + elif field == b"iTXt": + keyword, value = infile.read(length).split(b"\0", 1) + if keyword == b"faceswap": + retval["itxt"] = literal_eval(value[4:].decode("utf-8", errors="replace")) + break + else: + logger.trace("Skipping iTXt chunk: '%s'", # type:ignore[attr-defined] + keyword.decode("latin-1", errors="ignore")) + length = 0 # Reset marker for next chunk + infile.seek(length + 4, 1) + logger.trace("filename: %s, metadata: %s", filename, retval) # type:ignore[attr-defined] + return retval + + +def read_image_meta_batch(filenames): + """ Read the Faceswap metadata stored in a batch extracted faces' exif headers. + + Leverages multi-threading to load multiple images from disk at the same time + leading to vastly reduced image read times. Creates a generator to retrieve filenames + with their metadata as they are calculated. + + Notes + ----- + The order of returned values is non-deterministic so will most likely not be returned in the + same order as the filenames + + Parameters + ---------- + filenames: list + A list of ``str`` full paths to the images to be loaded. + + Yields + ------- + tuple + (**filename** (`str`), **metadata** (`dict`) ) + + Example + ------- + >>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"] + >>> for filename, meta in read_image_meta_batch(image_filenames): + >>> + """ + logger.trace("Requested batch: '%s'", filenames) # type:ignore[attr-defined] + executor = futures.ThreadPoolExecutor() + with executor: + logger.debug("Submitting %s items to executor", len(filenames)) + read_meta = {executor.submit(read_image_meta, filename): filename + for filename in filenames} + logger.debug("Succesfully submitted %s items to executor", len(filenames)) + for future in futures.as_completed(read_meta): + retval = (read_meta[future], future.result()) + logger.trace("Yielding: %s", retval) # type:ignore[attr-defined] + yield retval + + +def pack_to_itxt(metadata): + """ Pack the given metadata dictionary to a PNG iTXt header field. + + Parameters + ---------- + metadata: dict or bytes + The dictionary to write to the header. Can be pre-encoded as utf-8. + + Returns + ------- + bytes + A byte encoded PNG iTXt field, including chunk header and CRC + """ + if not isinstance(metadata, bytes): + metadata = str(metadata).encode("utf-8", "strict") + key = "faceswap".encode("latin-1", "strict") + + chunk = key + b"\0\0\0\0\0" + metadata + crc = struct.pack(">I", crc32(chunk, crc32(b"iTXt")) & 0xFFFFFFFF) + length = struct.pack(">I", len(chunk)) + retval = length + b"iTXt" + chunk + crc + return retval + + +def update_existing_metadata(filename, metadata): + """ Update the png header metadata for an existing .png extracted face file on the filesystem. + + Parameters + ---------- + filename: str + The full path to the face to be updated + metadata: dict or bytes + The dictionary to write to the header. Can be pre-encoded as utf-8. + """ + + tmp_filename = filename + "~" + with open(filename, "rb") as png, open(tmp_filename, "wb") as tmp: + chunk = png.read(8) + if chunk != b"\x89PNG\r\n\x1a\n": + raise ValueError(f"Invalid header found in png: {filename}") + tmp.write(chunk) + + while True: + chunk = png.read(8) + length, field = struct.unpack(">I4s", chunk) + logger.trace( # type:ignore[attr-defined] + "Read chunk: (chunk: %s, length: %s, field: %s)", + chunk, length, field) + + if field == b"IDAT": # Write out all remaining data + logger.trace("Writing image data and closing png") # type:ignore[attr-defined] + tmp.write(chunk + png.read()) + break + + if field != b"iTXt": # Write non iTXt chunk straight out + logger.trace("Copying existing chunk") # type:ignore[attr-defined] + tmp.write(chunk + png.read(length + 4)) # Header + CRC + continue + + keyword, value = png.read(length).split(b"\0", 1) + if keyword != b"faceswap": + # Write existing non fs-iTXt data + CRC + logger.trace("Copying non-faceswap iTXt chunk: %s", # type:ignore[attr-defined] + keyword) + tmp.write(keyword + b"\0" + value + png.read(4)) + continue + + logger.trace("Updating faceswap iTXt chunk") # type:ignore[attr-defined] + tmp.write(pack_to_itxt(metadata)) + png.seek(4, 1) # Skip old CRC + + os.replace(tmp_filename, filename) + + +def encode_image(image: np.ndarray, + extension: str, + encoding_args: tuple[int, ...] | None = None, + metadata: PNGHeaderDict | dict[str, T.Any] | bytes | None = None) -> bytes: + """ Encode an image. + + Parameters + ---------- + image: numpy.ndarray + The image to be encoded in `BGR` channel order. + extension: str + A compatible `cv2` image file extension that the final image is to be saved to. + encoding_args: tuple[int, ...], optional + Any encoding arguments to pass to cv2's imencode function + metadata: dict or bytes, optional + Metadata for the image. If provided, and the extension is png or tiff, this information + will be written to the PNG itxt header. Default:``None`` Can be provided as a python dict + or pre-encoded + + Returns + ------- + encoded_image: bytes + The image encoded into the correct file format as bytes + + Example + ------- + >>> image_file = "/path/to/image.png" + >>> image = read_image(image_file) + >>> encoded_image = encode_image(image, ".jpg") + """ + if metadata and extension.lower() not in (".png", ".tif"): + raise ValueError("Metadata is only supported for .png and .tif images") + args = tuple() if encoding_args is None else encoding_args + + retval = cv2.imencode(extension, image, args)[1].tobytes() + if metadata: + func = {".png": png_write_meta, ".tif": tiff_write_meta}[extension] + retval = func(retval, metadata) + return retval + + +def png_write_meta(image: bytes, data: PNGHeaderDict | dict[str, T.Any] | bytes) -> bytes: + """ Write Faceswap information to a png's iTXt field. + + Parameters + ---------- + image: bytes + The bytes encoded png file to write header data to + data: dict or bytes + The dictionary to write to the header. Can be pre-encoded as utf-8. + + Notes + ----- + This is a fairly stripped down and non-robust header writer to fit a very specific task. OpenCV + will not write any iTXt headers to the PNG file, so we make the assumption that the only iTXt + header that exists is the one that we created for storing alignments. + + References + ---------- + PNG Specification: https://www.w3.org/TR/2003/REC-PNG-20031110/ + + """ + split = image.find(b"IDAT") - 4 + retval = image[:split] + pack_to_itxt(data) + image[split:] + return retval + + +def tiff_write_meta(image: bytes, data: PNGHeaderDict | dict[str, T.Any] | bytes) -> bytes: + """ Write Faceswap information to a tiff's image_description field. + + Parameters + ---------- + png: bytes + The bytes encoded tiff file to write header data to + data: dict or bytes + The data to write to the image-description field. If provided as a dict, then it should be + a json serializable object, otherwise it should be data encoded as ascii bytes + + Notes + ----- + This handles a very specific task of adding, and populating, an ImageDescription field in a + Tiff file generated by OpenCV. For any other usecases it will likely fail + """ + if not isinstance(data, bytes): + data = json.dumps(data, ensure_ascii=True).encode("ascii") + + assert image[:2] == b"II", "Not a supported TIFF file" + assert struct.unpack(" 270: + insert_idx = i # Log insert location of image description + + if size <= 4: # value in offset column + ifd += tag + continue + + ifd += tag[:8] + tag_offset = struct.unpack(" dict[str, T.Any]: + """ Read information stored in a Tiff's Image Description field + + Returns + ------- + dict[str, Any] + Any arbitrary information stored in the TIFF header (for example matrix information for + the patch writer) + """ + assert image[:2] == b"II", "Not a supported TIFF file" + assert struct.unpack(" PNGHeaderDict | dict[str, T.Any]: + """ Read the Faceswap information stored in a png's iTXt field. + + Parameters + ---------- + image: bytes + The bytes encoded png file to read header data from + + Returns + ------- + :class:`~lib.align.alignments.PNGHeaderDict` | dict[str, Any] + The Faceswap information stored in the PNG header. This will either be a PNGHeaderDict + if an extracted face, or other arbitrary information (for example for the Patch Writer) + + Notes + ----- + This is a very stripped down, non-robust and non-secure header reader to fit a very specific + task. OpenCV will not write any iTXt headers to the PNG file, so we make the assumption that + the only iTXt header that exists is the one that Faceswap created for storing alignments. + """ + retval: PNGHeaderDict | None = None + pointer = 0 + while True: + pointer = image.find(b"iTXt", pointer) - 4 + if pointer < 0: + logger.trace("No metadata in png") # type:ignore[attr-defined] + break + length = struct.unpack(">I", image[pointer:pointer + 4])[0] + pointer += 8 + keyword, value = image[pointer:pointer + length].split(b"\0", 1) + if keyword == b"faceswap": + retval = literal_eval(value[4:].decode("utf-8", errors="ignore")) + break + logger.trace("Skipping iTXt chunk: '%s'", # type:ignore[attr-defined] + keyword.decode("latin-1", errors="ignore")) + pointer += length + 4 + assert retval is not None + return retval + + +def generate_thumbnail(image, size=96, quality=60): + """ Generate a jpg thumbnail for the given image. + + Parameters + ---------- + image: :class:`numpy.ndarray` + Three channel BGR image to convert to a jpg thumbnail + size: int + The width and height, in pixels, that the thumbnail should be generated at + quality: int + The jpg quality setting to use + + Returns + ------- + :class:`numpy.ndarray` + The given image encoded to a jpg at the given size and quality settings + """ + logger.trace("Input shape: %s, size: %s, quality: %s", # type:ignore[attr-defined] + image.shape, size, quality) + orig_size = image.shape[0] + if orig_size != size: + interp = cv2.INTER_AREA if orig_size > size else cv2.INTER_CUBIC + image = cv2.resize(image, (size, size), interpolation=interp) + retval = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1] + logger.trace("Output shape: %s", retval.shape) # type:ignore[attr-defined] + return retval + + +def batch_convert_color(batch, colorspace): + """ Convert a batch of images from one color space to another. + + Converts a batch of images by reshaping the batch prior to conversion rather than iterating + over the images. This leads to a significant speed up in the convert process. + + Parameters + ---------- + batch: numpy.ndarray + A batch of images. + colorspace: str + The OpenCV Color Conversion Code suffix. For example for BGR to LAB this would be + ``'BGR2LAB'``. + See https://docs.opencv.org/4.1.1/d8/d01/group__imgproc__color__conversions.html for a full + list of color codes. + + Returns + ------- + numpy.ndarray + The batch converted to the requested color space. + + Example + ------- + >>> images_bgr = numpy.array([image1, image2, image3]) + >>> images_lab = batch_convert_color(images_bgr, "BGR2LAB") + + Notes + ----- + This function is only compatible for color space conversions that have the same image shape + for source and destination color spaces. + + If you use :func:`batch_convert_color` with 8-bit images, the conversion will have some + information lost. For many cases, this will not be noticeable but it is recommended + to use 32-bit images in cases that need the full range of colors or that convert an image + before an operation and then convert back. + """ + logger.trace( # type:ignore[attr-defined] + "Batch converting: (batch shape: %s, colorspace: %s)", + batch.shape, colorspace) + original_shape = batch.shape + batch = batch.reshape((original_shape[0] * original_shape[1], *original_shape[2:])) + batch = cv2.cvtColor(batch, getattr(cv2, "COLOR_{}".format(colorspace))) + return batch.reshape(original_shape) + + +def hex_to_rgb(hexcode): + """ Convert a hex number to it's RGB counterpart. + + Parameters + ---------- + hexcode: str + The hex code to convert (e.g. `"#0d25ac"`) + + Returns + ------- + tuple + The hex code as a 3 integer (`R`, `G`, `B`) tuple + """ + value = hexcode.lstrip("#") + chars = len(value) + return tuple(int(value[i:i + chars // 3], 16) for i in range(0, chars, chars // 3)) + + +def rgb_to_hex(rgb): + """ Convert an RGB tuple to it's hex counterpart. + + Parameters + ---------- + rgb: tuple + The (`R`, `G`, `B`) integer values to convert (e.g. `(0, 255, 255)`) + + Returns + ------- + str: + The 6 digit hex code with leading `#` applied + """ + return "#{:02x}{:02x}{:02x}".format(*rgb) + + +# ################### # +# <<< VIDEO UTILS >>> # +# ################### # + +def count_frames(filename, fast=False): + """ Count the number of frames in a video file + + There is no guaranteed accurate way to get a count of video frames without iterating through + a video and decoding every frame. + + :func:`count_frames` can return an accurate count (albeit fairly slowly) or a possibly less + accurate count, depending on the :attr:`fast` parameter. A progress bar is displayed. + + Parameters + ---------- + filename: str + Full path to the video to return the frame count from. + fast: bool, optional + Whether to count the frames without decoding them. This is significantly faster but + accuracy is not guaranteed. Default: ``False``. + + Returns + ------- + int: + The number of frames in the given video file. + + Example + ------- + >>> filename = "/path/to/video.mp4" + >>> frame_count = count_frames(filename) + """ + logger.debug("filename: %s, fast: %s", filename, fast) + assert isinstance(filename, str), "Video path must be a string" + + cmd = [im_ffm.get_ffmpeg_exe(), "-i", filename, "-map", "0:v:0"] + if fast: + cmd.extend(["-c", "copy"]) + cmd.extend(["-f", "null", "-"]) + + logger.debug("FFMPEG Command: '%s'", " ".join(cmd)) + process = subprocess.Popen(cmd, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + universal_newlines=True, encoding="utf8") + pbar = None + duration = None + init_tqdm = False + update = 0 + frames = 0 + while True: + output = process.stdout.readline().strip() + if output == "" and process.poll() is not None: + break + + if output.startswith("Duration:"): + logger.debug("Duration line: %s", output) + idx = output.find("Duration:") + len("Duration:") + duration = int(convert_to_secs(*output[idx:].split(",", 1)[0].strip().split(":"))) + logger.debug("duration: %s", duration) + if output.startswith("frame="): + logger.debug("frame line: %s", output) + if not init_tqdm: + logger.debug("Initializing tqdm") + pbar = tqdm(desc="Analyzing Video", leave=False, total=duration, unit="secs") + init_tqdm = True + time_idx = output.find("time=") + len("time=") + frame_idx = output.find("frame=") + len("frame=") + frames = int(output[frame_idx:].strip().split(" ")[0].strip()) + vid_time = int(convert_to_secs(*output[time_idx:].split(" ")[0].strip().split(":"))) + logger.debug("frames: %s, vid_time: %s", frames, vid_time) + prev_update = update + update = vid_time + pbar.update(update - prev_update) + if pbar is not None: + pbar.close() + return_code = process.poll() + logger.debug("Return code: %s, frames: %s", return_code, frames) + return frames + + +class ImageIO(): + """ Perform disk IO for images or videos in a background thread. + + This is the parent thread for :class:`ImagesLoader` and :class:`ImagesSaver` and should not + be called directly. + + Parameters + ---------- + path: str or list + The path to load or save images to/from. For loading this can be a folder which contains + images, video file or a list of image files. For saving this must be an existing folder. + queue_size: int + The amount of images to hold in the internal buffer. + args: tuple, optional + The arguments to be passed to the loader or saver thread. Default: ``None`` + + See Also + -------- + lib.image.ImagesLoader : Background Image Loader inheriting from this class. + lib.image.ImagesSaver : Background Image Saver inheriting from this class. + """ + + def __init__(self, path, queue_size, args=None): + logger.debug("Initializing %s: (path: %s, queue_size: %s, args: %s)", + self.__class__.__name__, path, queue_size, args) + + self._args = tuple() if args is None else args + + self._location = path + self._check_location_exists() + + queue_name = queue_manager.add_queue(name=self.__class__.__name__, + maxsize=queue_size, + create_new=True) + self._queue = queue_manager.get_queue(queue_name) + self._thread = None + + @property + def location(self): + """ str: The folder or video that was passed in as the :attr:`path` parameter. """ + return self._location + + def _check_location_exists(self): + """ Check whether the input location exists. + + Raises + ------ + FaceswapError + If the given location does not exist + """ + if isinstance(self.location, str) and not os.path.exists(self.location): + raise FaceswapError("The location '{}' does not exist".format(self.location)) + if isinstance(self.location, (list, tuple)) and not all(os.path.exists(location) + for location in self.location): + raise FaceswapError("Not all locations in the input list exist") + + def _set_thread(self): + """ Set the background thread for the load and save iterators and launch it. """ + logger.trace("Setting thread") # type:ignore[attr-defined] + if self._thread is not None and self._thread.is_alive(): + logger.trace("Thread pre-exists and is alive: %s", # type:ignore[attr-defined] + self._thread) + return + self._thread = MultiThread(self._process, + self._queue, + name=self.__class__.__name__, + thread_count=1) + logger.debug("Set thread: %s", self._thread) + self._thread.start() + + def _process(self, queue): + """ Image IO process to be run in a thread. Override for loader/saver process. + + Parameters + ---------- + queue: queue.Queue() + The ImageIO Queue + """ + raise NotImplementedError + + def close(self): + """ Closes down and joins the internal threads """ + logger.debug("Received Close") + if self._thread is not None: + self._thread.join() + del self._thread + self._thread = None + logger.debug("Closed") + + +class ImagesLoader(ImageIO): + """ Perform image loading from a folder of images or a video. + + Images will be loaded and returned in the order that they appear in the folder, or in the video + to ensure deterministic ordering. Loading occurs in a background thread, caching 8 images at a + time so that other processes do not need to wait on disk reads. + + See also :class:`ImageIO` for additional attributes. + + Parameters + ---------- + path: str or list + The path to load images from. This can be a folder which contains images a video file or a + list of image files. + queue_size: int, optional + The amount of images to hold in the internal buffer. Default: 8. + fast_count: bool, optional + When loading from video, the video needs to be parsed frame by frame to get an accurate + count. This can be done quite quickly without guaranteed accuracy, or slower with + guaranteed accuracy. Set to ``True`` to count quickly, or ``False`` to count slower + but accurately. Default: ``True``. + skip_list: list, optional + Optional list of frame/image indices to not load. Any indices provided here will be skipped + when executing the :func:`load` function from the given location. Default: ``None`` + count: int, optional + If the number of images that the loader will encounter is already known, it can be passed + in here to skip the image counting step, which can save time at launch. Set to ``None`` if + the count is not already known. Default: ``None`` + + Examples + -------- + Loading from a video file: + + >>> loader = ImagesLoader('/path/to/video.mp4') + >>> for filename, image in loader.load(): + >>> + """ + + def __init__(self, + path: str | list[str], + queue_size: int = 8, + fast_count: bool = True, + skip_list: list[int] | None = None, + count: int | None = None) -> None: + logger.debug("Initializing %s: (path: %s, queue_size: %s, fast_count: %s, skip_list: %s, " + "count: %s)", self.__class__.__name__, path, queue_size, fast_count, + skip_list, count) + + super().__init__(path, queue_size=queue_size) + self._skip_list = set() if skip_list is None else set(skip_list) + self._is_video = self._check_for_video() + self._fps = self._get_fps() + + self._count = None + self._file_list: list[str] = [] + self._get_count_and_filelist(fast_count, count) + + @property + def count(self) -> int: + """ int: The number of images or video frames in the source location. This count includes + any files that will ultimately be skipped if a :attr:`skip_list` has been provided. See + also: :attr:`process_count`""" + assert self._count is not None + return self._count + + @property + def process_count(self) -> int: + """ int: The number of images or video frames to be processed (IE the total count less + items that are to be skipped from the :attr:`skip_list`)""" + return self.count - len(self._skip_list) + + @property + def is_video(self): + """ bool: ``True`` if the input is a video, ``False`` if it is not """ + return self._is_video + + @property + def fps(self): + """ float: For an input folder of images, this will always return 25fps. If the input is a + video, then the fps of the video will be returned. """ + return self._fps + + @property + def file_list(self) -> list[str]: + """ list[str]: A full list of files in the source location. This includes any files that + will ultimately be skipped if a :attr:`skip_list` has been provided. If the input is a + video then this is a list of dummy filenames as corresponding to an alignments file """ + return self._file_list + + def add_skip_list(self, skip_list): + """ Add a skip list to this :class:`ImagesLoader` + + Parameters + ---------- + skip_list: list + A list of indices corresponding to the frame indices that should be skipped by the + :func:`load` function. + """ + logger.debug(skip_list) + self._skip_list = set(skip_list) + + def _check_for_video(self): + """ Check whether the input is a video + + Returns + ------- + bool: 'True' if input is a video 'False' if it is a folder. + + Raises + ------ + FaceswapError + If the given location is a file and does not have a valid video extension. + + """ + if not isinstance(self.location, str) or os.path.isdir(self.location): + retval = False + elif os.path.splitext(self.location)[1].lower() in VIDEO_EXTENSIONS: + retval = True + else: + raise FaceswapError("The input file '{}' is not a valid video".format(self.location)) + logger.debug("Input '%s' is_video: %s", self.location, retval) + return retval + + def _get_fps(self): + """ Get the Frames per Second. + + If the input is a folder of images than 25.0 will be returned, as it is not possible to + calculate the fps just from frames alone. For video files the correct FPS will be returned. + + Returns + ------- + float: The Frames per Second of the input sources + """ + if self._is_video: + reader = imageio.get_reader(self.location, "ffmpeg") + retval = reader.get_meta_data()["fps"] + reader.close() + else: + retval = 25.0 + logger.debug(retval) + return retval + + def _get_count_and_filelist(self, fast_count, count): + """ Set the count of images to be processed and set the file list + + If the input is a video, a dummy file list is created for checking against an + alignments file, otherwise it will be a list of full filenames. + + Parameters + ---------- + fast_count: bool + When loading from video, the video needs to be parsed frame by frame to get an accurate + count. This can be done quite quickly without guaranteed accuracy, or slower with + guaranteed accuracy. Set to ``True`` to count quickly, or ``False`` to count slower + but accurately. + count: int + The number of images that the loader will encounter if already known, otherwise + ``None`` + """ + if self._is_video: + self._count = int(count_frames(self.location, + fast=fast_count)) if count is None else count + self._file_list = [self._dummy_video_framename(i) for i in range(self.count)] + else: + if isinstance(self.location, (list, tuple)): + self._file_list = self.location + else: + self._file_list = get_image_paths(self.location) + self._count = len(self.file_list) if count is None else count + + logger.debug("count: %s", self.count) + logger.trace("filelist: %s", self.file_list) # type:ignore[attr-defined] + + def _process(self, queue): + """ The load thread. + + Loads from a folder of images or from a video and puts to a queue + + Parameters + ---------- + queue: queue.Queue() + The ImageIO Queue + """ + iterator = self._from_video if self._is_video else self._from_folder + logger.debug("Load iterator: %s", iterator) + for retval in iterator(): + filename, image = retval[:2] + if image is None or (not image.any() and image.ndim not in (2, 3)): + # All black frames will return not numpy.any() so check dims too + logger.warning("Unable to open image. Skipping: '%s'", filename) + continue + logger.trace("Putting to queue: %s", # type:ignore[attr-defined] + [v.shape if isinstance(v, np.ndarray) else v for v in retval]) + queue.put(retval) + logger.trace("Putting EOF") # type:ignore[attr-defined] + queue.put("EOF") + + def _from_video(self): + """ Generator for loading frames from a video + + Yields + ------ + filename: str + The dummy filename of the loaded video frame. + image: numpy.ndarray + The loaded video frame. + """ + logger.debug("Loading frames from video: '%s'", self.location) + reader = imageio.get_reader(self.location, "ffmpeg") + for idx, frame in enumerate(reader): + if idx in self._skip_list: + logger.trace("Skipping frame %s due to skip list", # type:ignore[attr-defined] + idx) + continue + # Convert to BGR for cv2 compatibility + frame = frame[:, :, ::-1] + filename = self._dummy_video_framename(idx) + logger.trace("Loading video frame: '%s'", filename) # type:ignore[attr-defined] + yield filename, frame + reader.close() + + def _dummy_video_framename(self, index): + """ Return a dummy filename for video files. The file name is made up of: + _. + + Parameters + ---------- + index: int + The index number for the frame in the video file + + Notes + ----- + Indexes start at 0, frame numbers start at 1, so index is incremented by 1 + when creating the filename + + Returns + ------- + str: A dummied filename for a video frame """ + vidname, ext = os.path.splitext(os.path.basename(self.location)) + return f"{vidname}_{index + 1:06d}{ext}" + + def _from_folder(self): + """ Generator for loading images from a folder + + Yields + ------ + filename: str + The filename of the loaded image. + image: numpy.ndarray + The loaded image. + """ + logger.debug("Loading frames from folder: '%s'", self.location) + for idx, filename in enumerate(self.file_list): + if idx in self._skip_list: + logger.trace("Skipping frame %s due to skip list") # type:ignore[attr-defined] + continue + image_read = read_image(filename, raise_error=False) + retval = filename, image_read + if retval[1] is None: + logger.warning("Frame not loaded: '%s'", filename) + continue + yield retval + + def load(self): + """ Generator for loading images from the given :attr:`location` + + If :class:`FacesLoader` is in use then the Faceswap metadata of the image stored in the + image exif file is added as the final item in the output `tuple`. + + Yields + ------ + filename: str + The filename of the loaded image. + image: numpy.ndarray + The loaded image. + metadata: dict, (:class:`FacesLoader` only) + The Faceswap metadata associated with the loaded image. + """ + logger.debug("Initializing Load Generator") + self._set_thread() + while True: + self._thread.check_and_raise_error() + try: + retval = self._queue.get(True, 1) + except QueueEmpty: + continue + if retval == "EOF": + logger.trace("Got EOF") # type:ignore[attr-defined] + break + logger.trace("Yielding: %s", # type:ignore[attr-defined] + [v.shape if isinstance(v, np.ndarray) else v for v in retval]) + yield retval + logger.debug("Closing Load Generator") + self.close() + + +class FacesLoader(ImagesLoader): + """ Loads faces from a faces folder along with the face's Faceswap metadata. + + Examples + -------- + Loading faces with their Faceswap metadata: + + >>> loader = FacesLoader('/path/to/faces/folder') + >>> for filename, face, metadata in loader.load(): + >>> + """ + def __init__(self, path, skip_list=None, count=None): + logger.debug("Initializing %s: (path: %s, count: %s)", self.__class__.__name__, + path, count) + super().__init__(path, queue_size=8, skip_list=skip_list, count=count) + + def _get_count_and_filelist(self, fast_count, count): + """ Override default implementation to only return png files from the source folder + + Parameters + ---------- + fast_count: bool + Not used for faces loader + count: int + The number of images that the loader will encounter if already known, otherwise + ``None`` + """ + if isinstance(self.location, (list, tuple)): + file_list = self.location + else: + file_list = get_image_paths(self.location) + + self._file_list = [fname for fname in file_list + if os.path.splitext(fname)[-1].lower() == ".png"] + self._count = len(self.file_list) if count is None else count + + logger.debug("count: %s", self.count) + logger.trace("filelist: %s", self.file_list) # type:ignore[attr-defined] + + def _from_folder(self): + """ Generator for loading images from a folder + Faces will only ever be loaded from a folder, so this is the only function requiring + an override + + Yields + ------ + filename: str + The filename of the loaded image. + image: numpy.ndarray + The loaded image. + metadata: dict + The Faceswap metadata associated with the loaded image. + """ + logger.debug("Loading images from folder: '%s'", self.location) + for idx, filename in enumerate(self.file_list): + if idx in self._skip_list: + logger.trace("Skipping face %s due to skip list") # type:ignore[attr-defined] + continue + image_read = read_image(filename, raise_error=False, with_metadata=True) + retval = filename, *image_read + if retval[1] is None: + logger.warning("Face not loaded: '%s'", filename) + continue + yield retval + + +class SingleFrameLoader(ImagesLoader): + """ Allows direct access to a frame by filename or frame index. + + As we are interested in instant access to frames, there is no requirement to process in a + background thread, as either way we need to wait for the frame to load. + + Parameters + ---------- + video_meta_data: dict, optional + Existing video meta information containing the pts_time and iskey flags for the given + video. Used in conjunction with single_frame_reader for faster seeks. Providing this means + that the video does not need to be scanned again. Set to ``None`` if the video is to be + scanned. Default: ``None`` + """ + def __init__(self, path, video_meta_data=None): + logger.debug("Initializing %s: (path: %s, video_meta_data: %s)", + self.__class__.__name__, path, video_meta_data) + self._video_meta_data = dict() if video_meta_data is None else video_meta_data + self._reader = None + super().__init__(path, queue_size=1, fast_count=False) + + @property + def video_meta_data(self): + """ dict: For videos contains the keys `frame_pts` holding a list of time stamps for each + frame and `keyframes` holding the frame index of each key frame. + + Notes + ----- + Only populated if the input is a video and single frame reader is being used, otherwise + returns ``None``. + """ + return self._video_meta_data + + def _get_count_and_filelist(self, fast_count, count): + if self._is_video: + self._reader = imageio.get_reader(self.location, "ffmpeg") + self._reader.use_patch = True + count, video_meta_data = self._reader.get_frame_info( + frame_pts=self._video_meta_data.get("pts_time", None), + keyframes=self._video_meta_data.get("keyframes", None)) + self._video_meta_data = video_meta_data + super()._get_count_and_filelist(fast_count, count) + + def image_from_index(self, index: int) -> tuple[str, np.ndarray]: + """ Return a single image from :attr:`file_list` for the given index. + + Parameters + ---------- + index: int + The index number (frame number) of the frame to retrieve. NB: The first frame is + index `0` + + Returns + ------- + filename: str + The filename of the returned image + image: :class:`numpy.ndarray` + The image for the given index + + Notes + ----- + Retrieving frames from video files can be slow as the whole video file needs to be + iterated to retrieve the requested frame. If a frame has already been retrieved, then + retrieving frames of a higher index will be quicker than retrieving frames of a lower + index, as iteration needs to start from the beginning again when navigating backwards. + + We do not use a background thread for this task, as it is assumed that requesting an image + by index will be done when required. + """ + if self.is_video: + image = self._reader.get_data(index)[..., ::-1] + filename = self._dummy_video_framename(index) + else: + file_list = [f for idx, f in enumerate(self._file_list) + if idx not in self._skip_list] if self._skip_list else self._file_list + + filename = file_list[index] + image = read_image(filename, raise_error=True) + filename = os.path.basename(filename) + logger.trace("index: %s, filename: %s image shape: %s", # type:ignore[attr-defined] + index, filename, image.shape) + return filename, image + + +class ImagesSaver(ImageIO): + """ Perform image saving to a destination folder. + + Images are saved in a background ThreadPoolExecutor to allow for concurrent saving. + See also :class:`ImageIO` for additional attributes. + + Parameters + ---------- + path: str + The folder to save images to. This must be an existing folder. + queue_size: int, optional + The amount of images to hold in the internal buffer. Default: 8. + as_bytes: bool, optional + ``True`` if the image is already encoded to bytes, ``False`` if the image is a + :class:`numpy.ndarray`. Default: ``False``. + + Examples + -------- + + >>> saver = ImagesSaver('/path/to/save/folder') + >>> for filename, image in : + >>> saver.save(filename, image) + >>> saver.close() + """ + + def __init__(self, path, queue_size=8, as_bytes=False): + logger.debug("Initializing %s: (path: %s, queue_size: %s, as_bytes: %s)", + self.__class__.__name__, path, queue_size, as_bytes) + + super().__init__(path, queue_size=queue_size) + self._as_bytes = as_bytes + + def _check_location_exists(self): + """ Check whether the output location exists and is a folder + + Raises + ------ + FaceswapError + If the given location does not exist or the location is not a folder + """ + if not isinstance(self.location, str): + raise FaceswapError("The output location must be a string not a " + "{}".format(type(self.location))) + super()._check_location_exists() + if not os.path.isdir(self.location): + raise FaceswapError("The output location '{}' is not a folder".format(self.location)) + + def _process(self, queue): + """ Saves images from the save queue to the given :attr:`location` inside a thread. + + Parameters + ---------- + queue: queue.Queue() + The ImageIO Queue + """ + executor = futures.ThreadPoolExecutor(thread_name_prefix=self.__class__.__name__) + while True: + item = queue.get() + if item == "EOF": + logger.debug("EOF received") + break + logger.trace("Submitting: '%s'", item[0]) # type:ignore[attr-defined] + executor.submit(self._save, *item) + executor.shutdown() + + def _save(self, + filename: str, + image: bytes | np.ndarray, + sub_folder: str | None) -> None: + """ Save a single image inside a ThreadPoolExecutor + + Parameters + ---------- + filename: str + The filename of the image to be saved. NB: Any folders passed in with the filename + will be stripped and replaced with :attr:`location`. + image: bytes or :class:`numpy.ndarray` + The encoded image or numpy array to be saved + subfolder: str or ``None`` + If the file should be saved in a subfolder in the output location, the subfolder should + be provided here. ``None`` for no subfolder. + """ + location = os.path.join(self.location, sub_folder) if sub_folder else self._location + if sub_folder and not os.path.exists(location): + os.makedirs(location) + + filename = os.path.join(location, os.path.basename(filename)) + try: + if self._as_bytes: + assert isinstance(image, bytes) + with open(filename, "wb") as out_file: + out_file.write(image) + else: + assert isinstance(image, np.ndarray) + cv2.imwrite(filename, image) + logger.trace("Saved image: '%s'", filename) # type:ignore[attr-defined] + except Exception as err: # pylint:disable=broad-except + logger.error("Failed to save image '%s'. Original Error: %s", filename, str(err)) + del image + del filename + + def save(self, + filename: str, + image: bytes | np.ndarray, + sub_folder: str | None = None) -> None: + """ Save the given image in the background thread + + Ensure that :func:`close` is called once all save operations are complete. + + Parameters + ---------- + filename: str + The filename of the image to be saved. NB: Any folders passed in with the filename + will be stripped and replaced with :attr:`location`. + image: bytes + The encoded image to be saved + subfolder: str, optional + If the file should be saved in a subfolder in the output location, the subfolder should + be provided here. ``None`` for no subfolder. Default: ``None`` + """ + self._set_thread() + logger.trace("Putting to save queue: '%s'", filename) # type:ignore[attr-defined] + self._queue.put((filename, image, sub_folder)) + + def close(self): + """ Signal to the Save Threads that they should be closed and cleanly shutdown + the saver """ + logger.debug("Putting EOF to save queue") + self._queue.put("EOF") + super().close() + + +__all__ = get_module_objects(__name__) diff --git a/lib/keras_utils.py b/lib/keras_utils.py new file mode 100644 index 0000000000..916f3963cd --- /dev/null +++ b/lib/keras_utils.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3 +""" Common multi-backend Keras utilities """ +from __future__ import annotations +import typing as T + +import numpy as np + +from keras import ops, Variable + +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor + +# TODO these can probably be switched to pure pytorch + + +def frobenius_norm(matrix: KerasTensor, + axis: int = -1, + keep_dims: bool = True, + epsilon: float = 1e-15) -> KerasTensor: + """ Frobenius normalization for Keras Tensor + + Parameters + ---------- + matrix: :class:`keras.KerasTensor` + The matrix to normalize + axis: int, optional + The axis to normalize. Default: `-1` + keep_dims: bool, Optional + Whether to retain the original matrix shape or not. Default:``True`` + epsilon: flot, optional + Epsilon to apply to the normalization to preven NaN errors on zero values + + Returns + ------- + :class:`keras.KerasTensor` + The normalized output + """ + return ops.sqrt(ops.sum(ops.power(matrix, 2), axis=axis, keepdims=keep_dims) + epsilon) + + +def replicate_pad(image: KerasTensor, padding: int) -> KerasTensor: + """ Apply replication padding to an input batch of images. Expects 4D tensor in BHWC format. + + Notes + ----- + At the time of writing Keras does not have a native replication padding method. + The implementation here is probably not the most efficient, but it is a pure keras method + which should work ok. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + Image tensor to pad + pad: int + The amount of padding to apply to each side of the input image + + Returns + ------- + :class:`keras.KerasTensor` + The input image with replication padding applied + """ + top_pad = ops.tile(image[:, :1, ...], (1, padding, 1, 1)) + bottom_pad = ops.tile(image[:, -1:, ...], (1, padding, 1, 1)) + pad_top_bottom = ops.concatenate([top_pad, image, bottom_pad], axis=1) + left_pad = ops.tile(pad_top_bottom[..., :1, :], (1, 1, padding, 1)) + right_pad = ops.tile(pad_top_bottom[..., -1:, :], (1, 1, padding, 1)) + padded = ops.concatenate([left_pad, pad_top_bottom, right_pad], axis=2) + return padded + + +class ColorSpaceConvert(): + """ Transforms inputs between different color spaces on the GPU + + Notes + ----- + The following color space transformations are implemented: + - rgb to lab + - rgb to xyz + - srgb to _rgb + - srgb to ycxcz + - xyz to ycxcz + - xyz to lab + - xyz to rgb + - ycxcz to rgb + - ycxcz to xyz + + Parameters + ---------- + from_space: str + One of `"srgb"`, `"rgb"`, `"xyz"` + to_space: str + One of `"lab"`, `"rgb"`, `"ycxcz"`, `"xyz"` + + Raises + ------ + ValueError + If the requested color space conversion is not defined + """ + def __init__(self, from_space: str, to_space: str) -> None: + functions = {"rgb_lab": self._rgb_to_lab, + "rgb_xyz": self._rgb_to_xyz, + "srgb_rgb": self._srgb_to_rgb, + "srgb_ycxcz": self._srgb_to_ycxcz, + "xyz_ycxcz": self._xyz_to_ycxcz, + "xyz_lab": self._xyz_to_lab, + "xyz_rgb": self._xyz_to_rgb, + "ycxcz_rgb": self._ycxcz_to_rgb, + "ycxcz_xyz": self._ycxcz_to_xyz} + func_name = f"{from_space.lower()}_{to_space.lower()}" + if func_name not in functions: + raise ValueError(f"The color transform {from_space} to {to_space} is not defined.") + + self._func = functions[func_name] + self._ref_illuminant = Variable(np.array([[[0.950428545, 1.000000000, 1.088900371]]]), + dtype="float32", + trainable=False) + self._inv_ref_illuminant = 1. / self._ref_illuminant + + self._rgb_xyz_map = self._get_rgb_xyz_map() + self._xyz_multipliers = Variable([116, 500, 200], dtype="float32", trainable=False) + + @classmethod + def _get_rgb_xyz_map(cls) -> tuple[KerasTensor, KerasTensor]: + """ Obtain the mapping and inverse mapping for rgb to xyz color space conversion. + + Returns + ------- + tuple + The mapping and inverse Tensors for rgb to xyz color space conversion + """ + mapping = np.array([[10135552 / 24577794, 8788810 / 24577794, 4435075 / 24577794], + [2613072 / 12288897, 8788810 / 12288897, 887015 / 12288897], + [1425312 / 73733382, 8788810 / 73733382, 70074185 / 73733382]]) + inverse = np.linalg.inv(mapping) + return (Variable(mapping, dtype="float32", trainable=False), + Variable(inverse, dtype="float32", trainable=False)) + + def __call__(self, image: KerasTensor) -> KerasTensor: + """ Call the colorspace conversion function. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in the colorspace defined by :attr:`from_space` + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in the colorspace defined by :attr:`to_space` + """ + return self._func(image) + + def _rgb_to_lab(self, image: KerasTensor) -> KerasTensor: + """ RGB to LAB conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in RGB format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in LAB format + """ + converted = self._rgb_to_xyz(image) + return self._xyz_to_lab(converted) + + def _rgb_xyz_rgb(self, image: KerasTensor, mapping: KerasTensor) -> KerasTensor: + """ RGB to XYZ or XYZ to RGB conversion. + + Notes + ----- + The conversion in both directions is the same, but the mappping matrix for XYZ to RGB is + the inverse of RGB to XYZ. + + References + ---------- + https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz + + Parameters + ---------- + mapping: :class:`keras.KerasTensor` + The mapping matrix to perform either the XYZ to RGB or RGB to XYZ color space + conversion + + image: :class:`keras.KerasTensor` + The image tensor in RGB format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in XYZ format + """ + dim = image.shape + image = ops.transpose(image, (0, 3, 1, 2)) + image = ops.reshape(image, (dim[0], dim[3], dim[1] * dim[2])) + converted = ops.transpose(ops.dot(mapping, image), (0, 2, 1)) + return ops.reshape(converted, dim) + + def _rgb_to_xyz(self, image: KerasTensor) -> KerasTensor: + """ RGB to XYZ conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in RGB format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in XYZ format + """ + return self._rgb_xyz_rgb(image, self._rgb_xyz_map[0]) + + @classmethod + def _srgb_to_rgb(cls, image: KerasTensor) -> KerasTensor: + """ SRGB to RGB conversion. + + Notes + ----- + RGB Image is clipped to a small epsilon to stabalize training + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in SRGB format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in RGB format + """ + limit = np.float32(0.04045) + return ops.where(image > limit, + ops.power((ops.clip(image, limit, np.inf) + 0.055) / 1.055, 2.4), + image / 12.92) + + def _srgb_to_ycxcz(self, image: KerasTensor) -> KerasTensor: + """ SRGB to YcXcZ conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in SRGB format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in YcXcZ format + """ + converted = self._srgb_to_rgb(image) + converted = self._rgb_to_xyz(converted) + return self._xyz_to_ycxcz(converted) + + def _xyz_to_lab(self, image: KerasTensor) -> KerasTensor: + """ XYZ to LAB conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in XYZ format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in LAB format + """ + image = image * self._inv_ref_illuminant + delta = 6 / 29 + delta_cube = delta ** 3 + factor = 1 / (3 * (delta ** 2)) + + clamped_term = ops.power(ops.clip(image, delta_cube, np.inf), 1.0 / 3.0) + div = factor * image + (4 / 29) + + image = ops.where(image > delta_cube, clamped_term, div) + + return ops.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16., + self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])], + axis=-1) + + def _xyz_to_rgb(self, image: KerasTensor) -> KerasTensor: + """ XYZ to YcXcZ conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in XYZ format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in RGB format + """ + return self._rgb_xyz_rgb(image, self._rgb_xyz_map[1]) + + def _xyz_to_ycxcz(self, image: KerasTensor) -> KerasTensor: + """ XYZ to YcXcZ conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in XYZ format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in YcXcZ format + """ + image = image * self._inv_ref_illuminant + return ops.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16., + self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])], + axis=-1) + + def _ycxcz_to_rgb(self, image: KerasTensor) -> KerasTensor: + """ YcXcZ to RGB conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in YcXcZ format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in RGB format + """ + converted = self._ycxcz_to_xyz(image) + return self._xyz_to_rgb(converted) + + def _ycxcz_to_xyz(self, image: KerasTensor) -> KerasTensor: + """ YcXcZ to XYZ conversion. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The image tensor in YcXcZ format + + Returns + ------- + :class:`keras.KerasTensor` + The image tensor in XYZ format + """ + ch_y = (image[..., 0:1] + 16.) / self._xyz_multipliers[0] + return ops.concatenate([ch_y + (image[..., 1:2] / self._xyz_multipliers[1]), + ch_y, + ch_y - (image[..., 2:3] / self._xyz_multipliers[2])], + axis=-1) * self._ref_illuminant + + +__all__ = get_module_objects(__name__) diff --git a/lib/keypress.py b/lib/keypress.py index a36a049600..4505d67feb 100644 --- a/lib/keypress.py +++ b/lib/keypress.py @@ -17,24 +17,28 @@ """ import os +import sys + +from lib.utils import get_module_objects # Windows if os.name == "nt": - import msvcrt # pylint: disable=import-error + import msvcrt # pylint:disable=import-error # Posix (Linux, OS X) else: - import sys import termios import atexit from select import select +# pylint:disable=possibly-used-before-assignment + class KBHit: """ Creates a KBHit object that you can call to do various keyboard things. """ def __init__(self, is_gui=False): self.is_gui = is_gui - if os.name == "nt" or self.is_gui: + if os.name == "nt" or self.is_gui or not sys.stdout.isatty(): pass else: # Save the terminal settings @@ -43,7 +47,7 @@ def __init__(self, is_gui=False): self.old_term = termios.tcgetattr(self.file_desc) # New terminal setting unbuffered - self.new_term[3] = (self.new_term[3] & ~termios.ICANON & ~termios.ECHO) + self.new_term[3] = self.new_term[3] & ~termios.ICANON & ~termios.ECHO termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.new_term) # Support normal-terminal reset at exit @@ -51,21 +55,21 @@ def __init__(self, is_gui=False): def set_normal_term(self): """ Resets to normal terminal. On Windows this is a no-op. """ - if os.name == "nt" or self.is_gui: + if os.name == "nt" or self.is_gui or not sys.stdout.isatty(): pass else: termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.old_term) - @staticmethod - def getch(): + def getch(self): """ Returns a keyboard character after kbhit() has been called. Should not be called in the same program as getarrow(). """ + if (self.is_gui or not sys.stdout.isatty()) and os.name != "nt": + return None if os.name == "nt": - return msvcrt.getch().decode("utf-8") + return msvcrt.getch().decode("utf-8", errors="replace") return sys.stdin.read(1) - @staticmethod - def getarrow(): + def getarrow(self): """ Returns an arrow-key code after kbhit() has been called. Codes are 0 : up 1 : right @@ -73,6 +77,8 @@ def getarrow(): 3 : left Should not be called in the same program as getch(). """ + if (self.is_gui or not sys.stdout.isatty()) and os.name != "nt": + return None if os.name == "nt": msvcrt.getch() # skip 0xE0 char = msvcrt.getch() @@ -81,12 +87,16 @@ def getarrow(): char = sys.stdin.read(3)[2] vals = [65, 67, 66, 68] - return vals.index(ord(char.decode("utf-8"))) + return vals.index(ord(char.decode("utf-8", errors="replace"))) - @staticmethod - def kbhit(): + def kbhit(self): """ Returns True if keyboard character was hit, False otherwise. """ + if (self.is_gui or not sys.stdout.isatty()) and os.name != "nt": + return None if os.name == "nt": return msvcrt.kbhit() d_r, _, _ = select([sys.stdin], [], [], 0) return d_r != [] + + +__all__ = get_module_objects(__name__) diff --git a/lib/logger.py b/lib/logger.py index 0e64687bbc..d9dbbb1d3a 100644 --- a/lib/logger.py +++ b/lib/logger.py @@ -1,25 +1,25 @@ #!/usr/bin/python -""" Logging Setup """ +""" Logging Functions for Faceswap. """ +# NOTE: Don't import non stdlib packages. This module is accessed by setup.py import collections import logging -from logging.handlers import QueueHandler, QueueListener, RotatingFileHandler +from logging.handlers import RotatingFileHandler import os +import platform import re import sys +import typing as T +import time import traceback from datetime import datetime -from time import sleep -from lib.queue_manager import queue_manager -from lib.sysinfo import sysinfo +from lib.utils import get_module_objects -LOG_QUEUE = queue_manager._log_queue # pylint: disable=protected-access - -class MultiProcessingLogger(logging.Logger): - """ Create custom logger with custom levels """ - def __init__(self, name): +class FaceswapLogger(logging.Logger): + """ A standard :class:`logging.logger` with additional "verbose" and "trace" levels added. """ + def __init__(self, name: str) -> None: for new_level in (("VERBOSE", 15), ("TRACE", 5)): level_name, level_num = new_level if hasattr(logging, level_name): @@ -28,150 +28,547 @@ def __init__(self, name): setattr(logging, level_name, level_num) super().__init__(name) - def verbose(self, msg, *args, **kwargs): - """ - Log 'msg % args' with severity 'VERBOSE'. + def verbose(self, msg: str, *args, **kwargs) -> None: + # pylint:disable=wrong-spelling-in-docstring + """ Create a log message at severity level 15. + + Parameters + ---------- + msg: str + The log message to be recorded at Verbose level + args: tuple + Standard logging arguments + kwargs: dict + Standard logging key word arguments """ if self.isEnabledFor(15): self._log(15, msg, args, **kwargs) - def trace(self, msg, *args, **kwargs): - """ - Log 'msg % args' with severity 'VERBOSE'. + def trace(self, msg: str, *args, **kwargs) -> None: + # pylint:disable=wrong-spelling-in-docstring + """ Create a log message at severity level 5. + + Parameters + ---------- + msg: str + The log message to be recorded at Trace level + args: tuple + Standard logging arguments + kwargs: dict + Standard logging key word arguments """ if self.isEnabledFor(5): self._log(5, msg, args, **kwargs) +class ColoredFormatter(logging.Formatter): + """ Overrides the stand :class:`logging.Formatter` to enable colored labels for message level + labels on supported platforms + + Parameters + ---------- + fmt: str + The format string for the message as a whole + pad_newlines: bool, Optional + If ``True`` new lines will be padded to appear in line with the log message, if ``False`` + they will be left aligned + + kwargs: dict + Standard :class:`logging.Formatter` keyword arguments + """ + def __init__(self, fmt: str, pad_newlines: bool = False, **kwargs) -> None: + super().__init__(fmt, **kwargs) + self._use_color = self._get_color_compatibility() + self._level_colors = {"CRITICAL": "\033[31m", # red + "ERROR": "\033[31m", # red + "WARNING": "\033[33m", # yellow + "INFO": "\033[32m", # green + "VERBOSE": "\033[34m"} # blue + self._default_color = "\033[0m" + self._newline_padding = self._get_newline_padding(pad_newlines, fmt) + + @classmethod + def _get_color_compatibility(cls) -> bool: + """ Return whether the system supports color ansi codes. Most OSes do other than Windows + below Windows 10 version 1511. + + Returns + ------- + bool + ``True`` if the system supports color ansi codes otherwise ``False`` + """ + if platform.system().lower() != "windows": + return True + try: + win = sys.getwindowsversion() # type:ignore # pylint:disable=no-member + if win.major >= 10 and win.build >= 10586: + return True + except Exception: # pylint:disable=broad-except + return False + return False + + def _get_newline_padding(self, pad_newlines: bool, fmt: str) -> int: + """ Parses the format string to obtain padding for newlines if requested + + Parameters + ---------- + fmt: str + The format string for the message as a whole + pad_newlines: bool, Optional + If ``True`` new lines will be padded to appear in line with the log message, if + ``False`` they will be left aligned + + Returns + ------- + int + The amount of padding to apply to the front of newlines + """ + if not pad_newlines: + return 0 + msg_idx = fmt.find("%(message)") + 1 + filtered = fmt[:msg_idx - 1] + spaces = filtered.count(" ") + pads = [int(pad.replace("s", "")) for pad in re.findall(r"\ds", filtered)] + if "asctime" in filtered: + pads.append(self._get_sample_time_string()) + return sum(pads) + spaces + + def _get_sample_time_string(self) -> int: + """ Obtain a sample time string and calculate correct padding. + + This may be inaccurate when ticking over an integer from single to double digits, but that + shouldn't be a huge issue. + + Returns + ------- + int + The length of the formatted date-time string + """ + sample_time = time.time() + date_format = self.datefmt if self.datefmt else self.default_time_format + datestring = time.strftime(date_format, logging.Formatter.converter(sample_time)) + if not self.datefmt and self.default_msec_format: + msecs = (sample_time - int(sample_time)) * 1000 + datestring = self.default_msec_format % (datestring, msecs) + return len(datestring) + + def format(self, record: logging.LogRecord) -> str: + """ Color the log message level if supported otherwise return the standard log message. + + Parameters + ---------- + record: :class:`logging.LogRecord` + The incoming log record to be formatted for entry into the logger. + + Returns + ------- + str + The formatted log message + """ + formatted = super().format(record) + levelname = record.levelname + if self._use_color and levelname in self._level_colors: + formatted = re.sub(levelname, + f"{self._level_colors[levelname]}{levelname}{self._default_color}", + formatted, + 1) + if self._newline_padding: + formatted = formatted.replace("\n", f"\n{' ' * self._newline_padding}") + return formatted + + class FaceswapFormatter(logging.Formatter): - """ Override formatter to strip newlines and multiple spaces from logger - Messages that begin with "R|" should be handled as is + """ Overrides the standard :class:`logging.Formatter`. + + Strip newlines from incoming log messages. + + Rewrites some upstream warning messages to debug level to avoid spamming the console. """ - def format(self, record): - if record.msg.startswith("R|"): - record.msg = record.msg[2:] - record.strip_spaces = False - elif record.strip_spaces: - record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r")) - return super().format(record) + + def format(self, record: logging.LogRecord) -> str: + """ Strip new lines from log records and rewrite certain warning messages to debug level. + + Parameters + ---------- + record : :class:`logging.LogRecord` + The incoming log record to be formatted for entry into the logger. + + Returns + ------- + str + The formatted log message + """ + record.message = record.getMessage() + record = self._lower_external(record) + # strip newlines + if record.levelno < 30 and ("\n" in record.message or "\r" in record.message): + record.message = record.message.replace("\n", "\\n").replace("\r", "\\r") + + if self.usesTime(): + record.asctime = self.formatTime(record, self.datefmt) + msg = self.formatMessage(record) + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if msg[-1:] != "\n": + msg = msg + "\n" + msg = msg + record.exc_text + if record.stack_info: + if msg[-1:] != "\n": + msg = msg + "\n" + msg = msg + self.formatStack(record.stack_info) + return msg + + @classmethod + def _lower_external(cls, record: logging.LogRecord) -> logging.LogRecord: + """ Some external libs log at a higher level than we would really like, so lower their + log level. + + Specifically: Matplotlib font properties + + Parameters + ---------- + record: :class:`logging.LogRecord` + The log record to check for rewriting + + Returns + ---------- + :class:`logging.LogRecord` + The log rewritten or untouched record + """ + if (record.levelno == 20 and record.funcName == "__init__" + and record.module == "font_manager"): + # Matplotlib font manager + record.levelno = 10 + record.levelname = "DEBUG" + + return record class RollingBuffer(collections.deque): - """File-like that keeps a certain number of lines of text in memory.""" - def write(self, buffer): - """ Write line to buffer """ + """File-like that keeps a certain number of lines of text in memory for writing out to the + crash log. """ + + def write(self, buffer: str) -> None: + """ Splits lines from the incoming buffer and writes them out to the rolling buffer. + + Parameters + ---------- + buffer: str + The log messages to write to the rolling buffer + """ for line in buffer.rstrip().splitlines(): - self.append(line + "\n") + self.append(f"{line}\n") -def set_root_logger(loglevel=logging.INFO, queue=LOG_QUEUE): +class TqdmHandler(logging.StreamHandler): + """ Overrides :class:`logging.StreamHandler` to use :func:`tqdm.tqdm.write` rather than writing + to :func:`sys.stderr` so that log messages do not mess up tqdm progress bars. """ + + def emit(self, record: logging.LogRecord) -> None: + """ Format the incoming message and pass to :func:`tqdm.tqdm.write`. + + Parameters + ---------- + record : :class:`logging.LogRecord` + The incoming log record to be formatted for entry into the logger. + """ + # tqdm is imported here as it won't be installed when setup.py is running + from tqdm import tqdm # pylint:disable=import-outside-toplevel + msg = self.format(record) + tqdm.write(msg) + + +def _set_root_logger(loglevel: int = logging.INFO) -> logging.Logger: """ Setup the root logger. - Loaded in main process and into any spawned processes - Automatically added in multithreading.py""" - rootlogger = logging.getLogger() - q_handler = QueueHandler(queue) - rootlogger.addHandler(q_handler) - rootlogger.setLevel(loglevel) + Parameters + ---------- + loglevel: int, optional + The log level to set the root logger to. Default :attr:`logging.INFO` -def log_setup(loglevel, logfile, command): - """ initial log set up. """ + Returns + ------- + :class:`logging.Logger` + The root logger for Faceswap + """ + rootlogger = logging.getLogger() + rootlogger.setLevel(loglevel) + logging.captureWarnings(True) + return rootlogger + + +def log_setup(loglevel, log_file: str, command: str, is_gui: bool = False) -> None: + """ Set up logging for Faceswap. + + Sets up the root logger, the formatting for the crash logger and the file logger, and sets up + the crash, file and stream log handlers. + + Parameters + ---------- + loglevel: str + The requested log level that Faceswap should be run at. + log_file: str + The location of the log file to write Faceswap's log to + command: str + The Faceswap command that is being run. Used to dictate whether the log file should + have "_gui" appended to the filename or not. + is_gui: bool, optional + Whether Faceswap is running in the GUI or not. Dictates where the stream handler should + output messages to. Default: ``False`` + """ numeric_loglevel = get_loglevel(loglevel) root_loglevel = min(logging.DEBUG, numeric_loglevel) - set_root_logger(loglevel=root_loglevel) - log_format = FaceswapFormatter("%(asctime)s %(processName)-15s %(threadName)-15s " - "%(module)-15s %(funcName)-25s %(levelname)-8s %(message)s", - datefmt="%m/%d/%Y %H:%M:%S") - f_handler = file_handler(numeric_loglevel, logfile, log_format, command) - s_handler = stream_handler(numeric_loglevel) - c_handler = crash_handler(log_format) - - q_listener = QueueListener(LOG_QUEUE, f_handler, s_handler, c_handler, - respect_handler_level=True) - q_listener.start() - logging.info("Log level set to: %s", loglevel.upper()) + rootlogger = _set_root_logger(loglevel=root_loglevel) - -def file_handler(loglevel, logfile, log_format, command): - """ Add a logging rotating file handler """ - if logfile is not None: - filename = logfile + if command == "setup": + log_format = FaceswapFormatter("%(asctime)s %(module)-16s %(funcName)-30s %(levelname)-8s " + "%(message)s", datefmt="%m/%d/%Y %H:%M:%S") + s_handler = _stream_setup_handler(numeric_loglevel) + f_handler = _file_handler(root_loglevel, log_file, log_format, command) + else: + log_format = FaceswapFormatter("%(asctime)s %(processName)-15s %(threadName)-30s " + "%(module)-15s %(funcName)-30s %(levelname)-8s %(message)s", + datefmt="%m/%d/%Y %H:%M:%S") + s_handler = _stream_handler(numeric_loglevel, is_gui) + f_handler = _file_handler(numeric_loglevel, log_file, log_format, command) + + rootlogger.addHandler(f_handler) + rootlogger.addHandler(s_handler) + + if command != "setup": + c_handler = _crash_handler(log_format) + rootlogger.addHandler(c_handler) + logging.info("Log level set to: %s", loglevel.upper()) + + +def _file_handler(loglevel, + log_file: str, + log_format: FaceswapFormatter, + command: str) -> RotatingFileHandler: + """ Add a rotating file handler for the current Faceswap session. 1 backup is always kept. + + Parameters + ---------- + loglevel: str + The requested log level that messages should be logged at. + log_file: str + The location of the log file to write Faceswap's log to + log_format: :class:`FaceswapFormatter: + The formatting to store log messages as + command: str + The Faceswap command that is being run. Used to dictate whether the log file should + have "_gui" appended to the filename or not. + + Returns + ------- + :class:`logging.RotatingFileHandler` + The logging file handler + """ + if log_file: + filename = log_file else: filename = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap") - # Windows has issues sharing the log file with subprocesses, so log GUI separately + # Windows has issues sharing the log file with sub-processes, so log GUI separately filename += "_gui.log" if command == "gui" else ".log" should_rotate = os.path.isfile(filename) - log_file = RotatingFileHandler(filename, backupCount=1) + handler = RotatingFileHandler(filename, backupCount=1, encoding="utf-8") if should_rotate: - log_file.doRollover() - log_file.setFormatter(log_format) - log_file.setLevel(loglevel) - return log_file - - -def stream_handler(loglevel): - """ Add a logging cli handler """ + handler.doRollover() + handler.setFormatter(log_format) + handler.setLevel(loglevel) + return handler + + +def _stream_handler(loglevel: int, is_gui: bool) -> logging.StreamHandler | TqdmHandler: + """ Add a stream handler for the current Faceswap session. The stream handler will only ever + output at a maximum of VERBOSE level to avoid spamming the console. + + Parameters + ---------- + loglevel: int + The requested log level that messages should be logged at. + is_gui: bool, optional + Whether Faceswap is running in the GUI or not. Dictates where the stream handler should + output messages to. + + Returns + ------- + :class:`TqdmHandler` or :class:`logging.StreamHandler` + The stream handler to use + """ # Don't set stdout to lower than verbose loglevel = max(loglevel, 15) log_format = FaceswapFormatter("%(asctime)s %(levelname)-8s %(message)s", datefmt="%m/%d/%Y %H:%M:%S") - log_console = logging.StreamHandler(sys.stdout) + if is_gui: + # tqdm.write inserts extra lines in the GUI, so use standard output as + # it is not needed there. + log_console = logging.StreamHandler(sys.stdout) + else: + log_console = TqdmHandler(sys.stdout) log_console.setFormatter(log_format) log_console.setLevel(loglevel) return log_console -def crash_handler(log_format): - """ Add a handler that sores the last 50 debug lines to `debug_buffer` - for use in crash reports """ - log_crash = logging.StreamHandler(debug_buffer) +def _stream_setup_handler(loglevel: int) -> logging.StreamHandler: + """ Add a stream handler for faceswap's setup.py script + This stream handler outputs a limited set of easy to use information using colored labels + if available. It will only ever output at a minimum of INFO level + + Parameters + ---------- + loglevel: int + The requested log level that messages should be logged at. + + Returns + ------- + :class:`logging.StreamHandler` + The stream handler to use + """ + loglevel = max(loglevel, 15) + log_format = ColoredFormatter("%(levelname)-8s %(message)s", pad_newlines=True) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(log_format) + handler.setLevel(loglevel) + return handler + + +def _crash_handler(log_format: FaceswapFormatter) -> logging.StreamHandler: + """ Add a handler that stores the last 100 debug lines to :attr:'_DEBUG_BUFFER' for use in + crash reports. + + Parameters + ---------- + log_format: :class:`FaceswapFormatter: + The formatting to store log messages as + + Returns + ------- + :class:`logging.StreamHandler` + The crash log handler + """ + log_crash = logging.StreamHandler(_DEBUG_BUFFER) log_crash.setFormatter(log_format) log_crash.setLevel(logging.DEBUG) return log_crash -def get_loglevel(loglevel): - """ Check valid log level supplied and return numeric log level """ +def get_loglevel(loglevel: str) -> int: + """ Check whether a valid log level has been supplied, and return the numeric log level that + corresponds to the given string level. + + Parameters + ---------- + loglevel: str + The loglevel that has been requested + + Returns + ------- + int + The numeric representation of the given loglevel + """ numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): - raise ValueError("Invalid log level: %s" % loglevel) - + raise ValueError(f"Invalid log level: {loglevel}") return numeric_level -def crash_log(): - """ Write debug_buffer to a crash log on crash """ - path = os.getcwd() - filename = os.path.join(path, datetime.now().strftime("crash_report.%Y.%m.%d.%H%M%S%f.log")) - - # Wait until all log items have been processed - while not LOG_QUEUE.empty(): - sleep(1) +def crash_log() -> str: + """ On a crash, write out the contents of :func:`_DEBUG_BUFFER` containing the last 100 lines + of debug messages to a crash report in the root Faceswap folder. - freeze_log = list(debug_buffer) - with open(filename, "w") as outfile: + Returns + ------- + str + The filename of the file that contains the crash report + """ + original_traceback = traceback.format_exc().encode("utf-8") + path = os.path.dirname(os.path.realpath(sys.argv[0])) + filename = os.path.join(path, datetime.now().strftime("crash_report.%Y.%m.%d.%H%M%S%f.log")) + freeze_log = [line.encode("utf-8") for line in _DEBUG_BUFFER] + try: + from lib.system.sysinfo import sysinfo # pylint:disable=import-outside-toplevel + except Exception: # pylint:disable=broad-except + sysinfo = ("\n\nThere was an error importing System Information from lib.sysinfo. This is " + f"probably a bug which should be fixed:\n{traceback.format_exc()}") + with open(filename, "wb") as outfile: outfile.writelines(freeze_log) - traceback.print_exc(file=outfile) - outfile.write(sysinfo.full_info()) + outfile.write(original_traceback) + outfile.write(sysinfo.encode("utf-8")) return filename -# Add a flag to logging.LogRecord to not strip formatting from particular records -old_factory = logging.getLogRecordFactory() +def _process_value(value: T.Any) -> T.Any: + """ Process the values from a local dict and return in a loggable format + + Parameters + ---------- + value: Any + The dictionary value + + Returns + ------- + Any + The original or ammended value + """ + if isinstance(value, (list, tuple, set)) and len(value) > 10: + return f'[type: "{type(value).__name__}" len: {len(value)}' + + try: + import numpy as np # pylint:disable=import-outside-toplevel + except ImportError: + return repr(value) + + if isinstance(value, np.ndarray) and np.prod(value.shape) > 10: + return f'[type: "{type(value).__name__}" shape: {value.shape}, dtype: "{value.dtype}"]' + + return repr(value) -def faceswap_logrecord(*args, **kwargs): - record = old_factory(*args, **kwargs) - record.strip_spaces = True +def parse_class_init(locals_dict: dict[str, T.Any]) -> str: + """ Parse a locals dict from a class and return in a format suitable for logging + Parameters + ---------- + locals_dict: dict[str, T.Any] + A locals() dictionary from a newly initialized class + + Returns + ------- + str + The locals information suitable for logging + """ + delimit = {k: _process_value(v) + for k, v in locals_dict.items() if k != "self"} + dsp = ", ".join(f"{k}={v}" for k, v in delimit.items()) + dsp = f"({dsp})" if dsp else "" + return f"Initializing {locals_dict['self'].__class__.__name__}{dsp}" + + +_OLD_FACTORY = logging.getLogRecordFactory() + + +def _faceswap_logrecord(*args, **kwargs) -> logging.LogRecord: + """ Add a flag to :class:`logging.LogRecord` to not strip formatting from particular + records. """ + record = _OLD_FACTORY(*args, **kwargs) + record.strip_spaces = True # type:ignore return record -logging.setLogRecordFactory(faceswap_logrecord) +logging.setLogRecordFactory(_faceswap_logrecord) # Set logger class to custom logger -logging.setLoggerClass(MultiProcessingLogger) +logging.setLoggerClass(FaceswapLogger) + +# Stores the last 100 debug messages +_DEBUG_BUFFER = RollingBuffer(maxlen=100) + -# Stores the last 50 debug messages -debug_buffer = RollingBuffer(maxlen=50) # pylint: disable=invalid-name +__all__ = get_module_objects(__name__) diff --git a/lib/model/autoclip.py b/lib/model/autoclip.py new file mode 100644 index 0000000000..03d1a54af7 --- /dev/null +++ b/lib/model/autoclip.py @@ -0,0 +1,64 @@ +""" Auto clipper for clipping gradients. """ +from __future__ import annotations + +import logging +import typing as T + +import numpy as np +import torch + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class AutoClipper(): + """ AutoClip: Adaptive Gradient Clipping for Source Separation Networks + + Parameters + ---------- + clip_percentile: int + The percentile to clip the gradients at + history_size: int, optional + The number of iterations of data to use to calculate the norm Default: ``10000`` + + References + ---------- + Adapted from: https://github.com/pseeth/autoclip + original paper: https://arxiv.org/abs/2007.14469 + """ + def __init__(self, clip_percentile: int, history_size: int = 10000) -> None: + logger.debug(parse_class_init(locals())) + + self._clip_percentile = clip_percentile + self._history_size = history_size + self._grad_history: list[float] = [] + + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, gradients: list[KerasTensor]) -> list[KerasTensor]: + """ Call the AutoClip function. + + Parameters + ---------- + gradients: list[:class:`keras.KerasTensor`] + The list of gradient tensors for the optimizer + + Returns + ---------- + list[:class:`keras.KerasTensor`] + The autoclipped gradients + """ + self._grad_history.append(sum(g.data.norm(2).item() ** 2 + for g in gradients if g is not None) ** (1. / 2)) + self._grad_history = self._grad_history[-self._history_size:] + clip_value = np.percentile(self._grad_history, self._clip_percentile) + torch.nn.utils.clip_grad_norm_(gradients, T.cast(float, clip_value)) + return gradients + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/backup_restore.py b/lib/model/backup_restore.py new file mode 100644 index 0000000000..d60ecd0536 --- /dev/null +++ b/lib/model/backup_restore.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 + +""" Functions for backing up, restoring and creating model snapshots. """ + +import logging +import os +from datetime import datetime +from shutil import copyfile, copytree, rmtree + +from lib.serializer import get_serializer +from lib.utils import get_folder, get_module_objects + +logger = logging.getLogger(__name__) + + +class Backup(): + """ Performs the back up of models at each save iteration, and the restoring of models from + their back up location. + + Parameters + ---------- + model_dir: str + The folder that contains the model to be backed up + model_name: str + The name of the model that is to be backed up + """ + def __init__(self, model_dir: str, model_name: str) -> None: + logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s')", + self.__class__.__name__, model_dir, model_name) + self.model_dir = str(model_dir) + self.model_name = model_name + logger.debug("Initialized %s", self.__class__.__name__) + + def _check_valid(self, filename: str, for_restore: bool = False) -> bool: + """ Check if the passed in filename is valid for a backup or restore operation. + + Parameters + ---------- + filename: str + The filename that is to be checked for backup or restore + for_restore: bool, optional + ``True`` if the checks are to be performed for restoring a model, ``False`` if the + checks are to be performed for backing up a model. Default: ``False`` + + Returns + ------- + bool + ``True`` if the given file is valid for a backup/restore operation otherwise ``False`` + """ + fullpath = os.path.join(self.model_dir, filename) + if not filename.startswith(self.model_name): + # Any filename that does not start with the model name are invalid + # for all operations + retval = False + elif for_restore and filename.endswith(".bk"): + # Only filenames ending in .bk are valid for restoring + retval = True + elif not for_restore and ((os.path.isfile(fullpath) and not filename.endswith(".bk")) or + (os.path.isdir(fullpath) and + filename == f"{self.model_name}_logs")): + # Only filenames that do not end with .bk or folders that are the logs folder + # are valid for backup + retval = True + else: + retval = False + logger.debug("'%s' valid for backup operation: %s", filename, retval) + return retval + + @staticmethod + def backup_model(full_path: str) -> None: + """ Backup a model file. + + The backed up file is saved with the original filename in the original location with `.bk` + appended to the end of the name. + + Parameters + ---------- + full_path: str + The full path to a `.keras` model file or a `.json` state file + """ + backupfile = full_path + ".bk" + if os.path.exists(backupfile): + os.remove(backupfile) + if os.path.exists(full_path): + logger.verbose("Backing up: '%s' to '%s'", # type:ignore[attr-defined] + full_path, backupfile) + copyfile(full_path, backupfile) + + def snapshot_models(self, iterations: int) -> None: + """ Take a snapshot of the model at the current state and back it up. + + The snapshot is a copy of the model folder located in the same root location + as the original model file, with the number of iterations appended to the end + of the folder name. + + Parameters + ---------- + iterations: int + The number of iterations that the model has trained when performing the snapshot. + """ + print("\x1b[2K", end="\r") # Erase the current line + logger.verbose("Saving snapshot") # type:ignore[attr-defined] + snapshot_dir = f"{self.model_dir}_snapshot_{iterations}_iters" + + if os.path.isdir(snapshot_dir): + logger.debug("Removing previously existing snapshot folder: '%s'", snapshot_dir) + rmtree(snapshot_dir) + + dst = get_folder(snapshot_dir) + for filename in os.listdir(self.model_dir): + if not self._check_valid(filename, for_restore=False): + logger.debug("Not snapshotting file: '%s'", filename) + continue + srcfile = os.path.join(self.model_dir, filename) + dstfile = os.path.join(dst, filename) + + logger.debug("Saving snapshot: '%s' > '%s'", srcfile, dstfile) + if os.path.isdir(srcfile): + copytree(srcfile, dstfile) + else: + copyfile(srcfile, dstfile) + logger.info("Saved snapshot (%s iterations)", iterations) + + def restore(self) -> None: + """ Restores a model from backup. + + The original model files are migrated into a folder within the original model folder + named `_archived_`. The `.bk` backup files are then moved to + the location of the previously existing model files. Logs that were generated after the + the last backup was taken are removed. """ + archive_dir = self._move_archived() + self._restore_files() + self._restore_logs(archive_dir) + + def _move_archived(self) -> str: + """ Move archived files to the archived folder. + + Returns + ------- + str + The name of the generated archive folder + """ + logger.info("Archiving existing model files...") + now = datetime.now().strftime("%Y%m%d_%H%M%S") + archive_dir = os.path.join(self.model_dir, f"{self.model_name}_archived_{now}") + os.mkdir(archive_dir) + for filename in os.listdir(self.model_dir): + if not self._check_valid(filename, for_restore=False): + logger.debug("Not moving file to archived: '%s'", filename) + continue + logger.verbose( # type:ignore[attr-defined] + "Moving '%s' to archived model folder: '%s'", filename, archive_dir) + src = os.path.join(self.model_dir, filename) + dst = os.path.join(archive_dir, filename) + os.rename(src, dst) + logger.verbose("Archived existing model files") # type:ignore[attr-defined] + return archive_dir + + def _restore_files(self) -> None: + """ Restore files from .bk """ + logger.info("Restoring models from backup...") + for filename in os.listdir(self.model_dir): + if not self._check_valid(filename, for_restore=True): + logger.debug("Not restoring file: '%s'", filename) + continue + dstfile = os.path.splitext(filename)[0] + logger.verbose("Restoring '%s' to '%s'", # type:ignore[attr-defined] + filename, dstfile) + src = os.path.join(self.model_dir, filename) + dst = os.path.join(self.model_dir, dstfile) + copyfile(src, dst) + logger.verbose("Restored models from backup") # type:ignore[attr-defined] + + def _restore_logs(self, archive_dir: str) -> None: + """ Restores the log files up to and including the last backup. + + Parameters + ---------- + archive_dir: str + The full path to the model's archive folder + """ + logger.info("Restoring Logs...") + session_names = self._get_session_names() + log_dirs = self._get_log_dirs(archive_dir, session_names) + for log_dir in log_dirs: + src = os.path.join(archive_dir, log_dir) + dst = os.path.join(self.model_dir, log_dir) + logger.verbose("Restoring logfile: %s", dst) # type:ignore[attr-defined] + copytree(src, dst) + logger.verbose("Restored Logs") # type:ignore[attr-defined] + + def _get_session_names(self) -> list[str]: + """ Get the existing session names from a state file. + + Returns + ------- + list[str] + The session names that exist for the model + """ + serializer = get_serializer("json") + state_file = os.path.join(self.model_dir, + f"{self.model_name}_state.{serializer.file_extension}") + state = serializer.load(state_file) + session_names = [f"session_{key}" for key in state["sessions"].keys()] + logger.debug("Session to restore: %s", session_names) + return session_names + + def _get_log_dirs(self, archive_dir: str, session_names: list[str]) -> list[str]: + """ Get the session log directory paths in the archive folder. + + Parameters + ---------- + archive_dir: str + The full path to the model's archive folder + session_names: list[str] + The name of the training sessions that exist for the model + + Returns + ------- + list[str] + The full paths to the log folders + """ + archive_logs = os.path.join(archive_dir, f"{self.model_name}_logs") + paths = [os.path.join(dirpath.replace(archive_dir, "")[1:], folder) + for dirpath, dirnames, _ in os.walk(archive_logs) + for folder in dirnames + if folder in session_names] + logger.debug("log folders to restore: %s", paths) + return paths + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/initializers.py b/lib/model/initializers.py index c2b4ce6043..908dd6e892 100644 --- a/lib/model/initializers.py +++ b/lib/model/initializers.py @@ -1,81 +1,360 @@ #!/usr/bin/env python3 -""" Custom Initializers for faceswap.py - Initializers from: - shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" +""" Custom Initializers for faceswap.py """ +from __future__ import annotations +import logging import sys import inspect -import tensorflow as tf -from keras import initializers -from keras.utils.generic_utils import get_custom_objects +import typing as T +from keras import backend as K, initializers, ops +from keras import saving, Variable +from keras.src.initializers.random_initializers import compute_fans -def icnr_keras(shape, dtype=None): - """ - Custom initializer for subpix upscaling - From https://github.com/kostyaev/ICNR - Note: upscale factor is fixed to 2, and the base initializer is fixed to random normal. - """ - # TODO Roll this into ICNR_init when porting GAN 2.2 - shape = list(shape) - scale = 2 - initializer = tf.keras.initializers.RandomNormal(0, 0.02) +import numpy as np - new_shape = shape[:3] + [int(shape[3] / (scale ** 2))] - var_x = initializer(new_shape, dtype) - var_x = tf.transpose(var_x, perm=[2, 0, 1, 3]) - var_x = tf.image.resize_nearest_neighbor(var_x, size=(shape[0] * scale, shape[1] * scale)) - var_x = tf.space_to_depth(var_x, block_size=scale) - var_x = tf.transpose(var_x, perm=[1, 2, 0, 3]) - return var_x +from lib.logger import parse_class_init +from lib.utils import get_module_objects +if T.TYPE_CHECKING: + from keras import KerasTensor -class ICNR(initializers.Initializer): # pylint: disable=invalid-name - ''' - ICNR initializer for checkerboard artifact free sub pixel convolution +logger = logging.getLogger(__name__) - Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution - https://arxiv.org/pdf/1707.02937.pdf https://distill.pub/2016/deconv-checkerboard/ - Parameters: - initializer: initializer used for sub kernels (orthogonal, glorot uniform, etc.) - scale: scale factor of sub pixel convolution (upsampling from 8x8 to 16x16 is scale 2) - Return: +class ICNR(initializers.Initializer): + """ ICNR initializer for checkerboard artifact free sub pixel convolution + + Parameters + ---------- + initializer: :class:`keras.initializers.Initializer` + The initializer used for sub kernels (orthogonal, glorot uniform, etc.) + scale: int, optional + scaling factor of sub pixel convolution (up sampling from 8x8 to 16x16 is scale 2). + Default: `2` + + Returns + ------- + :class:`keras.KerasTensor` The modified kernel weights - Example: - x = conv2d(... weights_initializer=ICNR(initializer=he_uniform(), scale=2)) - ''' - def __init__(self, initializer, scale=2): - self.scale = scale - self.initializer = initializer + Example + ------- + >>> x = conv2d(... weights_initializer=ICNR(initializer=he_uniform(), scale=2)) + + References + ---------- + Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution + https://arxiv.org/pdf/1707.02937.pdf, https://distill.pub/2016/deconv-checkerboard/ + """ + + def __init__(self, + initializer: dict[str, T.Any] | initializers.Initializer, + scale: int = 2) -> None: + logger.debug(parse_class_init(locals())) - def __call__(self, shape, dtype='float32'): # tf needs partition_info=None + self._scale = scale + self._initializer = initializer + + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, + shape: list[int] | tuple[int, ...], + dtype: str | None = "float32") -> KerasTensor: + """ Call function for the ICNR initializer. + + Parameters + ---------- + shape: list[int] | tuple[int, ...] + The required resized shape for the output tensor + dtype: str + The data type for the tensor + kwargs: dict[str, Any] + Standard keras initializer keyword arguments + + Returns + ------- + :class:`keras.KerasTensor` + The modified kernel weights + """ shape = list(shape) - if self.scale == 1: - return self.initializer(shape) - new_shape = shape[:3] + [shape[3] // (self.scale ** 2)] - if type(self.initializer) is dict: - self.initializer = initializers.deserialize(self.initializer) - var_x = self.initializer(new_shape, dtype) - var_x = tf.transpose(var_x, perm=[2, 0, 1, 3]) - var_x = tf.image.resize_nearest_neighbor( - var_x, - size=(shape[0] * self.scale, shape[1] * self.scale), - align_corners=True) - var_x = tf.space_to_depth(var_x, block_size=self.scale, data_format='NHWC') - var_x = tf.transpose(var_x, perm=[1, 2, 0, 3]) - return var_x - - def get_config(self): - config = {'scale': self.scale, - 'initializer': self.initializer - } - base_config = super(ICNR, self).get_config() + + if self._scale == 1: + if isinstance(self._initializer, dict): + return next(i for i in self._initializer.values()) + return self._initializer(shape) + + new_shape = shape[:3] + [shape[3] // (self._scale ** 2)] + size = [s * self._scale for s in new_shape[:2]] + + if isinstance(self._initializer, dict): + self._initializer = initializers.deserialize(self._initializer) + + var_x = self._initializer(new_shape, dtype) + var_x = ops.transpose(var_x, [2, 0, 1, 3]) + var_x = ops.image.resize(var_x, + size, + interpolation="nearest", + data_format="channels_last") + var_x = self._space_to_depth(T.cast("KerasTensor", var_x)) + var_x = ops.transpose(var_x, [1, 2, 0, 3]) + + logger.debug("ICNR Output shape: %s", var_x.shape) + return T.cast("KerasTensor", var_x) + + def _space_to_depth(self, input_tensor: KerasTensor) -> KerasTensor: + """ Space to depth Keras implementation. + + Parameters + ---------- + input_tensor: :class:`keras.KerasTensor` + The tensor to be manipulated + + Returns + ------- + :class:`keras.KerasTensor` + The manipulated input tensor + """ + batch, height, width, depth = input_tensor.shape + assert height is not None and width is not None + new_height, new_width = height // 2, width // 2 + inter_shape = (batch, new_height, self._scale, new_width, self._scale, depth) + + var_x = ops.reshape(input_tensor, inter_shape) + var_x = ops.transpose(var_x, (0, 1, 3, 2, 4, 5)) + retval = ops.reshape(var_x, (batch, new_height, new_width, -1)) + + logger.debug("Space to depth - Input shape: %s, Output shape: %s", + input_tensor.shape, retval.shape) + return T.cast("KerasTensor", retval) + + def get_config(self) -> dict[str, T.Any]: + """ Return the ICNR Initializer configuration. + + Returns + ------- + dict[str, Any] + The configuration for ICNR Initialization + """ + config = {"scale": self._scale, "initializer": self._initializer} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class ConvolutionAware(initializers.Initializer): + """ + Initializer that generates orthogonal convolution filters in the Fourier space. If this + initializer is passed a shape that is not 3D or 4D, orthogonal initialization will be used. + + Adapted, fixed and optimized from: + https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/initializers/convaware.py + + Parameters + ---------- + eps_std: float, optional + The Standard deviation for the random normal noise used to break symmetry in the inverse + Fourier transform. Default: 0.05 + seed: int | None, optional + Used to seed the random generator. Default: ``None`` + initialized: bool, optional + This should always be set to ``False``. To avoid Keras re-calculating the values every time + the model is loaded, this parameter is internally set on first time initialization. + Default:``False`` + + Returns + ------- + :class:`keras.Variable` + The modified kernel weights + + References + ---------- + Armen Aghajanyan, https://arxiv.org/abs/1702.06295 + """ + + def __init__(self, + eps_std: float = 0.05, + seed: int | None = None, + initialized: bool = False) -> None: + logger.debug(parse_class_init(locals())) + + self._eps_std = eps_std + self._seed = seed + self._orthogonal = initializers.OrthogonalInitializer() + self._he_uniform = initializers.HeUniform() + self._initialized = initialized + + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _symmetrize(cls, inputs: np.ndarray) -> np.ndarray: + """ Make the given tensor symmetrical. + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input tensor to make symmetrical + + Returns + ------- + :class:`numpy.ndarray` + The symmetrical output + """ + var_a = np.transpose(inputs, axes=(0, 1, 3, 2)) + diag = var_a.diagonal(axis1=2, axis2=3) + var_b = np.array([[np.diag(arr) for arr in batch] for batch in diag]) + retval = inputs + var_a - var_b + logger.debug("Input shape: %s. Output shape: %s", inputs.shape, retval.shape) + return retval + + def _create_basis(self, filters_size: int, filters: int, size: int, dtype: str) -> np.ndarray: + """ Create the basis for convolutional aware initialization + + Parameters + ---------- + filters_size: int + The size of the filter + filters: int + The number of filters + dtype: str + The data type + + Returns + ------- + :class:`numpy.ndarray` + The output array + """ + if size == 1: + return np.random.normal(0.0, self._eps_std, (filters_size, filters, size)) + nbb = filters // size + 1 + var_a = np.random.normal(0.0, 1.0, (filters_size, nbb, size, size)) + var_a = self._symmetrize(var_a) + var_u = np.linalg.svd(var_a)[0].transpose(0, 1, 3, 2) + retval = np.reshape(var_u, (filters_size, nbb * size, size))[:, :filters, :].astype(dtype) + logger.debug("filters_size: %s, filters: %s, size: %s, dtype: %s, output: %s", + filters_size, filters, size, dtype, retval.shape) + return retval + + @classmethod + def _scale_filters(cls, filters: np.ndarray, variance: float) -> np.ndarray: + """ Scale the given filters. + + Parameters + ---------- + filters: :class:`numpy.ndarray` + The filters to scale + variance: float + The amount of variance + + Returns + ------- + :class:`numpy.ndarray` + The scaled filters + """ + c_var = np.var(filters) + var_p = np.sqrt(variance / c_var) + retval = filters * var_p + logger.debug("Scaled filters (filters: %s, variance: %s, output: %s)", + filters.shape, variance, retval.shape) + return retval + + def __call__(self, # pylint: disable=too-many-locals + shape: list[int] | tuple[int, ...], + dtype: str | None = None) -> Variable: + """ Call function for the ICNR initializer. + + Parameters + ---------- + shape: list[int] | tuple[int, ...] + The required shape for the output tensor + dtype: str + The data type for the tensor + + Returns + ------- + :class:`keras.Variable` + The modified kernel weights + """ + if self._initialized: # Avoid re-calculating initializer when loading a saved model + return T.cast("Variable", self._he_uniform(shape, dtype=dtype)) + dtype = K.floatx() if dtype is None else dtype + logger.info("Calculating Convolution Aware Initializer for shape: %s", shape) + rank = len(shape) + if self._seed is not None: + np.random.seed(self._seed) + + fan_in, _ = compute_fans(shape) + variance = 2 / fan_in + + kernel_shape: tuple[int, ...] + transpose_dimensions: tuple[int, ...] + correct_ifft: T.Callable + correct_fft: T.Callable + + if rank == 3: + row, stack_size, filters_size = shape + + transpose_dimensions = (2, 1, 0) + kernel_shape = (row,) + correct_ifft = lambda shape, s=[None]: np.fft.irfft(shape, s[0]) # noqa:E731,E501 pylint:disable=unnecessary-lambda-assignment + + correct_fft = np.fft.rfft + + elif rank == 4: + row, column, stack_size, filters_size = shape + + transpose_dimensions = (2, 3, 1, 0) + kernel_shape = (row, column) + correct_ifft = np.fft.irfft2 + correct_fft = np.fft.rfft2 + + elif rank == 5: + var_x, var_y, var_z, stack_size, filters_size = shape + + transpose_dimensions = (3, 4, 0, 1, 2) + kernel_shape = (var_x, var_y, var_z) + correct_fft = np.fft.rfftn + correct_ifft = np.fft.irfftn + + else: + self._initialized = True + return Variable(self._orthogonal(shape), dtype=dtype) + + kernel_fourier_shape = correct_fft(np.zeros(kernel_shape)).shape + + basis = self._create_basis(filters_size, + stack_size, + T.cast(int, np.prod(kernel_fourier_shape)), + dtype) + basis = basis.reshape((filters_size, stack_size,) + kernel_fourier_shape) + randoms = np.random.normal(0, self._eps_std, basis.shape[:-2] + kernel_shape) + init = correct_ifft(basis, kernel_shape) + randoms + init = self._scale_filters(init, variance) + self._initialized = True + retval = Variable(init.transpose(transpose_dimensions), dtype=dtype, name="conv_aware") + logger.debug("ConvAware output: %s", retval) + return retval + + def get_config(self) -> dict[str, T.Any]: + """ Return the Convolutional Aware Initializer configuration. + + Returns + ------- + dict[str, Any] + The configuration for Convolutional Aware Initialization + """ + config = {"eps_std": self._eps_std, + "seed": self._seed, + "initialized": self._initialized} + # pylint:disable=duplicate-code + base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) +# pylint:disable=duplicate-code # Update initializers into Keras custom objects for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and obj.__module__ == __name__: - get_custom_objects().update({name: obj}) + saving.get_custom_objects().update({name: obj}) + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/layers.py b/lib/model/layers.py index c848b7f23d..51bc644eb7 100644 --- a/lib/model/layers.py +++ b/lib/model/layers.py @@ -1,99 +1,412 @@ #!/usr/bin/env python3 -""" Custom Layers for faceswap.py - Layers from: - the original https://www.reddit.com/r/deepfakes/ code sample + contribs - shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" +""" Custom Layers for faceswap.py. """ +from __future__ import annotations -from __future__ import absolute_import - -import sys import inspect +import logging +import operator +import sys +import typing as T + +from keras import InputSpec, Layer, ops, saving + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor + + +logger = logging.getLogger(__name__) + + +class _GlobalPooling2D(Layer): # pylint:disable=too-many-ancestors + """Abstract class for different global pooling 2D layers. """ + def __init__(self, data_format: str | None = None, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + + super().__init__(**kwargs) + self.data_format = "channels_last" if data_format is None else data_format + self.input_spec = InputSpec(ndim=4) + logger.debug("Initialized %s", self.__class__.__name__) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Compute the output shape based on the input shape. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + """ + if self.data_format == "channels_last": + return (input_shape[0], input_shape[3]) + return (input_shape[0], input_shape[1]) + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Override to call the layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output from the layer + + """ + raise NotImplementedError + + def get_config(self) -> dict[str, T.Any]: + """ Set the Keras config """ + config = {"data_format": self.data_format} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class GlobalMinPooling2D(_GlobalPooling2D): # pylint:disable=too-many-ancestors,abstract-method + """Global minimum pooling operation for spatial data. """ + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + if self.data_format == "channels_last": + pooled = ops.min(inputs, axis=[1, 2]) + else: + pooled = ops.min(inputs, axis=[2, 3]) + return pooled + + +class GlobalStdDevPooling2D(_GlobalPooling2D): # pylint:disable=too-many-ancestors,abstract-method + """Global standard deviation pooling operation for spatial data. """ + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + if self.data_format == "channels_last": + pooled = ops.std(inputs, axis=[1, 2]) + else: + pooled = ops.std(inputs, axis=[2, 3]) + return pooled + + +class KResizeImages(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ A custom upscale function that uses :class:`keras.backend.resize_images` to upsample. + + Parameters + ---------- + size: int or float, optional + The scale to upsample to. Default: `2` + interpolation: ["nearest", "bilinear"], optional + The interpolation to use. Default: `"nearest"` + kwargs: dict + The standard Keras Layer keyword arguments (if any) + """ + def __init__(self, + size: int = 2, + interpolation: T.Literal["nearest", "bilinear"] = "nearest", + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(**kwargs) + self.size = size + self.interpolation = interpolation + logger.debug("Initialized %s", self.__class__.__name__) + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Call the upsample layer + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + height, width = inputs.shape[1:3] + assert height is not None and width is not None + size = int(round(width * self.size)), int(round(height * self.size)) + retval = ops.image.resize(inputs, + size, + interpolation=self.interpolation, + data_format="channels_last") + return retval + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """Computes the output shape of the layer. + + This is the input shape with size dimensions multiplied by :attr:`size` + + Parameters + ---------- + input_shape: tuple or list of tuples + Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the + layer). Shape tuples can include None for free dimensions, instead of an integer. + + Returns + ------- + tuple + An input shape tuple + """ + batch, height, width, channels = input_shape + return (batch, int(round(height * self.size)), int(round(width * self.size)), channels) + + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. + + Returns + -------- + dict + A python dictionary containing the layer configuration + """ + config = {"size": self.size, "interpolation": self.interpolation} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + -import tensorflow as tf -import keras.backend as K +class L2Normalize(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Normalizes a tensor w.r.t. the L2 norm alongside the specified axis. -from keras.engine import InputSpec, Layer -from keras.utils import conv_utils -from keras.utils.generic_utils import get_custom_objects -from keras import initializers -from keras.layers import ZeroPadding2D + Parameters + ---------- + axis: int + The axis to perform normalization across + kwargs: dict + The standard Keras Layer keyword arguments (if any) + """ + def __init__(self, axis: int, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + self.axis = axis + super().__init__(**kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Compute the output shape based on the input shape. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + """ + return input_shape + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + return ops.normalize(inputs, self.axis, order=2) + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. -class PixelShuffler(Layer): - """ PixelShuffler layer for Keras - by t-ae: https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981 """ - # pylint: disable=C0103 - def __init__(self, size=(2, 2), data_format=None, **kwargs): - super(PixelShuffler, self).__init__(**kwargs) - self.data_format = K.normalize_data_format(data_format) - self.size = conv_utils.normalize_tuple(size, 2, 'size') + A layer config is a Python dictionary (serializable) containing the configuration of a + layer. The same layer can be reinstated later (without its trained weights) from this + configuration. - def call(self, inputs, **kwargs): + The configuration of a layer does not include connectivity information, nor the layer + class name. These are handled by `Network` (one layer of abstraction above). - input_shape = K.int_shape(inputs) + Returns + -------- + dict + A python dictionary containing the layer configuration + """ + config = super().get_config() + config["axis"] = self.axis + return config + + +class PixelShuffler(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ PixelShuffler layer for Keras. + + This layer requires a Convolution2D prior to it, having output filters computed according to + the formula :math:`filters = k * (scale_factor * scale_factor)` where `k` is a user defined + number of filters (generally larger than 32) and `scale_factor` is the up-scaling factor + (generally 2). + + This layer performs the depth to space operation on the convolution filters, and returns a + tensor with the size as defined below. + + Notes + ----- + In practice, it is useful to have a second convolution layer after the + :class:`PixelShuffler` layer to speed up the learning process. However, if you are stacking + multiple :class:`PixelShuffler` blocks, it may increase the number of parameters greatly, + so the Convolution layer after :class:`PixelShuffler` layer can be removed. + + Example + ------- + >>> # A standard sub-pixel up-scaling block + >>> x = Convolution2D(256, 3, 3, padding="same", activation="relu")(...) + >>> u = PixelShuffler(size=(2, 2))(x) + [Optional] + >>> x = Convolution2D(256, 3, 3, padding="same", activation="relu")(u) + + Parameters + ---------- + size: tuple, optional + The (`h`, `w`) scaling factor for up-scaling. Default: `(2, 2)` + data_format: ["channels_first", "channels_last", ``None``], optional + The data format for the input. Default: ``None`` + kwargs: dict + The standard Keras Layer keyword arguments (if any) + + References + ---------- + https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981 + """ + def __init__(self, + size: int | tuple[int, int] = (2, 2), + data_format: str | None = None, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(**kwargs) + self.data_format = "channels_last" if data_format is None else data_format + self.size = (size, size) if isinstance(size, int) else tuple(size) + logger.debug("Initialized %s", self.__class__.__name__) + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + input_shape = inputs.shape if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + + raise ValueError("Inputs should have rank " + str(4) + - '; Received input shape:', str(input_shape)) + "; Received input shape:", str(input_shape)) - if self.data_format == 'channels_first': - batch_size, c, h, w = input_shape + out = None + if self.data_format == "channels_first": + batch_size, channels, height, width = input_shape + assert height is not None and width is not None and channels is not None if batch_size is None: batch_size = -1 - rh, rw = self.size - oh, ow = h * rh, w * rw - oc = c // (rh * rw) - - out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) - out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) - out = K.reshape(out, (batch_size, oc, oh, ow)) - elif self.data_format == 'channels_last': - batch_size, h, w, c = input_shape + r_height, r_width = self.size + o_height, o_width = height * r_height, width * r_width + o_channels = channels // (r_height * r_width) + + out = ops.reshape(inputs, (batch_size, r_height, r_width, o_channels, height, width)) + out = ops.transpose(out, (0, 3, 4, 1, 5, 2)) + out = ops.reshape(out, (batch_size, o_channels, o_height, o_width)) + elif self.data_format == "channels_last": + batch_size, height, width, channels = input_shape + assert height is not None and width is not None and channels is not None if batch_size is None: batch_size = -1 - rh, rw = self.size - oh, ow = h * rh, w * rw - oc = c // (rh * rw) - - out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) - out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) - out = K.reshape(out, (batch_size, oh, ow, oc)) - return out - - def compute_output_shape(self, input_shape): - + r_height, r_width = self.size + o_height, o_width = height * r_height, width * r_width + o_channels = channels // (r_height * r_width) + + out = ops.reshape(inputs, (batch_size, height, width, r_height, r_width, o_channels)) + out = ops.transpose(out, (0, 1, 3, 2, 4, 5)) + out = ops.reshape(out, (batch_size, o_height, o_width, o_channels)) + assert out is not None + return T.cast("KerasTensor", out) + + def compute_output_shape(self, # pylint:disable=arguments-differ + input_shape: tuple[int | None, ...]) -> tuple[int | None, ...]: + """Computes the output shape of the layer. + + Assumes that the layer will be built to match that input shape provided. + + Parameters + ---------- + input_shape: tuple or list of tuples + Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the + layer). Shape tuples can include None for free dimensions, instead of an integer. + + Returns + ------- + tuple + An input shape tuple + """ if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + + raise ValueError("Inputs should have rank " + str(4) + - '; Received input shape:', str(input_shape)) + "; Received input shape:", str(input_shape)) - if self.data_format == 'channels_first': + retval: tuple[int | None, ...] + if self.data_format == "channels_first": height = None width = None if input_shape[2] is not None: height = input_shape[2] * self.size[0] if input_shape[3] is not None: width = input_shape[3] * self.size[1] - channels = input_shape[1] // self.size[0] // self.size[1] + chs = input_shape[1] + assert chs is not None + channels = chs // self.size[0] // self.size[1] if channels * self.size[0] * self.size[1] != input_shape[1]: - raise ValueError('channels of input and size are incompatible') + raise ValueError("channels of input and size are incompatible") retval = (input_shape[0], channels, height, width) - elif self.data_format == 'channels_last': + else: height = None width = None if input_shape[1] is not None: height = input_shape[1] * self.size[0] if input_shape[2] is not None: width = input_shape[2] * self.size[1] - channels = input_shape[3] // self.size[0] // self.size[1] + chs = input_shape[3] + assert chs is not None + channels = chs // self.size[0] // self.size[1] if channels * self.size[0] * self.size[1] != input_shape[3]: - raise ValueError('channels of input and size are incompatible') + raise ValueError("channels of input and size are incompatible") retval = (input_shape[0], height, @@ -101,238 +414,339 @@ def compute_output_shape(self, input_shape): channels) return retval - def get_config(self): - config = {'size': self.size, - 'data_format': self.data_format} - base_config = super(PixelShuffler, self).get_config() + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. + + A layer config is a Python dictionary (serializable) containing the configuration of a + layer. The same layer can be reinstated later (without its trained weights) from this + configuration. + + The configuration of a layer does not include connectivity information, nor the layer + class name. These are handled by `Network` (one layer of abstraction above). + + Returns + -------- + dict + A python dictionary containing the layer configuration + """ + config = {"size": self.size, + "data_format": self.data_format} + base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) -class Scale(Layer): - """ - GAN Custom Scal Layer - Code borrows from https://github.com/flyyufelix/cnn_finetune +class QuickGELU(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Applies GELU approximation that is fast but somewhat inaccurate. + + Parameters + ---------- + name: str, optional + The name for the layer. Default: "QuickGELU" + kwargs: dict + The standard Keras Layer keyword arguments (if any) """ - def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs): - self.axis = axis - self.gamma_init = initializers.get(gamma_init) - self.initial_weights = weights - super(Scale, self).__init__(**kwargs) + def __init__(self, name: str = "QuickGELU", **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=name, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Compute the output shape based on the input shape. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + """ + return input_shape - def build(self, input_shape): - self.input_spec = [InputSpec(shape=input_shape)] + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Call the QuickGELU layerr - # Compatibility with TensorFlow >= 1.0.0 - self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name)) - self.trainable_weights = [self.gamma] + Parameters + ---------- + inputs : :class:`keras.KerasTensor` + The input Tensor - if self.initial_weights is not None: - self.set_weights(self.initial_weights) - del self.initial_weights + Returns + ------- + :class:`keras.KerasTensor` + The output Tensor + """ + return inputs * ops.sigmoid(1.702 * inputs) - def call(self, x, mask=None): - return self.gamma * x - def get_config(self): - config = {"axis": self.axis} - base_config = super(Scale, self).get_config() - return dict(list(base_config.items()) + list(config.items())) +class ReflectionPadding2D(Layer): # pylint:disable=too-many-ancestors,abstract-method + """Reflection-padding layer for 2D input (e.g. picture). + This layer can add rows and columns at the top, bottom, left and right side of an image tensor. -class SubPixelUpscaling(Layer): - # pylint: disable=C0103 - """ Sub-pixel convolutional upscaling layer based on the paper "Real-Time - Single Image and Video Super-Resolution Using an Efficient Sub-Pixel - Convolutional Neural Network" (https://arxiv.org/abs/1609.05158). - This layer requires a Convolution2D prior to it, having output filters - computed according to the formula : - filters = k * (scale_factor * scale_factor) - where k = a user defined number of filters (generally larger than 32) - scale_factor = the upscaling factor (generally 2) - This layer performs the depth to space operation on the convolution - filters, and returns a tensor with the size as defined below. - # Example : - ```python - # A standard subpixel upscaling block - x = Convolution2D(256, 3, 3, padding="same", activation="relu")(...) - u = SubPixelUpscaling(scale_factor=2)(x) - [Optional] - x = Convolution2D(256, 3, 3, padding="same", activation="relu")(u) - ``` - In practice, it is useful to have a second convolution layer after the - SubPixelUpscaling layer to speed up the learning process. - However, if you are stacking multiple SubPixelUpscaling blocks, - it may increase the number of parameters greatly, so the Convolution - layer after SubPixelUpscaling layer can be removed. - # Arguments - scale_factor: Upscaling factor. - data_format: Can be None, "channels_first" or "channels_last". - # Input shape - 4D tensor with shape: - `(samples, k * (scale_factor * scale_factor) channels, rows, cols)` - if data_format="channels_first" - or 4D tensor with shape: - `(samples, rows, cols, k * (scale_factor * scale_factor) channels)` - if data_format="channels_last". - # Output shape - 4D tensor with shape: - `(samples, k channels, rows * scale_factor, cols * scale_factor))` - if data_format="channels_first" - or 4D tensor with shape: - `(samples, rows * scale_factor, cols * scale_factor, k channels)` - if data_format="channels_last". + Parameters + ---------- + stride: int, optional + The stride of the following convolution. Default: `2` + kernel_size: int, optional + The kernel size of the following convolution. Default: `5` + kwargs: dict + The standard Keras Layer keyword arguments (if any) """ + def __init__(self, stride: int = 2, kernel_size: int = 5, **kwargs) -> None: + logger.debug(parse_class_init(locals())) - def __init__(self, scale_factor=2, data_format=None, **kwargs): - super(SubPixelUpscaling, self).__init__(**kwargs) - - self.scale_factor = scale_factor - self.data_format = K.normalize_data_format(data_format) + if isinstance(stride, (tuple, list)): + assert len(stride) == 2 and stride[0] == stride[1] + stride = stride[0] + self.stride = stride + self.kernel_size = kernel_size + self.input_spec: list[InputSpec] | None = None + super().__init__(**kwargs) - def build(self, input_shape): - pass + logger.debug("Initialized %s", self.__class__.__name__) - def call(self, x, mask=None): - y = self.depth_to_space(x, self.scale_factor, self.data_format) - return y + def build(self, input_shape: KerasTensor) -> None: + """Creates the layer weights. - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - b, k, r, c = input_shape - return (b, - k // (self.scale_factor ** 2), - r * self.scale_factor, - c * self.scale_factor) - b, r, c, k = input_shape - return (b, - r * self.scale_factor, - c * self.scale_factor, - k // (self.scale_factor ** 2)) - - @classmethod - def depth_to_space(cls, ipt, scale, data_format=None): - """ Uses phase shift algorithm to convert channels/depth - for spatial resolution """ - if data_format is None: - data_format = K.image_data_format() - data_format = data_format.lower() - ipt = cls._preprocess_conv2d_input(ipt, data_format) - out = tf.depth_to_space(ipt, scale) - out = cls._postprocess_conv2d_output(out, data_format) - return out - - @staticmethod - def _postprocess_conv2d_output(x, data_format): - """Transpose and cast the output from conv2d if needed. - # Arguments - x: A tensor. - data_format: string, `"channels_last"` or `"channels_first"`. - # Returns - A tensor. - """ + Must be implemented on all layers that have weights. - if data_format == "channels_first": - x = tf.transpose(x, (0, 3, 1, 2)) - - if K.floatx() == "float64": - x = tf.cast(x, "float64") - return x - - @staticmethod - def _preprocess_conv2d_input(x, data_format): - """Transpose and cast the input before the conv2d. - # Arguments - x: input tensor. - data_format: string, `"channels_last"` or `"channels_first"`. - # Returns - A tensor. + Parameters + ---------- + input_shape: :class:`keras.KerasTensor` + Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to + reference for weight shape computations. """ - if K.dtype(x) == "float64": - x = tf.cast(x, "float32") - if data_format == "channels_first": - # TF uses the last dimension as channel dimension, - # instead of the 2nd one. - # TH input shape: (samples, input_depth, rows, cols) - # TF input shape: (samples, rows, cols, input_depth) - x = tf.transpose(x, (0, 2, 3, 1)) - return x - - def get_config(self): - config = {"scale_factor": self.scale_factor, - "data_format": self.data_format} - base_config = super(SubPixelUpscaling, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - + self.input_spec = [InputSpec(shape=input_shape)] + super().build(input_shape) -class ReflectionPadding2D(Layer): - def __init__(self, stride=2, kernel_size=5, **kwargs): - ''' - # Arguments - stride: stride of following convolution (2) - kernel_size: kernel size of following convolution (5,5) - ''' - self.stride = stride - self.kernel_size = kernel_size - super(ReflectionPadding2D, self).__init__(**kwargs) + def compute_output_shape(self, *args, **kwargs) -> tuple[int | None, ...]: + """Computes the output shape of the layer. - def build(self, input_shape): - self.input_spec = [InputSpec(shape=input_shape)] - super(ReflectionPadding2D, self).build(input_shape) + Assumes that the layer will be built to match that input shape provided. - def compute_output_shape(self, input_shape): - """ If you are using "channels_last" configuration""" + Returns + ------- + tuple + An input shape tuple + """ + assert self.input_spec is not None input_shape = self.input_spec[0].shape + assert input_shape is not None + assert input_shape[1] is not None and input_shape[2] is not None in_width, in_height = input_shape[2], input_shape[1] - kernel_width, kernel_height = self.kernel_size, self.kernel_size + kernel_width, kernel_height = self.kernel_size, self.kernel_size - if (in_height % self.stride == 0): + if (in_height % self.stride) == 0: padding_height = max(kernel_height - self.stride, 0) else: padding_height = max(kernel_height - (in_height % self.stride), 0) - if (in_width % self.stride == 0): + if (in_width % self.stride) == 0: padding_width = max(kernel_width - self.stride, 0) else: - padding_width = max(kernel_width- (in_width % self.stride), 0) + padding_width = max(kernel_width - (in_width % self.stride), 0) return (input_shape[0], input_shape[1] + padding_height, input_shape[2] + padding_width, input_shape[3]) - def call(self, x, mask=None): + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + assert self.input_spec is not None input_shape = self.input_spec[0].shape + assert input_shape is not None + assert input_shape[1] is not None and input_shape[2] is not None in_width, in_height = input_shape[2], input_shape[1] - kernel_width, kernel_height = self.kernel_size, self.kernel_size + kernel_width, kernel_height = self.kernel_size, self.kernel_size - if (in_height % self.stride == 0): + if (in_height % self.stride) == 0: padding_height = max(kernel_height - self.stride, 0) else: padding_height = max(kernel_height - (in_height % self.stride), 0) - if (in_width % self.stride == 0): + if (in_width % self.stride) == 0: padding_width = max(kernel_width - self.stride, 0) else: - padding_width = max(kernel_width- (in_width % self.stride), 0) + padding_width = max(kernel_width - (in_width % self.stride), 0) padding_top = padding_height // 2 padding_bot = padding_height - padding_top padding_left = padding_width // 2 padding_right = padding_width - padding_left - return tf.pad(x, [[0,0], - [padding_top, padding_bot], - [padding_left, padding_right], - [0,0] ], - 'REFLECT') + return ops.pad(inputs, + [[0, 0], [padding_top, padding_bot], [padding_left, padding_right], [0, 0]], + mode="reflect") + + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. + + A layer config is a Python dictionary (serializable) containing the configuration of a + layer. The same layer can be reinstated later (without its trained weights) from this + configuration. + + The configuration of a layer does not include connectivity information, nor the layer + class name. These are handled by `Network` (one layer of abstraction above). + + Returns + -------- + dict + A python dictionary containing the layer configuration + """ + config = {"stride": self.stride, + "kernel_size": self.kernel_size} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class Swish(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Swish Activation Layer implementation for Keras. + + Parameters + ---------- + beta: float, optional + The beta value to apply to the activation function. Default: `1.0` + kwargs: dict + The standard Keras Layer keyword arguments (if any) + + References + ----------- + Swish: a Self-Gated Activation Function: https://arxiv.org/abs/1710.05941v1 + """ + def __init__(self, beta: float = 1.0, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(**kwargs) + self.beta = beta + logger.debug("Initialized %s", self.__class__.__name__) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Compute the output shape based on the input shape. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + """ + return input_shape + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Call the Swish Activation function. + + Parameters + ---------- + inputs: tensor + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + return ops.nn.swish(inputs * self.beta) def get_config(self): - config = {'stride': self.stride, - 'kernel_size': self.kernel_size} - base_config = super(ReflectionPadding2D, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + """Returns the config of the layer. + + Adds the :attr:`beta` to config. + + Returns + -------- + dict + A python dictionary containing the layer configuration + """ + config = super().get_config() + config["beta"] = self.beta + return config + + +class ScalarOp(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ A layer for scalar operations for migrating TFLambdaOps in Keras 2 models to Keras 3. This + layer should not be used directly + + Parameters + ---------- + operation: Literal["multiply", "truediv", "add", "subtract"] + The scalar operation to perform + value: float + The scalar value to use + """ + def __init__(self, + operation: T.Literal["multiply", "truediv", "add", "subtract"], + value: float, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + assert operation in ("multiply", "truediv", "add", "subtract") + self._operation = operation + self._operator = {"multiply": operator.mul, + "truediv": operator.truediv, + "add": operator.add, + "subtract": operator.sub}[operation] + self._value = value + + if "name" not in kwargs: + kwargs["name"] = f"ScalarOp_{operation}" + super().__init__(**kwargs) + + logger.debug("Initialized %s", self.__class__.__name__) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Output shape is the same as the input shape. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + """ + return input_shape + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Call the Scalar operation function. + + Parameters + ---------- + inputs: tensor + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + return self._operator(inputs, self._value) + + def get_config(self): + """Returns the config of the layer. + Returns + -------- + dict + A python dictionary containing the layer configuration + """ + config = super().get_config() + config["operation"] = self._operation + config["value"] = self._value + return config # Update layers into Keras custom objects -for name, obj in inspect.getmembers(sys.modules[__name__]): +for name_, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and obj.__module__ == __name__: - get_custom_objects().update({name: obj}) + saving.get_custom_objects().update({name_: obj}) + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/losses.py b/lib/model/losses.py deleted file mode 100644 index 66b6873310..0000000000 --- a/lib/model/losses.py +++ /dev/null @@ -1,844 +0,0 @@ -#!/usr/bin/env python3 -""" Custom Loss Functions for faceswap.py - Losses from: - keras.contrib - dfaker: https://github.com/dfaker/df - shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" - -from __future__ import absolute_import - -import keras.backend as K -from keras.layers import Lambda, concatenate -import numpy as np -import tensorflow as tf -from tensorflow.contrib.distributions import Beta - -from .normalization import InstanceNormalization - - -class DSSIMObjective(): - """ DSSIM Loss Function - - Code copy and pasted, with minor ammendments from: - https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py - - MIT License - - Copyright (c) 2017 Fariz Rahman - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. """ - # pylint: disable=C0103 - def __init__(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0): - """ - Difference of Structural Similarity (DSSIM loss function). Clipped - between 0 and 0.5 - Note : You should add a regularization term like a l2 loss in - addition to this one. - Note : In theano, the `kernel_size` must be a factor of the output - size. So 3 could not be the `kernel_size` for an output of 32. - # Arguments - k1: Parameter of the SSIM (default 0.01) - k2: Parameter of the SSIM (default 0.03) - kernel_size: Size of the sliding window (default 3) - max_value: Max value of the output (default 1.0) - """ - self.__name__ = 'DSSIMObjective' - self.kernel_size = kernel_size - self.k1 = k1 - self.k2 = k2 - self.max_value = max_value - self.c1 = (self.k1 * self.max_value) ** 2 - self.c2 = (self.k2 * self.max_value) ** 2 - self.dim_ordering = K.image_data_format() - self.backend = K.backend() - - @staticmethod - def __int_shape(x): - return K.int_shape(x) - - def __call__(self, y_true, y_pred): - # There are additional parameters for this function - # Note: some of the 'modes' for edge behavior do not yet have a - # gradient definition in the Theano tree and cannot be used for - # learning - - kernel = [self.kernel_size, self.kernel_size] - y_true = K.reshape(y_true, [-1] + list(self.__int_shape(y_pred)[1:])) - y_pred = K.reshape(y_pred, [-1] + list(self.__int_shape(y_pred)[1:])) - - patches_pred = self.extract_image_patches(y_pred, - kernel, - kernel, - 'valid', - self.dim_ordering) - patches_true = self.extract_image_patches(y_true, - kernel, - kernel, - 'valid', - self.dim_ordering) - - # Reshape to get the var in the cells - _, w, h, c1, c2, c3 = self.__int_shape(patches_pred) - patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3]) - patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3]) - # Get mean - u_true = K.mean(patches_true, axis=-1) - u_pred = K.mean(patches_pred, axis=-1) - # Get variance - var_true = K.var(patches_true, axis=-1) - var_pred = K.var(patches_pred, axis=-1) - # Get std dev - covar_true_pred = K.mean( - patches_true * patches_pred, axis=-1) - u_true * u_pred - - ssim = (2 * u_true * u_pred + self.c1) * ( - 2 * covar_true_pred + self.c2) - denom = (K.square(u_true) + K.square(u_pred) + self.c1) * ( - var_pred + var_true + self.c2) - ssim /= denom # no need for clipping, c1 + c2 make the denom non-zero - return K.mean((1.0 - ssim) / 2.0) - - @staticmethod - def _preprocess_padding(padding): - """Convert keras' padding to tensorflow's padding. - # Arguments - padding: string, `"same"` or `"valid"`. - # Returns - a string, `"SAME"` or `"VALID"`. - # Raises - ValueError: if `padding` is invalid. - """ - if padding == 'same': - padding = 'SAME' - elif padding == 'valid': - padding = 'VALID' - else: - raise ValueError('Invalid padding:', padding) - return padding - - def extract_image_patches(self, x, ksizes, ssizes, padding='same', - data_format='channels_last'): - ''' - Extract the patches from an image - # Parameters - x : The input image - ksizes : 2-d tuple with the kernel size - ssizes : 2-d tuple with the strides size - padding : 'same' or 'valid' - data_format : 'channels_last' or 'channels_first' - # Returns - The (k_w,k_h) patches extracted - TF ==> (batch_size,w,h,k_w,k_h,c) - TH ==> (batch_size,w,h,c,k_w,k_h) - ''' - kernel = [1, ksizes[0], ksizes[1], 1] - strides = [1, ssizes[0], ssizes[1], 1] - padding = self._preprocess_padding(padding) - if data_format == 'channels_first': - x = K.permute_dimensions(x, (0, 2, 3, 1)) - _, _, _, ch_i = K.int_shape(x) - patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1], - padding) - # Reshaping to fit Theano - _, w, h, ch = K.int_shape(patches) - patches = tf.reshape(tf.transpose(tf.reshape(patches, - [-1, w, h, - tf.floordiv(ch, ch_i), - ch_i]), - [0, 1, 2, 4, 3]), - [-1, w, h, ch_i, ksizes[0], ksizes[1]]) - if data_format == 'channels_last': - patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3]) - return patches - -# <<< START: from Dfaker >>> # -class PenalizedLoss(): # pylint: disable=too-few-public-methods - """ Penalized Loss - from: https://github.com/dfaker/df """ - def __init__(self, mask, loss_func, mask_prop=1.0): - self.mask = mask - self.loss_func = loss_func - self.mask_prop = mask_prop - self.mask_as_k_inv_prop = 1-mask_prop - - def __call__(self, y_true, y_pred): - # pylint: disable=invalid-name - tro, tgo, tbo = tf.split(y_true, 3, 3) - pro, pgo, pbo = tf.split(y_pred, 3, 3) - - tr = tro - tg = tgo - tb = tbo - - pr = pro - pg = pgo - pb = pbo - m = self.mask - - m = m * self.mask_prop - m += self.mask_as_k_inv_prop - tr *= m - tg *= m - tb *= m - - pr *= m - pg *= m - pb *= m - - y = tf.concat([tr, tg, tb], 3) - p = tf.concat([pr, pg, pb], 3) - - # yo = tf.stack([tro,tgo,tbo],3) - # po = tf.stack([pro,pgo,pbo],3) - - return self.loss_func(y, p) -# <<< END: from Dfaker >>> # - - -# <<< START: from Shoanlu GAN >>> # -def first_order(var_x, axis=1): - """ First Order Function from Shoanlu GAN """ - img_nrows = var_x.shape[1] - img_ncols = var_x.shape[2] - if axis == 1: - return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, 1:, :img_ncols - 1, :]) - if axis == 2: - return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, :img_nrows - 1, 1:, :]) - return None - - -def calc_loss(pred, target, loss='l2'): - """ Calculate Loss from Shoanlu GAN """ - if loss.lower() == "l2": - return K.mean(K.square(pred - target)) - if loss.lower() == "l1": - return K.mean(K.abs(pred - target)) - if loss.lower() == "cross_entropy": - return -K.mean(K.log(pred + K.epsilon()) * target + - K.log(1 - pred + K.epsilon()) * (1 - target)) - raise ValueError('Recieve an unknown loss type: {}.'.format(loss)) - - -def cyclic_loss(net_g1, net_g2, real1): - """ Cyclic Loss Function from Shoanlu GAN """ - fake2 = net_g2(real1)[-1] # fake2 ABGR - fake2 = Lambda(lambda x: x[:, :, :, 1:])(fake2) # fake2 BGR - cyclic1 = net_g1(fake2)[-1] # cyclic1 ABGR - cyclic1 = Lambda(lambda x: x[:, :, :, 1:])(cyclic1) # cyclic1 BGR - loss = calc_loss(cyclic1, real1, loss='l1') - return loss - - -def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights): - """ Adversarial Loss Function from Shoanlu GAN """ - alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) - fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) - fake = alpha * fake_bgr + (1-alpha) * distorted - - if gan_training == "mixup_LSGAN": - dist = Beta(0.2, 0.2) - lam = dist.sample() - mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted]) - pred_fake = net_d(concatenate([fake, distorted])) - pred_mixup = net_d(mixup) - loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2") - loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2") - mixup2 = lam * concatenate([real, - distorted]) + (1 - lam) * concatenate([fake_bgr, - distorted]) - pred_fake_bgr = net_d(concatenate([fake_bgr, distorted])) - pred_mixup2 = net_d(mixup2) - loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2") - loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2") - elif gan_training == "relativistic_avg_LSGAN": - real_pred = net_d(concatenate([real, distorted])) - fake_pred = net_d(concatenate([fake, distorted])) - loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2 - loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2 - loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred))) - - fake_pred2 = net_d(concatenate([fake_bgr, distorted])) - loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) - - K.ones_like(fake_pred2)))/2 - loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) - - K.zeros_like(fake_pred2)))/2 - loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) - - K.zeros_like(fake_pred2)))/2 - loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) - - K.ones_like(fake_pred2)))/2 - else: - raise ValueError("Receive an unknown GAN training method: {gan_training}") - return loss_d, loss_g - - -def reconstruction_loss(real, fake_abgr, mask_eyes, model_outputs, **weights): - """ Reconstruction Loss Function from Shoanlu GAN """ - alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) - fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) - - loss_g = 0 - loss_g += weights['w_recon'] * calc_loss(fake_bgr, real, "l1") - loss_g += weights['w_eyes'] * K.mean(K.abs(mask_eyes*(fake_bgr - real))) - - for out in model_outputs[:-1]: - out_size = out.get_shape().as_list() - resized_real = tf.image.resize_images(real, out_size[1:3]) - loss_g += weights['w_recon'] * calc_loss(out, resized_real, "l1") - return loss_g - - -def edge_loss(real, fake_abgr, mask_eyes, **weights): - """ Edge Loss Function from Shoanlu GAN """ - alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) - fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) - - loss_g = 0 - loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=1), - first_order(real, axis=1), "l1") - loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=2), - first_order(real, axis=2), "l1") - shape_mask_eyes = mask_eyes.get_shape().as_list() - resized_mask_eyes = tf.image.resize_images(mask_eyes, - [shape_mask_eyes[1]-1, shape_mask_eyes[2]-1]) - loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes * - (first_order(fake_bgr, axis=1) - - first_order(real, axis=1)))) - loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes * - (first_order(fake_bgr, axis=2) - - first_order(real, axis=2)))) - return loss_g - - -def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights): - """ Perceptual Loss Function from Shoanlu GAN """ - alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) - fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) - fake = alpha * fake_bgr + (1-alpha) * distorted - - def preprocess_vggface(var_x): - var_x = (var_x + 1) / 2 * 255 # channel order: BGR - var_x -= [91.4953, 103.8827, 131.0912] - return var_x - - real_sz224 = tf.image.resize_images(real, [224, 224]) - real_sz224 = Lambda(preprocess_vggface)(real_sz224) - dist = Beta(0.2, 0.2) - lam = dist.sample() # use mixup trick here to reduce foward pass from 2 times to 1. - mixup = lam*fake_bgr + (1-lam)*fake - fake_sz224 = tf.image.resize_images(mixup, [224, 224]) - fake_sz224 = Lambda(preprocess_vggface)(fake_sz224) - real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224) - fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224) - - # Apply instance norm on VGG(ResNet) features - # From MUNIT https://github.com/NVlabs/MUNIT - loss_g = 0 - - def instnorm(): - return InstanceNormalization() - - loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7), - instnorm()(real_feat7), "l2") - loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28), - instnorm()(real_feat28), "l2") - loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55), - instnorm()(real_feat55), "l2") - loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112), - instnorm()(real_feat112), "l2") - return loss_g - -# <<< END: from Shoanlu GAN >>> # - - -def generalized_loss_function(y_true, y_pred, a = 1.0, c=1.0/255.0): - ''' - generalized function used to return a large variety of mathematical loss functions - primary benefit is smooth, differentiable version of L1 loss - - Barron, J. A More General Robust Loss Function - https://arxiv.org/pdf/1701.03077.pdf - - Parameters: - a: penalty factor. larger number give larger weight to large deviations - c: scale factor used to adjust to the input scale (i.e. inputs of mean 1e-4 or 256 ) - - Return: - a loss value from the results of function(y_pred - y_true) - - Example: - a=1.0, x>>c , c=1.0/255.0 will give a smoothly differentiable version of L1 / MAE loss - a=1.999999 (lim as a->2), c=1.0/255.0 will give L2 / RMSE loss - ''' - x = y_pred - y_true - loss = (K.abs(2.0-a)/a) * ( K.pow( K.pow(x/c, 2.0)/K.abs(2.0-a) + 1.0 , (a/2.0)) - 1.0 ) - return K.mean(loss, axis=-1) * c - - -def staircase_loss(y_true, y_pred, a = 16.0, c=1.0/255.0): - h = c - w = c - x = K.clip(K.abs(y_true - y_pred) - 0.5 * c, 0.0, 1.0) - loss = h*( K.tanh(a*((x/w)-tf.floor(x/w)-0.5)) / ( 2.0*K.tanh(a/2.0) ) + 0.5 + tf.floor(x/w)) - loss += 1e-10 - return K.mean(loss, axis=-1) - - -def gradient_loss(y_true, y_pred): - ''' - Calculates the first and second order gradient difference between pixels of an image in the x and y dimensions. - These gradients are then compared between the ground truth and the predicted image and the difference is taken. - The difference used is a smooth L1 norm ( approximate to MAE but differable at zero ) - When used as a loss, its minimization will result in predicted images approaching the same level of sharpness - / blurriness as the ground truth. - - TV+TV2 Regularization with Nonconvex Sparseness-Inducing Penalty for Image Restoration, Chengwu Lu & Hua Huang, 2014 - (http://downloads.hindawi.com/journals/mpe/2014/790547.pdf) - - Parameters: - y_true: The predicted frames at each scale. - y_true: The ground truth frames at each scale - - Return: - The GD loss. - ''' - - assert 4 == K.ndim(y_true) - y_true.set_shape([None,80,80,3]) - y_pred.set_shape([None,80,80,3]) - TV_weight = 1.0 - TV2_weight = 1.0 - loss = 0.0 - - def diff_x(X): - Xleft = X[:, :, 1, :] - X[:, :, 0, :] - Xinner = tf.unstack(X[:, :, 2:, :] - X[:, :, :-2, :], axis=2) - Xright = X[:, :, -1, :] - X[:, :, -2, :] - Xout = [Xleft] + Xinner + [Xright] - Xout = tf.stack(Xout,axis=2) - return Xout * 0.5 - - def diff_y(X): - Xtop = X[:, 1, :, :] - X[:, 0, :, :] - Xinner = tf.unstack(X[:, 2:, :, :] - X[:, :-2, :, :], axis=1) - Xbot = X[:, -1, :, :] - X[:, -2, :, :] - Xout = [Xtop] + Xinner + [Xbot] - Xout = tf.stack(Xout,axis=1) - return Xout * 0.5 - - def diff_xx(X): - Xleft = X[:, :, 1, :] + X[:, :, 0, :] - Xinner = tf.unstack(X[:, :, 2:, :] + X[:, :, :-2, :], axis=2) - Xright = X[:, :, -1, :] + X[:, :, -2, :] - Xout = [Xleft] + Xinner + [Xright] - Xout = tf.stack(Xout,axis=2) - return Xout - 2.0 * X - - def diff_yy(X): - Xtop = X[:, 1, :, :] + X[:, 0, :, :] - Xinner = tf.unstack(X[:, 2:, :, :] + X[:, :-2, :, :], axis=1) - Xbot = X[:, -1, :, :] + X[:, -2, :, :] - Xout = [Xtop] + Xinner + [Xbot] - Xout = tf.stack(Xout,axis=1) - return Xout - 2.0 * X - - def diff_xy(X): - #xout1 - top_left = X[:, 1, 1, :]+X[:, 0, 0, :] - inner_left = tf.unstack(X[:, 2:, 1, :]+X[:, :-2, 0, :], axis=1) - bot_left = X[:, -1, 1, :]+X[:, -2, 0, :] - X_left = [top_left] + inner_left + [bot_left] - X_left = tf.stack(X_left, axis=1) - - top_mid = X[:, 1, 2:, :]+X[:, 0, :-2, :] - mid_mid = tf.unstack(X[:, 2:, 2:, :]+X[:, :-2, :-2, :], axis=1) - bot_mid = X[:, -1, 2:, :]+X[:, -2, :-2, :] - X_mid = [top_mid] + mid_mid + [bot_mid] - X_mid = tf.stack(X_mid, axis=1) - - top_right = X[:, 1, -1, :]+X[:, 0, -2, :] - inner_right = tf.unstack(X[:, 2:, -1, :]+X[:, :-2, -2, :], axis=1) - bot_right = X[:, -1, -1, :]+X[:, -2, -2, :] - X_right = [top_right] + inner_right + [bot_right] - X_right = tf.stack(X_right, axis=1) - - X_mid = tf.unstack(X_mid, axis=2) - Xout1 = [X_left] + X_mid + [X_right] - Xout1 = tf.stack(Xout1, axis=2) - - #Xout2 - top_left = X[:, 0, 1, :]+X[:, 1, 0, :] - inner_left = tf.unstack(X[:, :-2, 1, :]+X[:, 2:, 0, :], axis=1) - bot_left = X[:, -2, 1, :]+X[:, -1, 0, :] - X_left = [top_left] + inner_left + [bot_left] - X_left = tf.stack(X_left, axis=1) - - top_mid = X[:, 0, 2:, :]+X[:, 1, :-2, :] - mid_mid = tf.unstack(X[:, :-2, 2:, :]+X[:, 2:, :-2, :], axis=1) - bot_mid = X[:, -2, 2:, :]+X[:, -1, :-2, :] - X_mid = [top_mid] + mid_mid + [bot_mid] - X_mid = tf.stack(X_mid, axis=1) - - top_right = X[:, 0, -1, :]+X[:, 1, -2, :] - inner_right = tf.unstack(X[:, :-2, -1, :]+X[:, 2:, -2, :], axis=1) - bot_right = X[:, -2, -1, :]+X[:, -1, -2, :] - X_right = [top_right] + inner_right + [bot_right] - X_right = tf.stack(X_right, axis=1) - - X_mid = tf.unstack(X_mid, axis=2) - Xout2 = [X_left] + X_mid + [X_right] - Xout2 = tf.stack(Xout2, axis=2) - - return (Xout1 - Xout2) * 0.25 - - loss += TV_weight * ( generalized_loss_function(diff_x(y_true), diff_x(y_pred), a=1.999999) + - generalized_loss_function(diff_y(y_true), diff_y(y_pred), a=1.999999) ) - - loss += TV2_weight * ( generalized_loss_function(diff_xx(y_true), diff_xx(y_pred), a=1.999999) + - generalized_loss_function(diff_yy(y_true), diff_yy(y_pred), a=1.999999) + - 2.0 * generalized_loss_function(diff_xy(y_true), diff_xy(y_pred), a=1.999999) ) - - return loss / ( TV_weight + TV2_weight ) - - -def scharr_edges(image, magnitude): - ''' - Returns a tensor holding modified Scharr edge maps. - Arguments: - image: Image tensor with shape [batch_size, h, w, d] and type float32. - The image(s) must be 2x2 or larger. - magnitude: Boolean to determine if the edge magnitude or edge direction is returned - Returns: - Tensor holding edge maps for each channel. Returns a tensor with shape - [batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]], - [dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Scharr filter. - ''' - - # Define vertical and horizontal Scharr filters. - static_image_shape = image.get_shape() - image_shape = tf.shape(image) - ''' - #modified 3x3 Scharr - kernels = [[[-17.0, -61.0, -17.0], [0.0, 0.0, 0.0], [17.0, 61.0, 17.0]], - [[-17.0, 0.0, 17.0], [-61.0, 0.0, 61.0], [-17.0, 0.0, 17.0]]] - ''' - # 5x5 Scharr - kernels = [[[-1.0, -2.0, -3.0, -2.0, -1.0], [-1.0, -2.0, -6.0, -2.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 6.0, 2.0, 1.0], [1.0, 2.0, 3.0, 2.0, 1.0]], - [[-1.0, -1.0, 0.0, 1.0, 1.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-3.0, -6.0, 0.0, 6.0, 3.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-1.0, -1.0, 0.0, 1.0, 1.0]]] - num_kernels = len(kernels) - kernels = np.transpose(np.asarray(kernels), (1, 2, 0)) - kernels = np.expand_dims(kernels, -2) / np.sum(np.abs(kernels)) - kernels_tf = tf.constant(kernels, dtype=image.dtype) - kernels_tf = tf.tile(kernels_tf, [1, 1, image_shape[-1], 1], name='scharr_filters') - - # Use depth-wise convolution to calculate edge maps per channel. - pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]] - padded = tf.pad(image, pad_sizes, mode='REFLECT') - - # Output tensor has shape [batch_size, h, w, d * num_kernels]. - strides = [1, 1, 1, 1] - output = tf.nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID') - - # Reshape to [batch_size, h, w, d, num_kernels]. - shape = tf.concat([image_shape, [num_kernels]], 0) - output = tf.reshape(output, shape=shape) - output.set_shape(static_image_shape.concatenate([num_kernels])) - - if magnitude: # magnitude of edges - output = tf.sqrt(tf.reduce_sum(tf.square(output),axis=-1)) - else: # direction of edges - output = tf.atan(tf.squeeze(tf.div(output[:,:,:,:,0]/output[:,:,:,:,1]))) - - return output - - -def gmsd_loss(y_true,y_pred): - ''' - Improved image quality metric over MS-SSIM with easier calc - http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm - https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf - ''' - true_edge_mag = scharr_edges(y_true,True) - pred_edge_mag = scharr_edges(y_pred,True) - c = 0.002 - upper = 2.0 * tf.multiply(true_edge_mag,pred_edge_mag) + c - lower = tf.square(true_edge_mag) + tf.square(pred_edge_mag) + c - GMS = tf.div(upper,lower) - _mean, _var = tf.nn.moments(GMS, axes=[1,2], keep_dims=True) - GMSD = tf.reduce_mean(tf.sqrt(_var), axis=-1) # single metric value per image in tensor [?,1,1] - return K.tile(GMSD,[1,64,64]) # need to expand to [?,height,width] dimensions for Keras ... modify to not be hard-coded - - -def ms_ssim(img1, img2, max_val=1.0, power_factors=(0.0517, 0.3295, 0.3462, 0.2726)): - ''' - Computes the MS-SSIM between img1 and img2. - This function assumes that `img1` and `img2` are image batches, i.e. the last - three dimensions are [height, width, channels]. - Note: The true SSIM is only defined on grayscale. This function does not - perform any colorspace transform. (If input is already YUV, then it will - compute YUV SSIM average.) - Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale - structural similarity for image quality assessment." Signals, Systems and - Computers, 2004. - Arguments: - img1: First image batch. - img2: Second image batch. Must have the same rank as img1. - max_val: The dynamic range of the images (i.e., the difference between the - maximum the and minimum allowed values). - power_factors: Iterable of weights for each of the scales. The number of - scales used is the length of the list. Index 0 is the unscaled - resolution's weight and each increasing scale corresponds to the image - being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363, - 0.1333), which are the values obtained in the original paper. - Returns: - A tensor containing an MS-SSIM value for each image in batch. The values - are in range [0, 1]. Returns a tensor with shape: - broadcast(img1.shape[:-3], img2.shape[:-3]). - ''' - - def _verify_compatible_image_shapes(img1, img2): - ''' - Checks if two image tensors are compatible for applying SSIM or PSNR. - This function checks if two sets of images have ranks at least 3, and if the - last three dimensions match. - Args: - img1: Tensor containing the first image batch. - img2: Tensor containing the second image batch. - Returns: - A tuple containing: the first tensor shape, the second tensor shape, and a - list of control_flow_ops.Assert() ops implementing the checks. - Raises: - ValueError: When static shape check fails. - ''' - shape1 = img1.get_shape().with_rank_at_least(3) - shape2 = img2.get_shape().with_rank_at_least(3) - shape1[-3:].assert_is_compatible_with(shape2[-3:]) - - if shape1.ndims is not None and shape2.ndims is not None: - for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])): - if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)): - raise ValueError('Two images are not compatible: %s and %s' % (shape1, shape2)) - - # Now assign shape tensors. - shape1, shape2 = tf.shape_n([img1, img2]) - - # TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable. - checks = [] - checks.append(tf.Assert(tf.greater_equal(tf.size(shape1), 3),[shape1, shape2], summarize=10)) - checks.append(tf.Assert(tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])),[shape1, shape2], summarize=10)) - - return shape1, shape2, checks - - def _ssim_per_channel(img1, img2, max_val=1.0): - ''' - Computes SSIM index between img1 and img2 per color channel. - This function matches the standard SSIM implementation from: - Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image - quality assessment: from error visibility to structural similarity. IEEE - transactions on image processing. - Details: - - 11x11 Gaussian filter of width 1.5 is used. - - k1 = 0.01, k2 = 0.03 as in the original paper. - Args: - img1: First image batch. - img2: Second image batch. - max_val: The dynamic range of the images (i.e., the difference between the - maximum the and minimum allowed values). - Returns: - A pair of tensors containing and channel-wise SSIM and contrast-structure - values. The shape is [..., channels]. - ''' - - def _fspecial_gauss(size, sigma): - ''' - Function to mimic the 'fspecial' gaussian MATLAB function. - ''' - size = tf.convert_to_tensor(size, 'int32') - sigma = tf.convert_to_tensor(sigma) - - coords = tf.cast(tf.range(size), sigma.dtype) - coords -= tf.cast(size - 1, sigma.dtype) / 2.0 - - g = tf.square(coords) - g *= -0.5 / tf.square(sigma) - - g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1]) - g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax(). - g = tf.nn.softmax(g) - return tf.reshape(g, shape=[size, size, 1, 1]) - - def _ssim_helper(x, y, max_val, kernel, compensation=1.0): - ''' - Helper function for computing SSIM. - SSIM estimates covariances with weighted sums. The default parameters - use a biased estimate of the covariance: - Suppose `reducer` is a weighted sum, then the mean estimators are - \mu_x = \sum_i w_i x_i, - \mu_y = \sum_i w_i y_i, - where w_i's are the weighted-sum weights, and covariance estimator is - cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) - with assumption \sum_i w_i = 1. This covariance estimator is biased, since - E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y). - For SSIM measure with unbiased covariance estimators, pass as `compensation` - argument (1 - \sum_i w_i ^ 2). - Arguments: - x: First set of images. - y: Second set of images. - reducer: Function that computes 'local' averages from set of images. - For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]), - and for convolutional version, this is usually tf.nn.avg_pool or - tf.nn.conv2d with weighted-sum kernel. - max_val: The dynamic range (i.e., the difference between the maximum - possible allowed value and the minimum allowed value). - compensation: Compensation factor. See above. - Returns: - A pair containing the luminance measure, and the contrast-structure measure. - ''' - - def reducer(x, kernel): - shape = tf.shape(x) - x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0)) - y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID') - return tf.reshape(y, tf.concat([shape[:-3],tf.shape(y)[1:]], 0)) - - _SSIM_K1 = 0.01 - _SSIM_K2 = 0.03 - - c1 = (_SSIM_K1 * max_val) ** 2 - c2 = (_SSIM_K2 * max_val) ** 2 - - # SSIM luminance measure is - # (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1). - mean0 = reducer(x, kernel) - mean1 = reducer(y, kernel) - num0 = mean0 * mean1 * 2.0 - den0 = tf.square(mean0) + tf.square(mean1) - luminance = (num0 + c1) / (den0 + c1) - - # SSIM contrast-structure measure is - # (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2). - # Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then - # cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) - # = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j). - num1 = reducer(x * y, kernel) * 2.0 - den1 = reducer(tf.square(x) + tf.square(y), kernel) - c2 *= compensation - cs = (num1 - num0 + c2) / (den1 - den0 + c2) - - # SSIM score is the product of the luminance and contrast-structure measures. - return luminance, cs - - filter_size = tf.constant(9, dtype='int32') # changed from 11 to 9 due - filter_sigma = tf.constant(1.5, dtype=img1.dtype) - - shape1, shape2 = tf.shape_n([img1, img2]) - checks = [tf.Assert(tf.reduce_all(tf.greater_equal(shape1[-3:-1], filter_size)),[shape1, filter_size], summarize=8), - tf.Assert(tf.reduce_all(tf.greater_equal(shape2[-3:-1], filter_size)),[shape2, filter_size], summarize=8)] - - # Enforce the check to run before computation. - with tf.control_dependencies(checks): - img1 = tf.identity(img1) - - # TODO(sjhwang): Try to cache kernels and compensation factor. - kernel = _fspecial_gauss(filter_size, filter_sigma) - kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1]) - - # The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`, - # but to match MATLAB implementation of MS-SSIM, we use 1.0 instead. - compensation = 1.0 - - # TODO(sjhwang): Try FFT. - # TODO(sjhwang): Gaussian kernel is separable in space. Consider applying - # 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter. - - luminance, cs = _ssim_helper(img1, img2, max_val, kernel, compensation) - - # Average over the second and the third from the last: height, width. - axes = tf.constant([-3, -2], dtype='int32') - ssim_val = tf.reduce_mean(luminance * cs, axes) - cs = tf.reduce_mean(cs, axes) - return ssim_val, cs - - def do_pad(images, remainder): - padding = tf.expand_dims(remainder, -1) - padding = tf.pad(padding, [[1, 0], [1, 0]]) - return [tf.pad(x, padding, mode='SYMMETRIC') for x in images] - - # Shape checking. - shape1 = img1.get_shape().with_rank_at_least(3) - shape2 = img2.get_shape().with_rank_at_least(3) - shape1[-3:].merge_with(shape2[-3:]) - - with tf.name_scope(None, 'MS-SSIM', [img1, img2]): - shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2) - with tf.control_dependencies(checks): - img1 = tf.identity(img1) - - # Need to convert the images to float32. Scale max_val accordingly so that - # SSIM is computed correctly. - max_val = tf.cast(max_val, img1.dtype) - max_val = tf.image.convert_image_dtype(max_val, 'float32') - img1 = tf.image.convert_image_dtype(img1, 'float32') - img2 = tf.image.convert_image_dtype(img2, 'float32') - - imgs = [img1, img2] - shapes = [shape1, shape2] - - # img1 and img2 are assumed to be a (multi-dimensional) batch of - # 3-dimensional images (height, width, channels). `heads` contain the batch - # dimensions, and `tails` contain the image dimensions. - heads = [s[:-3] for s in shapes] - tails = [s[-3:] for s in shapes] - - divisor = [1, 2, 2, 1] - divisor_tensor = tf.constant(divisor[1:], dtype='int32') - - mcs = [] - for k in range(len(power_factors)): - with tf.name_scope(None, 'Scale%d' % k, imgs): - if k > 0: - # Avg pool takes rank 4 tensors. Flatten leading dimensions. - flat_imgs = [tf.reshape(x, tf.concat([[-1], t], 0)) for x, t in zip(imgs, tails)] - - remainder = tails[0] % divisor_tensor - need_padding = tf.reduce_any(tf.not_equal(remainder, 0)) - padded = tf.cond(need_padding,lambda: do_pad(flat_imgs, remainder), - lambda: flat_imgs) - - downscaled = [tf.nn.avg_pool(x, ksize=divisor, strides=divisor, padding='VALID') - for x in padded] - tails = [x[1:] for x in tf.shape_n(downscaled)] - imgs = [tf.reshape(x, tf.concat([h, t], 0)) for x, h, t in zip(downscaled, heads, tails)] - - # Overwrite previous ssim value since we only need the last one. - ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val) - mcs.append(tf.nn.relu(cs)) - - # Remove the cs score for the last scale. In the MS-SSIM calculation, - # we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p). - mcs.pop() # Remove the cs score for the last scale. - mcs_and_ssim = tf.stack(mcs + [tf.nn.relu(ssim_per_channel)],axis=-1) - # Take weighted geometric mean across the scale axis. - ms_ssim = tf.reduce_prod(tf.pow(mcs_and_ssim, power_factors),[-1]) - - return tf.reduce_mean(ms_ssim, [-1]) # Avg over color channels. - - -def ms_ssim_loss(y_true,y_pred): - MSSSIM = K.expand_dims(K.expand_dims(1.0 - ms_ssim(y_true, y_pred),axis=-1), axis=-1) - return K.tile(MSSSIM,[1,64,64]) # need to expand to [1,height,width] dimensions for Keras ... modify to not be hard-coded diff --git a/lib/model/losses/__init__.py b/lib/model/losses/__init__.py new file mode 100644 index 0000000000..751e791011 --- /dev/null +++ b/lib/model/losses/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +""" Custom Loss Functions for Faceswap """ + +from .feature_loss import LPIPSLoss +from .loss import (FocalFrequencyLoss, GeneralizedLoss, GradientLoss, + LaplacianPyramidLoss, LInfNorm, LossWrapper) +from .perceptual_loss import DSSIMObjective, GMSDLoss, LDRFLIPLoss, MSSIMLoss diff --git a/lib/model/losses/feature_loss.py b/lib/model/losses/feature_loss.py new file mode 100644 index 0000000000..9e96841481 --- /dev/null +++ b/lib/model/losses/feature_loss.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +""" Custom Feature Map Loss Functions for faceswap.py """ +from __future__ import annotations +from dataclasses import dataclass, field +import logging +import typing as T + +import keras +from keras import applications as kapp, layers, Model, ops, Variable + +import numpy as np + +from lib.logger import parse_class_init +from lib.model.networks import AlexNet, SqueezeNet +from lib.utils import get_module_objects, GetModel + +if T.TYPE_CHECKING: + from collections.abc import Callable + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +@dataclass +class NetInfo: + """ Data class for holding information about Trunk and Linear Layer nets. + + Parameters + ---------- + model_id: int + The model ID for the model stored in the deepfakes Model repo + model_name: str + The filename of the decompressed model/weights file + net: callable, Optional + The net definition to load, if any. Default:``None`` + init_kwargs: dict, optional + Keyword arguments to initialize any :attr:`net`. Default: empty ``dict`` + needs_init: bool, optional + True if the net needs initializing otherwise False. Default: ``True`` + """ + model_id: int = 0 + model_name: str = "" + net: Callable | None = None + init_kwargs: dict[str, T.Any] = field(default_factory=dict) + needs_init: bool = True + outputs: list[str] = field(default_factory=list) + + +class _LPIPSTrunkNet(): + """ Trunk neural network loader for LPIPS Loss function. + + Parameters + ---------- + net_name: str + The name of the trunk network to load. One of "alex", "squeeze" or "vgg16" + eval_mode: bool + ``True`` for evaluation mode, ``False`` for training mode + load_weights: bool + ``True`` if pretrained trunk network weights should be loaded, otherwise ``False`` + """ + def __init__(self, + net_name: T.Literal["alex", "squeeze", "vgg16"], + eval_mode: bool, + load_weights: bool) -> None: + logger.debug(parse_class_init(locals())) + self._eval_mode = eval_mode + self._load_weights = load_weights + self._net_name = net_name + self._net = self._nets[net_name] + logger.debug("Initialized: %s ", self.__class__.__name__) + + @property + def _nets(self) -> dict[str, NetInfo]: + """ :class:`NetInfo`: The Information about the requested net.""" + return { + "alex": NetInfo(model_id=15, + model_name="alexnet_imagenet_no_top_v1.h5", + net=AlexNet, + outputs=[f"features_{idx}" for idx in (0, 3, 6, 8, 10)]), + "squeeze": NetInfo(model_id=16, + model_name="squeezenet_imagenet_no_top_v1.h5", + net=SqueezeNet, + outputs=[f"features_{idx}" for idx in (0, 4, 7, 9, 10, 11, 12)]), + "vgg16": NetInfo(model_id=17, + model_name="vgg16_imagenet_no_top_v1.h5", + net=kapp.vgg16.VGG16, + init_kwargs={"include_top": False, "weights": None}, + outputs=[f"block{i + 1}_conv{2 if i < 2 else 3}" for i in range(5)])} + + @classmethod + def _normalize_output(cls, inputs: KerasTensor, epsilon: float = 1e-10) -> KerasTensor: + """ Normalize the output tensors from the trunk network. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + An output tensor from the trunk model + epsilon: float, optional + Epsilon to apply to the normalization operation. Default: `1e-10` + """ + norm_factor = ops.sqrt(ops.sum(ops.square(inputs), axis=-1, keepdims=True)) + return inputs / (norm_factor + epsilon) + + def _process_weights(self, model: Model) -> Model: + """ Save and lock weights if requested. + + Parameters + ---------- + model :class:`keras.models.Model` + The loaded trunk or linear network + + Returns + ------- + :class:`keras.models.Model` + The network with weights loaded/not loaded and layers locked/unlocked + """ + if self._load_weights: + weights = GetModel(self._net.model_name, self._net.model_id).model_path + model.load_weights(weights) + + if self._eval_mode: + model.trainable = False + for layer in model.layers: + layer.trainable = False + return model + + def __call__(self) -> Model: + """ Load the Trunk net, add normalization to feature outputs, load weights and set + trainable state. + + Returns + ------- + :class:`keras.models.Model` + The trunk net with normalized feature output layers + """ + if self._net.net is None: + raise ValueError("No net loaded") + + model = self._net.net(**self._net.init_kwargs) + model = model if self._net_name == "vgg16" else model() + out_layers = [self._normalize_output(model.get_layer(name).output) + for name in self._net.outputs] + model = Model(inputs=model.input, outputs=out_layers) + model = self._process_weights(model) + return model + + +class _LPIPSLinearNet(_LPIPSTrunkNet): + """ The Linear Network to be applied to the difference between the true and predicted outputs + of the trunk network. + + Parameters + ---------- + net_name: str + The name of the trunk network in use. One of "alex", "squeeze" or "vgg16" + eval_mode: bool + ``True`` for evaluation mode, ``False`` for training mode + load_weights: bool + ``True`` if pretrained linear network weights should be loaded, otherwise ``False`` + trunk_net: :class:`keras.models.Model` + The trunk net to place the linear layer on. + use_dropout: bool + ``True`` if a dropout layer should be used in the Linear network otherwise ``False`` + """ + def __init__(self, + net_name: T.Literal["alex", "squeeze", "vgg16"], + eval_mode: bool, + load_weights: bool, + trunk_net: Model, + use_dropout: bool) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(net_name=net_name, eval_mode=eval_mode, load_weights=load_weights) + + self._trunk = trunk_net + self._use_dropout = use_dropout + + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def _nets(self) -> dict[str, NetInfo]: + """ :class:`NetInfo`: The Information about the requested net.""" + return { + "alex": NetInfo(model_id=18, + model_name="alexnet_lpips_v1.h5",), + "squeeze": NetInfo(model_id=19, + model_name="squeezenet_lpips_v1.h5"), + "vgg16": NetInfo(model_id=20, + model_name="vgg16_lpips_v1.h5")} + + def _linear_block(self, net_output_layer: KerasTensor) -> tuple[KerasTensor, KerasTensor]: + """ Build a linear block for a trunk network output. + + Parameters + ---------- + net_output_layer: :class:`keras.KerasTensor` + An output from the selected trunk network + + Returns + ------- + :class:`keras.KerasTensor` + The input to the linear block + :class:`keras.KerasTensor` + The output from the linear block + """ + in_shape = net_output_layer.shape[1:] + input_ = T.cast("KerasTensor", layers.Input(in_shape)) + var_x = layers.Dropout(rate=0.5)(input_) if self._use_dropout else input_ + var_x = layers.Conv2D(1, 1, strides=1, padding="valid", use_bias=False)(var_x) + return input_, var_x + + def __call__(self) -> Model: + """ Build the linear network for the given trunk network's outputs. Load in trained weights + and set the model's trainable parameters. + + Returns + ------- + :class:`keras.models.Model` + The compiled Linear Net model + """ + inputs = [] + outputs = [] + + for input_ in self._trunk.outputs: + in_, out = self._linear_block(input_) + inputs.append(in_) + outputs.append(out) + + model = Model(inputs=inputs, outputs=outputs) + model = self._process_weights(model) + return model + + +class LPIPSLoss(keras.losses.Loss): + """ LPIPS Loss Function. + + A perceptual loss function that uses linear outputs from pretrained CNNs feature layers. + + Notes + ----- + Channels Last implementation. All trunks implemented from the original paper. + + References + ---------- + https://richzhang.github.io/PerceptualSimilarity/ + + Parameters + ---------- + trunk_network: str + The name of the trunk network to use. One of "alex", "squeeze" or "vgg16" + trunk_pretrained: bool, optional + ``True`` Load the imagenet pretrained weights for the trunk network. ``False`` randomly + initialize the trunk network. Default: ``True`` + trunk_eval_mode: bool, optional + ``True`` for running inference on the trunk network (standard mode), ``False`` for training + the trunk network. Default: ``True`` + linear_pretrained: bool, optional + ``True`` loads the pretrained weights for the linear network layers. ``False`` randomly + initializes the layers. Default: ``True`` + linear_eval_mode: bool, optional + ``True`` for running inference on the linear network (standard mode), ``False`` for + training the linear network. Default: ``True`` + linear_use_dropout: bool, optional + ``True`` if a dropout layer should be used in the Linear network otherwise ``False``. + Default: ``True`` + lpips: bool, optional + ``True`` to use linear network on top of the trunk network. ``False`` to just average the + output from the trunk network. Default ``True`` + spatial: bool, optional + ``True`` output the loss in the spatial domain (i.e. as a grayscale tensor of height and + width of the input image). ``Bool`` reduce the spatial dimensions for loss calculation. + Default: ``False`` + normalize: bool, optional + ``True`` if the input Tensor needs to be normalized from the 0. to 1. range to the -1. to + 1. range. Default: ``True`` + ret_per_layer: bool, optional + ``True`` to return the loss value per feature output layer otherwise ``False``. + Default: ``False`` + """ + def __init__(self, # pylint:disable=too-many-arguments,too-many-positional-arguments + trunk_network: T.Literal["alex", "squeeze", "vgg16"], + trunk_pretrained: bool = True, + trunk_eval_mode: bool = True, + linear_pretrained: bool = True, + linear_eval_mode: bool = True, + linear_use_dropout: bool = True, + lpips: bool = True, + spatial: bool = False, + normalize: bool = True, + ret_per_layer: bool = False) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self._spatial = spatial + self._use_lpips = lpips + self._normalize = normalize + self._ret_per_layer = ret_per_layer + self._shift = Variable(np.array([-.030, -.088, -.188], + dtype="float32")[None, None, None, :], + trainable=False) + self._scale = Variable(np.array([.458, .448, .450], dtype="float32")[None, None, None, :], + trainable=False) + + # Loss needs to be done as fp32. We could cast at output, but better to update the model + switch_mixed_precision = keras.mixed_precision.global_policy().name == "mixed_float16" + if switch_mixed_precision: + logger.debug("Temporarily disabling mixed precision") + keras.mixed_precision.set_global_policy("float32") + + self._trunk_net = _LPIPSTrunkNet(trunk_network, trunk_eval_mode, trunk_pretrained)() + self._linear_net = _LPIPSLinearNet(trunk_network, + linear_eval_mode, + linear_pretrained, + self._trunk_net, + linear_use_dropout)() + if switch_mixed_precision: + logger.debug("Re-enabling mixed precision") + keras.mixed_precision.set_global_policy("mixed_float16") + logger.debug("Initialized: %s", self.__class__.__name__) + + def _process_diffs(self, inputs: list[KerasTensor]) -> list[KerasTensor]: + """ Perform processing on the Trunk Network outputs. + + If :attr:`use_ldip` is enabled, process the diff values through the linear network, + otherwise return the diff values summed on the channels axis. + + Parameters + ---------- + inputs: list[:class:`keras.KerasTensor`] + List of the squared difference of the true and predicted outputs from the trunk network + + Returns + ------- + list[:class:`keras.KerasTensor`] + List of either the linear network outputs (when using lpips) or summed network outputs + """ + if self._use_lpips: + return self._linear_net(inputs) + return [T.cast("KerasTensor", ops.sum(x, axis=-1)) for x in inputs] + + def _process_output(self, inputs: KerasTensor, output_dims: tuple) -> KerasTensor: + """ Process an individual output based on whether :attr:`is_spatial` has been selected. + + When spatial output is selected, all outputs are sized to the shape of the original True + input Tensor. When not selected, the mean across the spatial axes (h, w) are returned + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + An individual diff output tensor from the linear network or summed output + output_dims: tuple + The (height, width) of the original true image + + Returns + ------- + :class:`keras.KerasTensor` + Either the original tensor resized to the true image dimensions, or the mean + value across the height, width axes. + """ + if self._spatial: + return layers.Resizing(*output_dims, interpolation="bilinear")(inputs) + return T.cast("KerasTensor", ops.mean(inputs, axis=(1, 2), keepdims=True)) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Perform the LPIPS Loss Function. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images + y_pred: :class:`keras.KerasTensor` + The predicted batch of images + + Returns + ------- + :class:`keras.KerasTensor` + The final loss value + """ + if self._normalize: + y_true = (y_true * 2.0) - 1.0 + y_pred = (y_pred * 2.0) - 1.0 + + y_true = (y_true - self._shift) / self._scale + y_pred = (y_pred - self._shift) / self._scale + + net_true = self._trunk_net(y_true) + net_pred = self._trunk_net(y_pred) + + diffs = [(out_true - out_pred) ** 2 + for out_true, out_pred in zip(net_true, net_pred)] + + dims = y_true.shape[1:3] + res = [self._process_output(diff, dims) for diff in self._process_diffs(diffs)] + + axis = 0 if self._spatial else None + val = T.cast("KerasTensor", ops.sum(res, axis=axis)) + + retval = (val, res) if self._ret_per_layer else val + assert not isinstance(retval, tuple) + return retval / 10.0 # Reduce by factor of 10 'cos this loss is STRONG + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/losses/loss.py b/lib/model/losses/loss.py new file mode 100644 index 0000000000..f8d375df67 --- /dev/null +++ b/lib/model/losses/loss.py @@ -0,0 +1,708 @@ +#!/usr/bin/env python3 +""" Custom Loss Functions for faceswap.py """ + +from __future__ import annotations +import logging +import typing as T + +import numpy as np +from keras import Loss, backend as K +from keras import ops, Variable + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if K.backend() == "torch": + import torch # pylint:disable=import-error +else: + import tensorflow as tf # pylint:disable=import-error # type:ignore + +if T.TYPE_CHECKING: + from collections.abc import Callable + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class FocalFrequencyLoss(Loss): + """ Focal Frequencey Loss Function. + + A channels last implementation. + + Notes + ----- + There is a bug in this implementation that will do an incorrect FFT if + :attr:`patch_factor` > ``1``, which means incorrect loss will be returned, so keep + patch factor at 1. + + Parameters + ---------- + alpha: float, Optional + Scaling factor of the spectrum weight matrix for flexibility. Default: ``1.0`` + patch_factor: int, Optional + Factor to crop image patches for patch-based focal frequency loss. + Default: ``1`` + ave_spectrum: bool, Optional + ``True`` to use minibatch average spectrum otherwise ``False``. Default: ``False`` + log_matrix: bool, Optional + ``True`` to adjust the spectrum weight matrix by logarithm otherwise ``False``. + Default: ``False`` + batch_matrix: bool, Optional + ``True`` to calculate the spectrum weight matrix using batch-based statistics otherwise + ``False``. Default: ``False`` + epsilon : float, Optional + Small epsilon for safer weights scaling division. Default: `1e-6` + + + References + ---------- + https://arxiv.org/pdf/2012.12821.pdf + https://github.com/EndlessSora/focal-frequency-loss + """ + + def __init__(self, + alpha: float = 1.0, + patch_factor: int = 1, + ave_spectrum: bool = False, + log_matrix: bool = False, + batch_matrix: bool = False, + epsilon: float = 1e-6) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self._alpha = alpha + # TODO Fix bug where FFT will be incorrect if patch_factor > 1 for tensorflow + self._patch_factor = patch_factor + self._ave_spectrum = ave_spectrum + self._log_matrix = log_matrix + self._batch_matrix = batch_matrix + self._epsilon = epsilon + self._dims: tuple[int, int] = (0, 0) + logger.debug("Initialized: %s", self.__class__.__name__) + + def _get_patches(self, inputs: KerasTensor) -> KerasTensor: + """ Crop the incoming batch of images into patches as defined by :attr:`_patch_factor. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + A batch of images to be converted into patches + + Returns + ------- + :class:`keras.KerasTensor`` + The incoming batch converted into patches + """ + patch_list = [] + patch_rows = self._dims[0] // self._patch_factor + patch_cols = self._dims[1] // self._patch_factor + for i in range(self._patch_factor): + for j in range(self._patch_factor): + row_from = i * patch_rows + row_to = (i + 1) * patch_rows + col_from = j * patch_cols + col_to = (j + 1) * patch_cols + patch_list.append(inputs[:, row_from: row_to, col_from: col_to, :]) + + retval = ops.stack(patch_list, axis=1) + return T.cast("KerasTensor", retval) + + def _tensor_to_frequency_spectrum(self, patch: KerasTensor) -> KerasTensor: + """ Perform FFT to create the orthonomalized DFT frequencies. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The incoming batch of patches to convert to the frequency spectrum + + Returns + ------- + :class:`keras.KerasTensor` + The DFT frequencies split into real and imaginary numbers as float32 + """ + patch = T.cast("KerasTensor", + ops.transpose(patch, (0, 1, 4, 2, 3))) # move channels to first + + assert K.backend() in ("torch", "tensorflow"), "Only Torch and Tensorflow are supported" + if K.backend() == "torch": + freq = torch.fft.fft2(patch, # pylint:disable=not-callable # type:ignore + norm="ortho") + else: + patch = patch / np.sqrt(self._dims[0] * self._dims[1]) # Orthonormalization + patch = T.cast("KerasTensor", ops.cast(patch, "complex64")) + freq = tf.signal.fft2d(patch)[..., None] # type:ignore + + freq = ops.stack([freq.real, freq.imag], axis=-1) + + if K.backend() == "tensorflow": + freq = ops.cast(freq, "float32") + + freq = ops.transpose(freq, (0, 1, 3, 4, 2, 5)) # channels to last + return T.cast("KerasTensor", freq) + + def _get_weight_matrix(self, freq_true: KerasTensor, freq_pred: KerasTensor) -> KerasTensor: + """ Calculate a continuous, dynamic weight matrix based on current Euclidean distance. + + Parameters + ---------- + freq_true: :class:`keras.KerasTensor` + The real and imaginary DFT frequencies for the true batch of images + freq_pred: :class:`keras.KerasTensor` + The real and imaginary DFT frequencies for the predicted batch of images + + Returns + ------- + :class:`keras.KerasTensor` + The weights matrix for prioritizing hard frequencies + """ + weights = ops.square(freq_pred - freq_true) + weights = ops.sqrt(weights[..., 0] + weights[..., 1]) + weights = ops.power(weights, self._alpha) + + if self._log_matrix: # adjust the spectrum weight matrix by logarithm + weights = ops.log(weights + 1.0) + + if self._batch_matrix: # calculate the spectrum weight matrix using batch-based statistics + scale = ops.max(weights) + else: + scale = ops.max(weights, axis=(-2, -3), keepdims=True) + weights = weights / ops.maximum(scale, self._epsilon) + + weights = ops.clip(weights, x_min=0.0, x_max=1.0) + + return T.cast("KerasTensor", weights) + + @classmethod + def _calculate_loss(cls, + freq_true: KerasTensor, + freq_pred: KerasTensor, + weight_matrix: KerasTensor) -> KerasTensor: + """ Perform the loss calculation on the DFT spectrum applying the weights matrix. + + Parameters + ---------- + freq_true: :class:`keras.KerasTensor` + The real and imaginary DFT frequencies for the true batch of images + freq_pred: :class:`keras.KerasTensor` + The real and imaginary DFT frequencies for the predicted batch of images + + Returns + :class:`keras.KerasTensor` + The final loss matrix + """ + + tmp = ops.square(freq_pred - freq_true) # freq distance using squared Euclidean distance + + freq_distance = tmp[..., 0] + tmp[..., 1] + loss = weight_matrix * freq_distance # dynamic spectrum weighting (Hadamard product) + + return T.cast("KerasTensor", ops.mean(loss)) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the Focal Frequency Loss Function. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images + y_pred: :class:`keras.KerasTensor` + The predicted batch of images + + Returns + ------- + :class:`keras.KerasTensor` + The loss for this batch of images + """ + if not all(self._dims): + rows, cols = y_true.shape[1:3] + assert rows is not None and cols is not None + assert cols % self._patch_factor == 0 and rows % self._patch_factor == 0, ( + "Patch factor must be a divisor of the image height and width") + self._dims = (rows, cols) + + patches_true = self._get_patches(y_true) + patches_pred = self._get_patches(y_pred) + + freq_true = self._tensor_to_frequency_spectrum(patches_true) + freq_pred = self._tensor_to_frequency_spectrum(patches_pred) + + if self._ave_spectrum: # whether to use minibatch average spectrum + freq_true = T.cast("KerasTensor", ops.mean(freq_true, axis=0, keepdims=True)) + freq_pred = T.cast("KerasTensor", ops.mean(freq_pred, axis=0, keepdims=True)) + + weight_matrix = self._get_weight_matrix(freq_true, freq_pred) + return self._calculate_loss(freq_true, freq_pred, weight_matrix) + + +class GeneralizedLoss(Loss): + """ Generalized function used to return a large variety of mathematical loss functions. + + The primary benefit is a smooth, differentiable version of L1 loss. + + References + ---------- + Barron, J. A General and Adaptive Robust Loss Function - https://arxiv.org/pdf/1701.03077.pdf + + Example + ------- + >>> a=1.0, x>>c , c=1.0/255.0 # will give a smoothly differentiable version of L1 / MAE loss + >>> a=1.999999 (limit as a->2), beta=1.0/255.0 # will give L2 / RMSE loss + + Parameters + ---------- + alpha: float, optional + Penalty factor. Larger number give larger weight to large deviations. Default: `1.0` + beta: float, optional + Scale factor used to adjust to the input scale (i.e. inputs of mean `1e-4` or `256`). + Default: `1.0/255.0` + """ + def __init__(self, alpha: float = 1.0, beta: float = 1.0/255.0) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self._alpha = alpha + self._beta = beta + logger.debug("Initialized: %s", self.__class__.__name__) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the Generalized Loss Function + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + + Returns + ------- + :class:`keras.KerasTensor` + The loss value from the results of function(y_pred - y_true) + """ + diff = y_pred - y_true + second = (ops.power(ops.power(diff/self._beta, 2.) / ops.abs(2. - self._alpha) + 1., + (self._alpha / 2.)) - 1.) + loss = (ops.abs(2. - self._alpha)/self._alpha) * second + loss = ops.mean(loss, axis=-1) * self._beta + return T.cast("KerasTensor", loss) + + +class GradientLoss(Loss): + """ Gradient Loss Function. + + Calculates the first and second order gradient difference between pixels of an image in the x + and y dimensions. These gradients are then compared between the ground truth and the predicted + image and the difference is taken. When used as a loss, its minimization will result in + predicted images approaching the same level of sharpness / blurriness as the ground truth. + + References + ---------- + TV+TV2 Regularization with Non-Convex Sparseness-Inducing Penalty for Image Restoration, + Chengwu Lu & Hua Huang, 2014 - http://downloads.hindawi.com/journals/mpe/2014/790547.pdf + """ + def __init__(self) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self.generalized_loss = GeneralizedLoss(alpha=1.9999) + self._tv_weight = 1.0 + self._tv2_weight = 1.0 + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def _diff_x(cls, img: KerasTensor) -> KerasTensor: + """ X Difference """ + x_left = img[:, :, 1:2, :] - img[:, :, 0:1, :] + x_inner = img[:, :, 2:, :] - img[:, :, :-2, :] + x_right = img[:, :, -1:, :] - img[:, :, -2:-1, :] + x_out = ops.concatenate([x_left, x_inner, x_right], axis=2) + return T.cast("KerasTensor", x_out) * 0.5 + + @classmethod + def _diff_y(cls, img: KerasTensor) -> KerasTensor: + """ Y Difference """ + y_top = img[:, 1:2, :, :] - img[:, 0:1, :, :] + y_inner = img[:, 2:, :, :] - img[:, :-2, :, :] + y_bot = img[:, -1:, :, :] - img[:, -2:-1, :, :] + y_out = ops.concatenate([y_top, y_inner, y_bot], axis=1) + return T.cast("KerasTensor", y_out) * 0.5 + + @classmethod + def _diff_xx(cls, img: KerasTensor) -> KerasTensor: + """ X-X Difference """ + x_left = img[:, :, 1:2, :] + img[:, :, 0:1, :] + x_inner = img[:, :, 2:, :] + img[:, :, :-2, :] + x_right = img[:, :, -1:, :] + img[:, :, -2:-1, :] + x_out = ops.concatenate([x_left, x_inner, x_right], axis=2) + return x_out - 2.0 * img + + @classmethod + def _diff_yy(cls, img: KerasTensor) -> KerasTensor: + """ Y-Y Difference """ + y_top = img[:, 1:2, :, :] + img[:, 0:1, :, :] + y_inner = img[:, 2:, :, :] + img[:, :-2, :, :] + y_bot = img[:, -1:, :, :] + img[:, -2:-1, :, :] + y_out = ops.concatenate([y_top, y_inner, y_bot], axis=1) + return y_out - 2.0 * img + + @classmethod + def _diff_xy(cls, img: KerasTensor) -> KerasTensor: + """ X-Y Difference """ + # xout1 + # Left + top = img[:, 1:2, 1:2, :] + img[:, 0:1, 0:1, :] + inner = img[:, 2:, 1:2, :] + img[:, :-2, 0:1, :] + bottom = img[:, -1:, 1:2, :] + img[:, -2:-1, 0:1, :] + xy_left = ops.concatenate([top, inner, bottom], axis=1) + # Mid + top = img[:, 1:2, 2:, :] + img[:, 0:1, :-2, :] + mid = img[:, 2:, 2:, :] + img[:, :-2, :-2, :] + bottom = img[:, -1:, 2:, :] + img[:, -2:-1, :-2, :] + xy_mid = ops.concatenate([top, mid, bottom], axis=1) + # Right + top = img[:, 1:2, -1:, :] + img[:, 0:1, -2:-1, :] + inner = img[:, 2:, -1:, :] + img[:, :-2, -2:-1, :] + bottom = img[:, -1:, -1:, :] + img[:, -2:-1, -2:-1, :] + xy_right = ops.concatenate([top, inner, bottom], axis=1) + + # Xout2 + # Left + top = img[:, 0:1, 1:2, :] + img[:, 1:2, 0:1, :] + inner = img[:, :-2, 1:2, :] + img[:, 2:, 0:1, :] + bottom = img[:, -2:-1, 1:2, :] + img[:, -1:, 0:1, :] + xy_left = ops.concatenate([top, inner, bottom], axis=1) + # Mid + top = img[:, 0:1, 2:, :] + img[:, 1:2, :-2, :] + mid = img[:, :-2, 2:, :] + img[:, 2:, :-2, :] + bottom = img[:, -2:-1, 2:, :] + img[:, -1:, :-2, :] + xy_mid = ops.concatenate([top, mid, bottom], axis=1) + # Right + top = img[:, 0:1, -1:, :] + img[:, 1:2, -2:-1, :] + inner = img[:, :-2, -1:, :] + img[:, 2:, -2:-1, :] + bottom = img[:, -2:-1, -1:, :] + img[:, -1:, -2:-1, :] + xy_right = ops.concatenate([top, inner, bottom], axis=1) + + xy_out1 = T.cast("KerasTensor", ops.concatenate([xy_left, xy_mid, xy_right], axis=2)) + xy_out2 = T.cast("KerasTensor", ops.concatenate([xy_left, xy_mid, xy_right], axis=2)) + return (xy_out1 - xy_out2) * 0.25 + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the gradient loss function. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + + Returns + ------- + :class:`keras.KerasTensor` + The loss value + """ + loss = 0.0 + loss += self._tv_weight * (self.generalized_loss(self._diff_x(y_true), + self._diff_x(y_pred)) + + self.generalized_loss(self._diff_y(y_true), + self._diff_y(y_pred))) + loss += self._tv2_weight * (self.generalized_loss(self._diff_xx(y_true), + self._diff_xx(y_pred)) + + self.generalized_loss(self._diff_yy(y_true), + self._diff_yy(y_pred)) + + self.generalized_loss(self._diff_xy(y_true), + self._diff_xy(y_pred)) * 2.) + loss = loss / (self._tv_weight + self._tv2_weight) + # TODO simplify to use MSE instead + return T.cast("KerasTensor", loss) + + +class LaplacianPyramidLoss(Loss): + """ Laplacian Pyramid Loss Function + + Notes + ----- + Channels last implementation on square images only. + + Parameters + ---------- + max_levels: int, Optional + The max number of laplacian pyramid levels to use. Default: `5` + gaussian_size: int, Optional + The size of the gaussian kernel. Default: `5` + gaussian_sigma: float, optional + The gaussian sigma. Default: 2.0 + + References + ---------- + https://arxiv.org/abs/1707.05776 + https://github.com/nathanaelbosch/generative-latent-optimization/blob/master/utils.py + """ + def __init__(self, + max_levels: int = 5, + gaussian_size: int = 5, + gaussian_sigma: float = 1.0) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self._max_levels = max_levels + self._weights = Variable([np.power(2., -2 * idx) for idx in range(max_levels + 1)], + trainable=False) + self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma) + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def _get_gaussian_kernel(cls, size: int, sigma: float) -> KerasTensor: + """ Obtain the base gaussian kernel for the Laplacian Pyramid. + + Parameters + ---------- + size: int, Optional + The size of the gaussian kernel + sigma: float + The gaussian sigma + + Returns + ------- + :class:`keras.KerasTensor` + The base single channel Gaussian kernel + """ + assert size % 2 == 1, ("kernel size must be uneven") + x_1 = np.linspace(- (size // 2), size // 2, size, dtype="float32") + x_1 /= np.sqrt(2)*sigma + x_2 = x_1 ** 2 + kernel = np.exp(- x_2[:, None] - x_2[None, :]) + kernel /= kernel.sum() + kernel = np.reshape(kernel, (size, size, 1, 1)) + return Variable(kernel, trainable=False) + + def _conv_gaussian(self, inputs: KerasTensor) -> KerasTensor: + """ Perform Gaussian convolution on a batch of images. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input batch of images to perform Gaussian convolution on. + + Returns + ------- + :class:`keras.KerasTensor` + The convolved images + """ + channels = inputs.shape[-1] + gauss = ops.tile(self._gaussian_kernel, (1, 1, 1, channels)) + + # TF doesn't implement replication padding like pytorch. This is an inefficient way to + # implement it for a square guassian kernel + # TODO Make this pure pytorch code + gauss_shape = self._gaussian_kernel.shape[1] + assert gauss_shape is not None + size = gauss_shape // 2 + padded_inputs = inputs + for _ in range(size): + padded_inputs = ops.pad(padded_inputs, + ([0, 0], [1, 1], [1, 1], [0, 0]), + mode="symmetric") + + retval = ops.conv(padded_inputs, gauss, strides=1, padding="valid") + return T.cast("KerasTensor", retval) + + def _get_laplacian_pyramid(self, inputs: KerasTensor) -> list[KerasTensor]: + """ Obtain the Laplacian Pyramid. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input batch of images to run through the Laplacian Pyramid + + Returns + ------- + list + The tensors produced from the Laplacian Pyramid + """ + pyramid = [] + current = inputs + for _ in range(self._max_levels): + gauss = self._conv_gaussian(current) + diff = current - gauss + pyramid.append(diff) + current = ops.average_pool(gauss, (2, 2), strides=(2, 2), padding="valid") + pyramid.append(current) + return pyramid + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Calculate the Laplacian Pyramid Loss. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + + Returns + ------- + :class:`keras.KerasTensor` + The loss value + """ + pyramid_true = self._get_laplacian_pyramid(y_true) + pyramid_pred = self._get_laplacian_pyramid(y_pred) + + losses = ops.stack( + [ops.sum(ops.abs(ppred - ptrue)) / ops.cast(ops.prod(ops.shape(ptrue)), "float32") + for ptrue, ppred in zip(pyramid_true, pyramid_pred)]) + loss = ops.sum(losses * self._weights) + return T.cast("KerasTensor", loss) + + +class LInfNorm(Loss): + """ Calculate the L-inf norm as a loss function. """ + def __init__(self, *args, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(*args, name=self.__class__.__name__, **kwargs) + logger.debug("Initialized: %s", self.__class__.__name__) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the L-inf norm loss function. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + + Returns + ------- + :class:`keras.KerasTensor` + The loss value + """ + diff = ops.abs(y_true - y_pred) + max_loss = ops.max(diff, axis=(1, 2), keepdims=True) + loss = ops.mean(max_loss, axis=-1) + return T.cast("KerasTensor", loss) + + +class LossWrapper(Loss): + """ A wrapper class for multiple keras losses to enable multiple masked weighted loss + functions on a single output. + + Notes + ----- + Whilst Keras does allow for applying multiple weighted loss functions, it does not allow + for an easy mechanism to add additional data (in our case masks) that are batch specific + but are not fed in to the model. + + This wrapper receives this additional mask data for the batch stacked onto the end of the + color channels of the received :attr:`y_true` batch of images. These masks are then split + off the batch of images and applied to both the :attr:`y_true` and :attr:`y_pred` tensors + prior to feeding into the loss functions. + + For example, for an image of shape (4, 128, 128, 3) 3 additional masks may be stacked onto + the end of y_true, meaning we receive an input of shape (4, 128, 128, 6). This wrapper then + splits off (4, 128, 128, 3:6) from the end of the tensor, leaving the original y_true of + shape (4, 128, 128, 3) ready for masking and feeding through the loss functions. + """ + def __init__(self, name="LossWrapper", reduction="sum_over_batch_size") -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=name, reduction=reduction) + self._loss_functions: list[Loss | Callable] = [] + self._loss_weights: list[float] = [] + self._mask_channels: list[int] = [] + logger.debug("Initialized: %s", self.__class__.__name__) + + def add_loss(self, + function: Callable | Loss, + weight: float = 1.0, + mask_channel: int = -1) -> None: + """ Add the given loss function with the given weight to the loss function chain. + + Parameters + ---------- + function: :class:`keras.losses.Loss` + The loss function to add to the loss chain + weight: float, optional + The weighting to apply to the loss function. Default: `1.0` + mask_channel: int, optional + The channel in the `y_true` image that the mask exists in. Set to `-1` if there is no + mask for the given loss function. Default: `-1` + """ + logger.debug("Adding loss: (function: %s, weight: %s, mask_channel: %s)", + function, weight, mask_channel) + # Loss must be compiled inside LossContainer for keras to handle distibuted strategies + self._loss_functions.append(function) + self._loss_weights.append(weight) + self._mask_channels.append(mask_channel) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the sub loss functions for the loss wrapper. + + Loss is returned as the weighted sum of the chosen losses. + + If masks are being applied to the loss function inputs, then they should be included as + additional channels at the end of :attr:`y_true`, so that they can be split off and + applied to the actual inputs to the selected loss function(s). + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images, with any required masks stacked on the end + y_pred: :class:`keras.KerasTensor` + The batch of model predictions + + Returns + ------- + :class:`keras.KerasTensor` + The final weighted loss + """ + loss = 0.0 + for func, weight, mask_channel in zip(self._loss_functions, + self._loss_weights, + self._mask_channels): + logger.trace("Processing loss function: " # type:ignore[attr-defined] + "(func: %s, weight: %s, mask_channel: %s)", + func, weight, mask_channel) + n_true, n_pred = self._apply_mask(y_true, y_pred, mask_channel) + loss += (func(n_true, n_pred) * weight) + return T.cast("KerasTensor", loss) + + @classmethod + def _apply_mask(cls, + y_true: KerasTensor, + y_pred: KerasTensor, + mask_channel: int, + mask_prop: float = 1.0) -> tuple[KerasTensor, KerasTensor]: + """ Apply the mask to the input y_true and y_pred. If a mask is not required then + return the unmasked inputs. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + mask_channel: int + The channel within y_true that the required mask resides in + mask_prop: float, optional + The amount of mask propagation. Default: `1.0` + + Returns + ------- + :class:`keras.KerasTensor` + The ground truth batch of images, with the required mask applied + :class:`keras.KerasTensor` + The predicted batch of images with the required mask applied + """ + if mask_channel == -1: + logger.trace("No mask to apply") # type:ignore[attr-defined] + return y_true[..., :3], y_pred[..., :3] + + logger.trace("Applying mask from channel %s", mask_channel) # type:ignore[attr-defined] + + mask = ops.tile(ops.expand_dims(y_true[..., mask_channel], axis=-1), (1, 1, 1, 3)) + mask_as_k_inv_prop = 1 - mask_prop + mask = (mask * mask_prop) + mask_as_k_inv_prop + + m_true = y_true[..., :3] * mask + m_pred = y_pred[..., :3] * mask + + return m_true, m_pred + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/losses/perceptual_loss.py b/lib/model/losses/perceptual_loss.py new file mode 100644 index 0000000000..cbdfa6eed7 --- /dev/null +++ b/lib/model/losses/perceptual_loss.py @@ -0,0 +1,978 @@ +#!/usr/bin/env python3 +""" Keras implementation of Perceptual Loss Functions for faceswap.py """ +from __future__ import annotations + +import logging +import typing as T + +import numpy as np +import torch + +import keras +from keras import ops, Variable + +from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor + from torch import Tensor + +logger = logging.getLogger(__name__) + + +class DSSIMObjective(keras.losses.Loss): + """ DSSIM Loss Functions + + Difference of Structural Similarity (DSSIM loss function). + + Adapted from :func:`tensorflow.image.ssim` for a pure keras implentation. + + Notes + ----- + Channels last only. Assumes all input images are the same size and square + + Parameters + ---------- + k_1: float, optional + Parameter of the SSIM. Default: `0.01` + k_2: float, optional + Parameter of the SSIM. Default: `0.03` + filter_size: int, optional + size of gaussian filter Default: `11` + filter_sigma: float, optional + Width of gaussian filter Default: `1.5` + max_value: float, optional + Max value of the output. Default: `1.0` + + Notes + ------ + You should add a regularization term like a l2 loss in addition to this one. + """ + def __init__(self, + k_1: float = 0.01, + k_2: float = 0.03, + filter_size: int = 11, + filter_sigma: float = 1.5, + max_value: float = 1.0) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self._filter_size = filter_size + self._filter_sigma = filter_sigma + self._kernel = self._get_kernel() + + compensation = 1.0 + self._c1 = (k_1 * max_value) ** 2 + self._c2 = ((k_2 * max_value) ** 2) * compensation + logger.debug("Initialized: %s", self.__class__.__name__) + + def _get_kernel(self) -> KerasTensor: + """ Obtain the base kernel for performing depthwise convolution. + + Returns + ------- + :class:`keras.KerasTensor` + The gaussian kernel based on selected size and sigma + """ + coords = np.arange(self._filter_size, dtype="float32") + coords -= (self._filter_size - 1) / 2. + + kernel = np.square(coords) + kernel *= -0.5 / np.square(self._filter_sigma) + kernel = np.reshape(kernel, (1, -1)) + np.reshape(kernel, (-1, 1)) + kernel = Variable(np.reshape(kernel, (1, -1)), trainable=False) + kernel = ops.softmax(kernel) + kernel = ops.reshape(kernel, (self._filter_size, self._filter_size, 1, 1)) + return T.cast("KerasTensor", kernel) + + @classmethod + def _depthwise_conv2d(cls, image: KerasTensor, kernel: KerasTensor) -> KerasTensor: + """ Perform a standardized depthwise convolution. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + Batch of images, channels last, to perform depthwise convolution + kernel: :class:`keras.KerasTensor` + convolution kernel + + Returns + ------- + :class:`keras.KerasTensor` + The output from the convolution + """ + return T.cast("KerasTensor", ops.depthwise_conv(image, kernel, strides=1, padding="valid")) + + def _get_ssim(self, + y_true: KerasTensor, + y_pred: KerasTensor) -> tuple[KerasTensor, KerasTensor]: + """ Obtain the structural similarity between a batch of true and predicted images. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The input batch of ground truth images + y_pred: :class:`keras.KerasTensor` + The input batch of predicted images + + Returns + ------- + :class:`keras.KerasTensor` + The SSIM for the given images + :class:`keras.KerasTensor` + The Contrast for the given images + """ + channels = y_true.shape[-1] + kernel = ops.tile(self._kernel, (1, 1, channels, 1)) + + # SSIM luminance measure is (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1) + mean_true = self._depthwise_conv2d(y_true, kernel) + mean_pred = self._depthwise_conv2d(y_pred, kernel) + num_lum = mean_true * mean_pred * 2.0 + den_lum = ops.square(mean_true) + ops.square(mean_pred) + luminance = (num_lum + self._c1) / (den_lum + self._c1) + + # SSIM contrast-structure measure is (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2) + num_con = self._depthwise_conv2d(y_true * y_pred, kernel) * 2.0 + den_con = self._depthwise_conv2d( + T.cast("KerasTensor", ops.square(y_true) + ops.square(y_pred)), kernel) + + contrast = (num_con - num_lum + self._c2) / (den_con - den_lum + self._c2) + + # Average over the height x width dimensions + axes = (-3, -2) + ssim = T.cast("KerasTensor", ops.mean(luminance * contrast, axis=axes)) + contrast = T.cast("KerasTensor", ops.mean(contrast, axis=axes)) + + return ssim, contrast + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the DSSIM or MS-DSSIM Loss Function. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The input batch of ground truth images + y_pred: :class:`keras.KerasTensor` + The input batch of predicted images + + Returns + ------- + :class:`keras.KerasTensor` + The DSSIM or MS-DSSIM for the given images + """ + ssim = self._get_ssim(y_true, y_pred)[0] + retval = (1. - ssim) / 2.0 + return T.cast("KerasTensor", ops.mean(retval)) + + +class GMSDLoss(keras.losses.Loss): + """ Gradient Magnitude Similarity Deviation Loss. + + Improved image quality metric over MS-SSIM with easier calculations + + References + ---------- + http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm + https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf + """ + + def __init__(self, *args, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(*args, name=self.__class__.__name__, **kwargs) + self._scharr_edges = Variable(np.array([[[[0.00070, 0.00070]], + [[0.00520, 0.00370]], + [[0.03700, 0.00000]], + [[0.00520, -0.0037]], + [[0.00070, -0.0007]]], + [[[0.00370, 0.00520]], + [[0.11870, 0.11870]], + [[0.25890, 0.00000]], + [[0.11870, -0.1187]], + [[0.00370, -0.0052]]], + [[[0.00000, 0.03700]], + [[0.00000, 0.25890]], + [[0.00000, 0.00000]], + [[0.00000, -0.2589]], + [[0.00000, -0.0370]]], + [[[-0.0037, 0.00520]], + [[-0.1187, 0.11870]], + [[-0.2589, 0.00000]], + [[-0.1187, -0.1187]], + [[-0.0037, -0.0052]]], + [[[-0.0007, 0.00070]], + [[-0.0052, 0.00370]], + [[-0.0370, 0.00000]], + [[-0.0052, -0.0037]], + [[-0.0007, -0.0007]]]]), + dtype="float32", + trainable=False) + logger.debug("Initialized: %s", self.__class__.__name__) + + def _map_scharr_edges(self, image: KerasTensor, magnitude: bool) -> KerasTensor: + """ Returns a tensor holding modified Scharr edge maps. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + Image tensor with shape [batch_size, h, w, d] and type float32. The image(s) must be + 2x2 or larger. + magnitude: bool + Boolean to determine if the edge magnitude or edge direction is returned + + Returns + ------- + :class:`keras.KerasTensor` + Tensor holding edge maps for each channel. Returns a tensor with shape `[batch_size, h, + w, d, 2]` where the last two dimensions hold `[[dy[0], dx[0]], [dy[1], dx[1]], ..., + [dy[d-1], dx[d-1]]]` calculated using the Scharr filter. + """ + # Define vertical and horizontal Scharr filters. + image_shape = image.shape + num_kernels = [2] + + kernels = ops.tile(self._scharr_edges, [1, 1, image_shape[-1], 1]) + + # Use depth-wise convolution to calculate edge maps per channel. + # Output tensor has shape [batch_size, h, w, d * num_kernels]. + pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]] + padded = ops.pad(image, pad_sizes, mode="reflect") + output = ops.depthwise_conv(padded, kernels) + + if not magnitude: # direction of edges + # Reshape to [batch_size, h, w, d, num_kernels]. + shape = ops.concatenate([image_shape, num_kernels], axis=0) + output = ops.reshape(output, shape) + output = ops.reshape(output, ops.concatenate([image_shape, num_kernels])) + output = torch.atan(T.cast("Tensor", + ops.squeeze(output[:, :, :, :, 0] / output[:, :, :, :, 1], + axis=None))) + # magnitude of edges -- unified x & y edges don't work well with Neural Networks + return T.cast("KerasTensor", output) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Return the Gradient Magnitude Similarity Deviation Loss. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + + Returns + ------- + :class:`keras.KerasTensor` + The loss value + """ + true_edge = self._map_scharr_edges(y_true, True) + pred_edge = self._map_scharr_edges(y_pred, True) + ephsilon = 0.0025 + upper = 2.0 * true_edge * pred_edge + lower = ops.square(true_edge) + ops.square(pred_edge) + gms = (upper + ephsilon) / (lower + ephsilon) + gmsd = ops.std(gms, axis=(1, 2, 3), keepdims=True) + gmsd = ops.squeeze(gmsd, axis=-1) + return T.cast("KerasTensor", gmsd) + + +class LDRFLIPLoss(keras.losses.Loss): # pylint:disable=too-many-instance-attributes + """ Computes the LDR-FLIP error map between two LDR images, assuming the images are observed + at a certain number of pixels per degree of visual angle. + + References + ---------- + https://research.nvidia.com/sites/default/files/node/3260/FLIP_Paper.pdf + https://github.com/NVlabs/flip + + License + ------- + BSD 3-Clause License + Copyright (c) 2020-2022, NVIDIA Corporation & AFFILIATES. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + Redistributions of source code must retain the above copyright notice, this list of conditions + and the following disclaimer. + Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials provided + with the distribution. + Neither the name of the copyright holder nor the names of its contributors may be used to + endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY + AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + Parameters + ---------- + computed_distance_exponent: float, Optional + The computed distance exponent to apply to Hunt adjusted, filtered colors. + (`qc` in original paper). Default: `0.7` + feature_exponent: float, Optional + The feature exponent to apply for increasing the impact of feature difference on the + final loss value. (`qf` in original paper). Default: `0.5` + lower_threshold_exponent: float, Optional + The `pc` exponent for the color pipeline as described in the original paper: Default: `0.4` + upper_threshold_exponent: float, Optional + The `pt` exponent for the color pipeline as described in the original paper. + Default: `0.95` + epsilon: float + A small value to improve training stability. Default: `1e-15` + pixels_per_degree: float, Optional + The estimated number of pixels per degree of visual angle of the observer. This effectively + impacts the tolerance when calculating loss. The default corresponds to viewing images on a + 0.7m wide 4K monitor at 0.7m from the display. Default: ``None`` + color_order: str + The `"BGR"` or `"RGB"` color order of the incoming images + """ + def __init__(self, + computed_distance_exponent: float = 0.7, + feature_exponent: float = 0.5, + lower_threshold_exponent: float = 0.4, + upper_threshold_exponent: float = 0.95, + epsilon: float = 1e-15, + pixels_per_degree: float | None = None, + color_order: T.Literal["bgr", "rgb"] = "bgr") -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self._computed_distance_exponent = computed_distance_exponent + self._feature_exponent = feature_exponent + self._pc = lower_threshold_exponent + self._pt = upper_threshold_exponent + self._epsilon = epsilon + self._color_order = color_order.lower() + + if pixels_per_degree is None: + pixels_per_degree = (0.7 * 3840 / 0.7) * np.pi / 180 + self._pixels_per_degree = pixels_per_degree + self._spatial_filters = _SpatialFilters(pixels_per_degree) + self._feature_detector = _FeatureDetection(pixels_per_degree) + self._col_conv = {"rgb2lab": ColorSpaceConvert(from_space="rgb", to_space="lab"), + "rgb2ycxcz": ColorSpaceConvert("srgb", "ycxcz")} + self._hunt = {"green": Variable([[[[0.0, 1.0, 0.0]]]], dtype="float32", trainable=False), + "blue": Variable([[[[0.0, 0.0, 1.0]]]], dtype="float32", trainable=False)} + + logger.debug("Initialized: %s ", self.__class__.__name__) + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the LDR Flip Loss Function + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images + y_pred: :class:`keras.KerasTensor` + The predicted batch of images + + Returns + ------- + :class::class:`keras.KerasTensor` + The calculated Flip loss value + """ + if self._color_order == "bgr": # Switch models training in bgr order to rgb + y_true = y_true[..., [2, 1, 0]] + y_pred = y_pred[..., [2, 1, 0]] + + y_true = T.cast("KerasTensor", ops.clip(y_true, 0, 1.)) + y_pred = T.cast("KerasTensor", ops.clip(y_pred, 0, 1.)) + + true_ycxcz = self._col_conv["rgb2ycxcz"](y_true) + pred_ycxcz = self._col_conv["rgb2ycxcz"](y_pred) + + delta_e_color = self._color_pipeline(true_ycxcz, pred_ycxcz) + delta_e_features = self._process_features(true_ycxcz, pred_ycxcz) + + loss = ops.power(delta_e_color, 1 - delta_e_features) + return T.cast("KerasTensor", loss) + + def _color_pipeline(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Perform the color processing part of the FLIP loss function + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images in YCxCz color space + y_pred: :class:`keras.KerasTensor` + The predicted batch of images in YCxCz color space + + Returns + ------- + :class:`keras.KerasTensor` + The exponentiated, maximum HyAB difference between two colors in Hunt-adjusted + L*A*B* space + """ + filtered_true = self._spatial_filters(y_true) + filtered_pred = self._spatial_filters(y_pred) + + rgb2lab = self._col_conv["rgb2lab"] + preprocessed_true = self._hunt_adjustment(rgb2lab(filtered_true)) + preprocessed_pred = self._hunt_adjustment(rgb2lab(filtered_pred)) + hunt_adjusted_green = self._hunt_adjustment(rgb2lab(self._hunt["green"])) + hunt_adjusted_blue = self._hunt_adjustment(rgb2lab(self._hunt["blue"])) + + delta = self._hyab(preprocessed_true, preprocessed_pred) + power_delta = T.cast("KerasTensor", ops.power(delta, self._computed_distance_exponent)) + cmax = T.cast("KerasTensor", ops.power(self._hyab(hunt_adjusted_green, hunt_adjusted_blue), + self._computed_distance_exponent)) + return self._redistribute_errors(power_delta, cmax) + + def _process_features(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Perform the color processing part of the FLIP loss function + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images in YCxCz color space + y_pred: :class:`keras.KerasTensor` + The predicted batch of images in YCxCz color space + + Returns + ------- + :class:`keras.KerasTensor` + The exponentiated features delta + """ + col_y_true = (y_true[..., 0:1] + 16) / 116. + col_y_pred = (y_pred[..., 0:1] + 16) / 116. + + edges_true = self._feature_detector(col_y_true, "edge") + points_true = self._feature_detector(col_y_true, "point") + edges_pred = self._feature_detector(col_y_pred, "edge") + points_pred = self._feature_detector(col_y_pred, "point") + + delta = ops.maximum(ops.abs(frobenius_norm(edges_true) - frobenius_norm(edges_pred)), + ops.abs(frobenius_norm(points_pred) - frobenius_norm(points_true))) + + delta = ops.clip(delta, x_min=self._epsilon, x_max=np.inf) + return T.cast("KerasTensor", ops.power(((1 / np.sqrt(2)) * delta), self._feature_exponent)) + + @classmethod + def _hunt_adjustment(cls, image: KerasTensor) -> KerasTensor: + """ Apply Hunt-adjustment to an image in L*a*b* color space + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The batch of images in L*a*b* to adjust + + Returns + ------- + :class:`keras.KerasTensor` + The hunt adjusted batch of images in L*a*b color space + """ + ch_l = image[..., 0:1] + adjusted = ops.concatenate([ch_l, image[..., 1:] * (ch_l * 0.01)], axis=-1) + return T.cast("KerasTensor", adjusted) + + def _hyab(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Compute the HyAB distance between true and predicted images. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth batch of images in standard or Hunt-adjusted L*A*B* color space + y_pred: :class:`keras.KerasTensor` + The predicted batch of images in in standard or Hunt-adjusted L*A*B* color space + + Returns + ------- + :class:`keras.KerasTensor` + image tensor containing the per-pixel HyAB distances between true and predicted images + """ + delta = y_true - y_pred + root = T.cast("KerasTensor", ops.sqrt(ops.clip(ops.power(delta[..., 0:1], 2), + x_min=self._epsilon, + x_max=np.inf))) + delta_norm = frobenius_norm(delta[..., 1:3]) + return root + delta_norm + + def _redistribute_errors(self, + power_delta_e_hyab: KerasTensor, + cmax: KerasTensor) -> KerasTensor: + """ Redistribute exponentiated HyAB errors to the [0,1] range + + Parameters + ---------- + power_delta_e_hyab: :class:`keras.KerasTensor` + The exponentiated HyAb distance + cmax: :class:`keras.KerasTensor` + The exponentiated, maximum HyAB difference between two colors in Hunt-adjusted + L*A*B* space + + Returns + ------- + :class:`keras.KerasTensor` + The redistributed per-pixel HyAB distances (in range [0,1]) + """ + pccmax = self._pc * cmax + delta_e_c = ops.where( + power_delta_e_hyab < pccmax, + (self._pt / pccmax) * power_delta_e_hyab, + self._pt + ((power_delta_e_hyab - pccmax) / (cmax - pccmax)) * (1.0 - self._pt)) + return T.cast("KerasTensor", delta_e_c) + + +class _SpatialFilters(): + """ Filters an image with channel specific spatial contrast sensitivity functions and clips + result to the unit cube in linear RGB. + + For use with LDRFlipLoss. + + Parameters + ---------- + pixels_per_degree: float + The estimated number of pixels per degree of visual angle of the observer. This effectively + impacts the tolerance when calculating loss. + """ + def __init__(self, pixels_per_degree: float) -> None: + logger.debug(parse_class_init(locals())) + self._pixels_per_degree = pixels_per_degree + self._spatial_filters, self._radius = self._generate_spatial_filters() + self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb") + logger.debug("Initialized: %s", self.__class__.__name__) + + def _generate_spatial_filters(self) -> tuple[KerasTensor, int]: + """ Generates spatial contrast sensitivity filters with width depending on the number of + pixels per degree of visual angle of the observer for channels "A", "RG" and "BY" + + Returns + ------- + dict + the channels ("A" (Achromatic CSF), "RG" (Red-Green CSF) or "BY" (Blue-Yellow CSF)) as + key with the Filter kernel corresponding to the spatial contrast sensitivity function + of channel and kernel's radius + """ + mapping = {"A": {"a1": 1, "b1": 0.0047, "a2": 0, "b2": 1e-5}, + "RG": {"a1": 1, "b1": 0.0053, "a2": 0, "b2": 1e-5}, + "BY": {"a1": 34.1, "b1": 0.04, "a2": 13.5, "b2": 0.025}} + + domain, radius = self._get_evaluation_domain(mapping["A"]["b1"], + mapping["A"]["b2"], + mapping["RG"]["b1"], + mapping["RG"]["b2"], + mapping["BY"]["b1"], + mapping["BY"]["b2"]) + + weights = np.array([self._generate_weights(mapping[channel], domain) + for channel in ("A", "RG", "BY")]) + vweights = Variable(np.moveaxis(weights, 0, -1), dtype="float32", trainable=False) + + return vweights, radius + + def _get_evaluation_domain(self, + b1_a: float, + b2_a: float, + b1_rg: float, + b2_rg: float, + b1_by: float, + b2_by: float) -> tuple[np.ndarray, int]: + """ TODO docstring """ + max_scale_parameter = max([b1_a, b2_a, b1_rg, b2_rg, b1_by, b2_by]) + delta_x = 1.0 / self._pixels_per_degree + radius = int(np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) + * self._pixels_per_degree)) + ax_x, ax_y = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1)) + domain = (ax_x * delta_x) ** 2 + (ax_y * delta_x) ** 2 + return domain, radius + + @classmethod + def _generate_weights(cls, channel: dict[str, float], domain: np.ndarray) -> np.ndarray: + """ TODO docstring """ + a_1, b_1, a_2, b_2 = channel["a1"], channel["b1"], channel["a2"], channel["b2"] + grad = (a_1 * np.sqrt(np.pi / b_1) * np.exp(-np.pi ** 2 * domain / b_1) + + a_2 * np.sqrt(np.pi / b_2) * np.exp(-np.pi ** 2 * domain / b_2)) + grad = grad / np.sum(grad) + grad = np.reshape(grad, (*grad.shape, 1)) + return grad + + def __call__(self, image: KerasTensor) -> KerasTensor: + """ Call the spacial filtering. + + Parameters + ---------- + image: :class:`keras.KerasTensor` + Image tensor to filter in YCxCz color space + + Returns + ------- + :class:`keras.KerasTensor` + The input image transformed to linear RGB after filtering with spatial contrast + sensitivity functions + """ + padded_image = replicate_pad(image, self._radius) + image_tilde_opponent = T.cast("KerasTensor", ops.conv(padded_image, + self._spatial_filters, + strides=1, + padding="valid")) + rgb = ops.clip(self._ycxcz2rgb(image_tilde_opponent), 0., 1.) + return T.cast("KerasTensor", rgb) + + +class _FeatureDetection(): + """ Detect features (i.e. edges and points) in an achromatic YCxCz image. + + For use with LDRFlipLoss. + + Parameters + ---------- + pixels_per_degree: float + The number of pixels per degree of visual angle of the observer + """ + def __init__(self, pixels_per_degree: float) -> None: + logger.debug(parse_class_init(locals())) + width = 0.082 + self._std = 0.5 * width * pixels_per_degree + self._radius = int(np.ceil(3 * self._std)) + grid = np.meshgrid(range(-self._radius, self._radius + 1), + range(-self._radius, self._radius + 1)) + + gradient = np.exp(-(grid[0] ** 2 + grid[1] ** 2) / (2 * (self._std ** 2))) + self._grads = { + "edge": Variable(np.multiply(-grid[0], gradient), trainable=False, dtype="float32"), + "point": Variable(np.multiply(grid[0] ** 2 / (self._std ** 2) - 1, gradient), + trainable=False, + dtype="float32")} + + logger.debug("Initialized: %s", self.__class__.__name__) + + def __call__(self, image: KerasTensor, feature_type: str) -> KerasTensor: + """ Run the feature detection + + Parameters + ---------- + image: :class:`keras.KerasTensor` + Batch of images in YCxCz color space with normalized Y values + feature_type: str + Type of features to detect (`"edge"` or `"point"`) + + Returns + ------- + :class:`keras.KerasTensor` + Detected features in the 0-1 range + """ + feature_type = feature_type.lower() + + grad_x = self._grads[feature_type] + negative_weights_sum = -ops.sum(grad_x[grad_x < 0]) + positive_weights_sum = ops.sum(grad_x[grad_x > 0]) + + grad_x = ops.where(grad_x < 0, + grad_x / negative_weights_sum, + grad_x / positive_weights_sum) + kernel = ops.expand_dims(ops.expand_dims(grad_x, axis=-1), axis=-1) + features_x = ops.conv(replicate_pad(image, self._radius), + kernel, + strides=1, + padding="valid") + kernel = ops.transpose(kernel, (1, 0, 2, 3)) + features_y = ops.conv(replicate_pad(image, self._radius), + kernel, + strides=1, + padding="valid") + features = ops.concatenate([features_x, features_y], axis=-1) + return T.cast("KerasTensor", features) + + +class MSSIMLoss(keras.losses.Loss): + """ Multiscale Structural Similarity Loss Function + + Parameters + ---------- + k_1: float, optional + Parameter of the SSIM. Default: `0.01` + k_2: float, optional + Parameter of the SSIM. Default: `0.03` + filter_size: int, optional + size of gaussian filter Default: `11` + filter_sigma: float, optional + Width of gaussian filter Default: `1.5` + max_value: float, optional + Max value of the output. Default: `1.0` + power_factors: tuple, optional + Iterable of weights for each of the scales. The number of scales used is the length of the + list. Index 0 is the unscaled resolution's weight and each increasing scale corresponds to + the image being downsampled by 2. Defaults to the values obtained in the original paper. + Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + + Notes + ------ + You should add a regularization term like a l2 loss in addition to this one. + Adapted from Tehnsorflow's ssim_multiscale implementation + """ + def __init__(self, + k_1: float = 0.01, + k_2: float = 0.03, + filter_size: int = 11, + filter_sigma: float = 1.5, + max_value: float = 1.0, + power_factors: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + ) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(name=self.__class__.__name__) + self.filter_size = filter_size + self._filter_sigma = Variable(filter_sigma, dtype="float32", trainable=False) + self._k_1 = k_1 + self._k_2 = k_2 + self._max_value = max_value + self._power_factors = power_factors + self._divisor = [1, 2, 2, 1] + self._divisor_tensor = Variable(self._divisor[1:], dtype="int32", trainable=False) + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def _reducer(cls, image: KerasTensor, kernel: KerasTensor) -> KerasTensor: + """ Computes local averages from a set of images + + Parameters + ---------- + image: :class:`keras.KerasTensor` + The images to be processed + kernel: :class:`keras.KerasTensor` + The kernel to apply + + Returns + ------- + :class:`keras.KerasTensor` + The reduced image + """ + shape = image.shape + var_x = ops.reshape(image, (-1, *shape[-3:])) + var_y = ops.nn.depthwise_conv(var_x, kernel, strides=1, padding="valid") + return T.cast("KerasTensor", ops.reshape(var_y, (*shape[:-3], *var_y.shape[1:]))) + + def _ssim_helper(self, + image1: KerasTensor, + image2: KerasTensor, + kernel: KerasTensor) -> tuple[KerasTensor, KerasTensor]: + """ Helper function for computing SSIM + + Parameters + ---------- + image1: :class:`keras.KerasTensor` + The first set of images + image2: :class:`keras.KerasTensor` + The second set of images + kernel: :class:`keras.KerasTensor` + The gaussian kernel + + Returns + ------- + :class:`keras.KerasTensor`: + The channel-wise SSIM + :class:`keras.KerasTensor`: + The channel-wise contrast-structure + """ + c_1 = (self._k_1 * self._max_value) ** 2 + c_2 = (self._k_2 * self._max_value) ** 2 + + mean0 = self._reducer(image1, kernel) + mean1 = self._reducer(image2, kernel) + num0 = mean0 * mean1 * 2.0 + den0 = ops.square(mean0) + ops.square(mean1) + luminance = (num0 + c_1) / (den0 + c_1) + + num1 = self._reducer(image1 * image2, kernel) * 2.0 + den1 = self._reducer(T.cast("KerasTensor", ops.square(image1) + ops.square(image2)), + kernel) + cs_ = (num1 - num0 + c_2) / (den1 - den0 + c_2) + + return luminance, cs_ + + def _fspecial_gauss(self, size: int) -> KerasTensor: + """Function to mimic the 'fspecial' gaussian MATLAB function. + + Parameters + ---------- + filter_size: int + size of gaussian filter + + Returns + ------- + :class:`keras.KerasTensor` + The gaussian kernel + """ + coords = ops.cast(range(size), self._filter_sigma.dtype) + coords -= ops.cast(size - 1, self._filter_sigma.dtype) / 2.0 + + gauss = ops.square(coords) + gauss *= -0.5 / ops.square(self._filter_sigma) + + gauss = ops.reshape(gauss, [1, -1]) + ops.reshape(gauss, [-1, 1]) + gauss = ops.reshape(gauss, [1, -1]) # For ops.softmax(). + gauss = ops.softmax(gauss) + return T.cast("KerasTensor", ops.reshape(gauss, [size, size, 1, 1])) + + def _ssim_per_channel(self, + image1: KerasTensor, + image2: KerasTensor, + filter_size: int) -> tuple[KerasTensor, KerasTensor]: + """Computes SSIM index between image1 and image2 per color channel. + + This function matches the standard SSIM implementation from: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image + quality assessment: from error visibility to structural similarity. IEEE + transactions on image processing. + + Parameters + ---------- + image1: :class:`keras.KerasTensor` + The first image batch + image2: :class:`keras.KerasTensor` + The second image batch. + filter_size: int + size of gaussian filter). + + Returns + ------- + :class:`keras.KerasTensor`: + The channel-wise SSIM + :class:`keras.KerasTensor`: + The channel-wise contrast-structure + """ + shape = image1.shape + + kernel = self._fspecial_gauss(filter_size) + kernel = ops.tile(kernel, [1, 1, shape[-1], 1]) + + luminance, cs_ = self._ssim_helper(image1, image2, kernel) + + # Average over the second and the third from the last: height, width. + ssim_val = T.cast("KerasTensor", ops.mean(luminance * cs_, [-3, -2])) + cs_ = T.cast("KerasTensor", ops.mean(cs_, [-3, -2])) + return ssim_val, cs_ + + @classmethod + def _do_pad(cls, images: list[KerasTensor], remainder: KerasTensor) -> list[KerasTensor]: + """ Pad images + + Parameters + ---------- + images: list[:class:`keras.KerasTensor`] + Images to pad + remainder: :class:`keras.KerasTensor` + Remainding images to pad + + Returns + ------- + list[:class:`keras.KerasTensor`] + Padded images + """ + padding = ops.expand_dims(remainder, axis=-1) + padding = ops.pad(padding, [[1, 0], [1, 0]], mode="constant") + return [ops.pad(x, padding, mode="symmetric") for x in images] + + def _mssism(self, # pylint:disable=too-many-locals + y_true: KerasTensor, + y_pred: KerasTensor, + filter_size: int) -> KerasTensor: + """ Perform the MSSISM calculation. + + Ported from Tensorflow implementation `image.ssim_multiscale` + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + filter_size: int + The filter size to use + """ + images = [y_true, y_pred] + shapes = [y_true.shape, y_pred.shape] + heads = [s[:-3] for s in shapes] + tails = [s[-3:] for s in shapes] + + mcs = [] + ssim_per_channel = None + for k in range(len(self._power_factors)): + if k > 0: + # Avg pool takes rank 4 tensors. Flatten leading dimensions. + flat_images = [T.cast("KerasTensor", ops.reshape(x, (-1, *t))) + for x, t in zip(images, tails)] + remainder = tails[0] % self._divisor_tensor + + need_padding = ops.any(ops.not_equal(remainder, 0)) + padded = ops.cond( + need_padding, + lambda: self._do_pad(flat_images, # pylint:disable=cell-var-from-loop + remainder), # pylint:disable=cell-var-from-loop + lambda: flat_images) # pylint:disable=cell-var-from-loop + + downscaled = [ops.average_pool(x, + self._divisor[1:3], + strides=self._divisor[1:3], + padding='valid') + for x in padded] + + tails = [x.shape[1:] for x in downscaled] + images = [T.cast("KerasTensor", ops.reshape(x, (*h, *t))) + for x, h, t in zip(downscaled, heads, tails)] + + # Overwrite previous ssim value since we only need the last one. + ssim_per_channel, cs_ = self._ssim_per_channel(images[0], images[1], filter_size) + mcs.append(ops.relu(cs_)) + + mcs.pop() # Remove the cs score for the last scale. + + mcs_and_ssim = ops.stack(mcs + [ops.relu(ssim_per_channel)], axis=-1) + ms_ssim = ops.prod(ops.power(mcs_and_ssim, self._power_factors), [-1]) + + return T.cast("KerasTensor", ops.mean(ms_ssim, [-1])) # Avg over color channels. + + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: + """ Call the MS-SSIM Loss Function. + + Parameters + ---------- + y_true: :class:`keras.KerasTensor` + The ground truth value + y_pred: :class:`keras.KerasTensor` + The predicted value + + Returns + ------- + :class:`keras.KerasTensor` + The MS-SSIM Loss value + """ + im_size = y_true.shape[1] + assert isinstance(im_size, int) + # filter size cannot be larger than the smallest scale + smallest_scale = self._get_smallest_size(im_size, len(self._power_factors) - 1) + filter_size = min(self.filter_size, smallest_scale) + + ms_ssim = self._mssism(y_true, y_pred, filter_size) + ms_ssim_loss = 1. - ms_ssim + return T.cast("KerasTensor", ops.mean(ms_ssim_loss)) + + def _get_smallest_size(self, size: int, idx: int) -> int: + """ Recursive function to obtain the smallest size that the image will be scaled to. + + Parameters + ---------- + size: int + The current scaled size to iterate through + idx: int + The current iteration to be performed. When iteration hits zero the value will + be returned + + Returns + ------- + int + The smallest size the image will be scaled to based on the original image size and + the amount of scaling factors that will occur + """ + logger.trace("scale id: %s, size: %s", idx, size) # type:ignore[attr-defined] + if idx > 0: + size = self._get_smallest_size(size // 2, idx - 1) + return size + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/masks.py b/lib/model/masks.py deleted file mode 100644 index 6fc6ee151b..0000000000 --- a/lib/model/masks.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -""" Masks functions for faceswap.py - Masks from: - dfaker: https://github.com/dfaker/df""" - -import logging - -import cv2 -import numpy as np - -from lib.umeyama import umeyama - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -def dfaker(landmarks, face, channels=4): - """ Dfaker model mask - Embeds the mask into the face alpha channel - - channels: 1, 3 or 4: - 1 - Return a single channel mask - 3 - Return a 3 channel mask - 4 - Return the original image with the mask in the alpha channel - """ - padding = int(face.shape[0] * 0.1875) - coverage = face.shape[0] - (padding * 2) - logger.trace("face_shape: %s, coverage: %s, landmarks: %s", face.shape, coverage, landmarks) - - mat = umeyama(landmarks[17:], True)[0:2] - mat = np.array(mat.ravel()).reshape(2, 3) - mat = mat * coverage - mat[:, 2] += padding - - points = np.array(landmarks).reshape((-1, 2)) - facepoints = np.array(points).reshape((-1, 2)) - - mask = np.zeros_like(face, dtype=np.uint8) - - hull = cv2.convexHull(facepoints.astype(int)) # pylint: disable=no-member - hull = cv2.transform(hull.reshape(1, -1, 2), # pylint: disable=no-member - mat).reshape(-1, 2).astype(int) - cv2.fillConvexPoly(mask, hull, (255, 255, 255)) # pylint: disable=no-member - - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) # pylint: disable=no-member - mask = cv2.dilate(mask, # pylint: disable=no-member - kernel, - iterations=1, - borderType=cv2.BORDER_REFLECT) # pylint: disable=no-member - mask = mask[:, :, :1] - - return merge_mask(face, mask, channels) - - -def dfl_full(landmarks, face, channels=4): - """ DFL Face Full Mask - - channels: 1, 3 or 4: - 1 - Return a single channel mask - 3 - Return a 3 channel mask - 4 - Return the original image with the mask in the alpha channel - """ - logger.trace("face_shape: %s, landmarks: %s", face.shape, landmarks) - mask = np.zeros(face.shape[0:2] + (1, ), dtype=np.float32) - jaw = cv2.convexHull(np.concatenate(( # pylint: disable=no-member - landmarks[0:17], # jawline - landmarks[48:68], # mouth - [landmarks[0]], # temple - [landmarks[8]], # chin - [landmarks[16]]))) # temple - nose_ridge = cv2.convexHull(np.concatenate(( # pylint: disable=no-member - landmarks[27:31], # nose line - [landmarks[33]]))) # nose point - eyes = cv2.convexHull(np.concatenate(( # pylint: disable=no-member - landmarks[17:27], # eyebrows - [landmarks[0]], # temple - [landmarks[27]], # nose top - [landmarks[16]], # temple - [landmarks[33]]))) # nose point - - cv2.fillConvexPoly(mask, jaw, (255, 255, 255)) # pylint: disable=no-member - cv2.fillConvexPoly(mask, nose_ridge, (255, 255, 255)) # pylint: disable=no-member - cv2.fillConvexPoly(mask, eyes, (255, 255, 255)) # pylint: disable=no-member - return merge_mask(face, mask, channels) - - -def merge_mask(image, mask, channels): - """ Return the mask in requested shape """ - logger.trace("image_shape: %s, mask_shape: %s, channels: %s", - image.shape, mask.shape, channels) - assert channels in (1, 3, 4), "Channels should be 1, 3 or 4" - assert mask.shape[2] == 1 and mask.ndim == 3, "Input mask be 3 dimensions with 1 channel" - - if channels == 3: - retval = np.tile(mask, 3) - elif channels == 4: - retval = np.concatenate((image, mask), -1) - else: - retval = mask - - logger.trace("Final mask shape: %s", retval.shape) - return retval diff --git a/lib/model/networks/__init__.py b/lib/model/networks/__init__.py new file mode 100644 index 0000000000..e2be872d73 --- /dev/null +++ b/lib/model/networks/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Pre-defined networks for use in faceswap """ +from .simple_nets import AlexNet, SqueezeNet +from .clip import ViT, ViTConfig, TypeModels as TypeModelsViT diff --git a/lib/model/networks/clip.py b/lib/model/networks/clip.py new file mode 100644 index 0000000000..d325624f57 --- /dev/null +++ b/lib/model/networks/clip.py @@ -0,0 +1,859 @@ +#!/usr/bin/env python3 +""" CLIP: https://github.com/openai/CLIP. This implementation only ports the visual transformer +part of the model. +""" +# TODO Fix Resnet. It is correct until final MHA +from __future__ import annotations +import inspect +import logging +import typing as T +import sys +import warnings + +from dataclasses import dataclass + +from keras import layers, ops, Variable, models, saving +import numpy as np + +from lib.model.layers import QuickGELU +from lib.utils import get_module_objects, GetModel + +if T.TYPE_CHECKING: + from keras import KerasTensor + + +logger = logging.getLogger(__name__) + +TypeModels = T.Literal["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B-16", + "ViT-B-32", "ViT-L-14", "ViT-L-14-336px", "FaRL-B-16-16", "FaRL-B-16-64"] + + +@dataclass +class ViTConfig: + """ Configuration settings for ViT + + Parameters + ---------- + embed_dim: int + Dimensionality of the final shared embedding space + resolution: int + Spatial resolution of the input images + layer_conf: tuple[int, int, int, int] | int + Number of layers in the visual encoder, or a tuple of layer configurations for a custom + ResNet visual encoder + width: int + Width of the visual encoder layers + patch: int + Size of the patches to be extracted from the images. Only used for Visual encoder. + git_id: int, optional + The id of the model weights file stored in deepfakes_models repo if they exist. Default: 0 + """ + embed_dim: int + resolution: int + layer_conf: int | tuple[int, int, int, int] + width: int + patch: int + git_id: int = 0 + + def __post_init__(self): + """ Validate that patch_size is given correctly """ + assert (isinstance(self.layer_conf, (tuple, list)) and self.patch == 0) or ( + isinstance(self.layer_conf, int) and self.patch > 0) + + +MODEL_CONFIG: dict[TypeModels, ViTConfig] = { # Each model has a different set of parameters + "RN50": ViTConfig( + embed_dim=1024, resolution=224, layer_conf=(3, 4, 6, 3), width=64, patch=0, git_id=21), + "RN101": ViTConfig( + embed_dim=512, resolution=224, layer_conf=(3, 4, 23, 3), width=64, patch=0, git_id=22), + "RN50x4": ViTConfig( + embed_dim=640, resolution=288, layer_conf=(4, 6, 10, 6), width=80, patch=0, git_id=23), + "RN50x16": ViTConfig( + embed_dim=768, resolution=384, layer_conf=(6, 8, 18, 8), width=96, patch=0, git_id=24), + "RN50x64": ViTConfig( + embed_dim=1024, resolution=448, layer_conf=(3, 15, 36, 10), width=128, patch=0, git_id=25), + "ViT-B-16": ViTConfig( + embed_dim=512, resolution=224, layer_conf=12, width=768, patch=16, git_id=26), + "ViT-B-32": ViTConfig( + embed_dim=512, resolution=224, layer_conf=12, width=768, patch=32, git_id=27), + "ViT-L-14": ViTConfig( + embed_dim=768, resolution=224, layer_conf=24, width=1024, patch=14, git_id=28), + "ViT-L-14-336px": ViTConfig( + embed_dim=768, resolution=336, layer_conf=24, width=1024, patch=14, git_id=29), + "FaRL-B-16-16": ViTConfig( + embed_dim=512, resolution=224, layer_conf=12, width=768, patch=16, git_id=30), + "FaRL-B-16-64": ViTConfig( + embed_dim=512, resolution=224, layer_conf=12, width=768, patch=16, git_id=31)} + + +# ################## # +# VISUAL TRANSFORMER # +# ################## # + +class Transformer(): + """ A class representing a Transformer model with attention mechanism and residual connections. + + Parameters + ---------- + width: int + The dimension of the input and output vectors. + num_layers: int + The number of layers in the Transformer. + heads: int + The number of attention heads. + attn_mask: :class:`keras.KerasTensor`, optional + The attention mask, by default None. + name: str, optional + The name of the Transformer model, by default "transformer". + + Methods + ------- + __call__() -> :class:`keras.models.Model`: + Calls the Transformer layers. + """ + _layer_names: dict[str, int] = {} + """ dict[str, int] for tracking unique layer names""" + + def __init__(self, + width: int, + num_layers: int, + heads: int, + attn_mask: KerasTensor = None, + name: str = "transformer") -> None: + logger.debug("Initializing: %s (width: %s, num_layers: %s, heads: %s, attn_mask: %s, " + "name: %s)", + self.__class__.__name__, width, num_layers, heads, attn_mask, name) + self._width = width + self._num_layers = num_layers + self._heads = heads + self._attn_mask = attn_mask + self._name = name + logger.debug("Initialized: %s ", self.__class__.__name__) + + @classmethod + def _get_name(cls, name: str) -> str: + """ Return unique layer name for requested block. + + As blocks can be used multiple times, auto appends an integer to the end of the requested + name to keep all block names unique + + Parameters + ---------- + name: str + The requested name for the layer + + Returns + ------- + str + The unique name for this layer + """ + cls._layer_names[name] = cls._layer_names.setdefault(name, -1) + 1 + name = f"{name}_{cls._layer_names[name]}" + logger.debug("Generating block name: %s", name) + return name + + @classmethod + def _mlp(cls, inputs: KerasTensor, key_dim: int, name: str) -> KerasTensor: + """" Multilayer Perceptron for Block Attention + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the MLP + key_dim: int + key dimension per head for MultiHeadAttention + name: str + The name to prefix on the layers + + Returns + ------- + :class:`keras.KerasTensor` + The output from the MLP + """ + name = f"{name}_mlp" + var_x = layers.Dense(key_dim * 4, name=f"{name}_c_fc")(inputs) + var_x = QuickGELU(name=f"{name}_gelu")(var_x) + var_x = layers.Dense(key_dim, name=f"{name}_c_proj")(var_x) + return var_x + + def residual_attention_block(self, + inputs: KerasTensor, + key_dim: int, + num_heads: int, + attn_mask: KerasTensor, + name: str = "ResidualAttentionBlock") -> KerasTensor: + """ Call the residual attention block + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input Tensor + key_dim: int + key dimension per head for MultiHeadAttention + num_heads: int + Number of heads for MultiHeadAttention + attn_mask: :class:`keras.KerasTensor`, optional + Default: ``None`` + name: str, optional + The name for the layer. Default: "ResidualAttentionBlock" + + Returns + ------- + :class:`keras.KerasTensor` + The return Tensor + """ + name = self._get_name(name) + + var_x = layers.LayerNormalization(epsilon=1e-05, name=f"{name}_ln_1")(inputs) + var_x = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=key_dim // num_heads, + name=f"{name}_attn")(var_x, var_x, var_x, attention_mask=attn_mask) + var_x = layers.Add()([inputs, var_x]) + var_y = var_x + var_x = layers.LayerNormalization(epsilon=1e-05, name=f"{name}_ln_2")(var_x) + var_x = layers.Add()([var_y, self._mlp(var_x, key_dim, name)]) + return var_x + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Transformer layers + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input Tensor + + Returns + ------- + :class:`keras.KerasTensor` + The return Tensor + """ + logger.debug("Calling %s with input: %s", self.__class__.__name__, inputs.shape) + var_x = inputs + for _ in range(self._num_layers): + var_x = self.residual_attention_block(var_x, + self._width, + self._heads, + self._attn_mask, + name=f"{self._name}_resblocks") + return var_x + + +class EmbeddingLayer(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Parent class for trainable embedding variables + + Parameters + ---------- + input_shape: tuple[int, ...] + The shape of the variable + scale: int + Amount to scale the random initialization by + name: str + The name of the layer + dtype: str, optional + The datatype for the layer. Mixed precision can mess up the embeddings. Default: "float32" + """ + def __init__(self, + input_shape: tuple[int, ...], + scale: int, + name: str, + *args, + dtype="float32", + **kwargs) -> None: + super().__init__(name=name, dtype=dtype, *args, **kwargs) + self._input_shape = input_shape + self._scale = scale + self._var: KerasTensor + + def build(self, input_shape: tuple[int, ...]) -> None: + """ Add the weights + + Parameters + ---------- + input_shape: tuple[int, ... + The input shape of the incoming tensor + """ + self._var = Variable(self._scale * np.random.normal(size=self._input_shape), + trainable=True, + dtype=self.dtype) + super().build(input_shape) + + def get_config(self) -> dict[str, T.Any]: + """ Get the config dictionary for the layer + + Returns + ------- + dict[str, Any] + The config dictionary for the layer + """ + retval = super().get_config() + retval["input_shape"] = self._input_shape + retval["scale"] = self._scale + return retval + + +class ClassEmbedding(EmbeddingLayer): # pylint:disable=too-many-ancestors,abstract-method + """ Trainable Class Embedding layer """ + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Get the Class Embedding layer + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor to the embedding layer + + Returns + ------- + :class:`keras.KerasTensor` + The class embedding layer shaped for the input tensor + """ + return ops.tile(self._var[None, None], [inputs.shape[0], 1, 1]) + + +class PositionalEmbedding(EmbeddingLayer): # pylint:disable=too-many-ancestors,abstract-method + """ Trainable Positional Embedding layer """ + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Get the Positional Embedding layer + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor to the embedding layer + + Returns + ------- + :class:`keras.KerasTensor` + The positional embedding layer shaped for the input tensor + """ + return ops.tile(self._var[None], [inputs.shape[0], 1, 1]) + + +class Projection(EmbeddingLayer): # pylint:disable=too-many-ancestors,abstract-method + """ Trainable Projection Embedding Layer """ + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Get the Projection layer + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor to the embedding layer + + Returns + ------- + :class:`keras.KerasTensor` + The Projection layer expanded to the batch dimension and transposed for matmul + """ + return ops.tile(ops.transpose(self._var)[None], [inputs.shape[0], 1, 1]) + + +class VisualTransformer(): + """ A class representing a Visual Transformer model for image classification tasks. + + Parameters + ---------- + input_resolution: int + The input resolution of the images. + patch_size: int + The size of the patches to be extracted from the images. + width: int + The dimension of the input and output vectors. + num_layers: int + The number of layers in the Transformer. + heads: int + The number of attention heads. + output_dim: int + The dimension of the output vector. + name: str, optional + The name of the Visual Transformer model, Default: "VisualTransformer". + + Methods + ------- + __call__() -> :class:`keras.models.Model`: + Builds and returns the Visual Transformer model. + """ + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + num_layers: int, + heads: int, + output_dim: int, + name: str = "VisualTransformer") -> None: + logger.debug("Initializing: %s (input_resolution: %s, patch_size: %s, width: %s, " + "layers: %s, heads: %s, output_dim: %s, name: %s)", + self.__class__.__name__, input_resolution, patch_size, width, num_layers, + heads, output_dim, name) + self._input_resolution = input_resolution + self._patch_size = patch_size + self._width = width + self._num_layers = num_layers + self._heads = heads + self._output_dim = output_dim + self._name = name + logger.debug("Initialized: %s", self.__class__.__name__) + + def __call__(self) -> models.Model: + """ Builds and returns the Visual Transformer model. + + Returns + ------- + :class:`keras.models.Model` + The Visual Transformer model. + """ + inputs = layers.Input([self._input_resolution, self._input_resolution, 3]) + var_x: KerasTensor = layers.Conv2D(self._width, # shape = [*, grid, grid, width] + self._patch_size, + strides=self._patch_size, + use_bias=False, + name=f"{self._name}_conv1")(inputs) + + var_x = layers.Reshape((-1, self._width))(var_x) # shape = [*, grid ** 2, width] + + class_embed = ClassEmbedding((self._width, ), + self._width ** -0.5, + name=f"{self._name}_class_embedding")(var_x) + var_x = layers.Concatenate(axis=1)([class_embed, var_x]) + + pos_embed = PositionalEmbedding(((self._input_resolution // self._patch_size) ** 2 + 1, + self._width), + self._width ** -0.5, + name=f"{self._name}_positional_embedding")(var_x) + var_x = layers.Add()([var_x, pos_embed]) + var_x = layers.LayerNormalization(epsilon=1e-05, name=f"{self._name}_ln_pre")(var_x) + var_x = Transformer(self._width, + self._num_layers, + self._heads, + name=f"{self._name}_transformer")(var_x) + var_x = layers.LayerNormalization(epsilon=1e-05, + name=f"{self._name}_ln_post")(var_x[:, 0, :]) + proj = Projection((self._width, self._output_dim), + self._width ** -0.5, + name=f"{self._name}_proj")(var_x) + var_x = layers.Dot(axes=-1)([var_x, proj]) + return models.Model(inputs=inputs, outputs=var_x, name=self._name) + + +# ################ # +# MODIEFIED RESNET # +# ################ # +class Bottleneck(): + """ A ResNet bottleneck block that performs a sequence of convolutions, batch normalization, + and ReLU activation operations on an input tensor. + + Parameters + ---------- + inplanes: int + The number of input channels. + planes: int + The number of output channels. + stride: int, optional + The stride of the bottleneck block. Default: 1 + name: str, optional + The name of the bottleneck block. Default: "bottleneck" + """ + expansion = 4 + """ int: The factor by which the number of input channels is expanded to get the number of + output channels.""" + + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + name: str = "bottleneck") -> None: + logger.debug("Initializing: %s (inplanes: %s, planes: %s, stride: %s, name: %s)", + self.__class__.__name__, inplanes, planes, stride, name) + self._inplanes = inplanes + self._planes = planes + self._stride = stride + self._name = name + logger.debug("Initialized: %s", self.__class__.__name__) + + def _downsample(self, inputs: KerasTensor) -> KerasTensor: + """ Perform downsample if required + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input the downsample + + Returns + ------- + :class:`keras.KerasTensor` + The original tensor, if downsizing not required, otherwise the downsized tensor + """ + if self._stride <= 1 and self._inplanes == self._planes * self.expansion: + return inputs + + name = f"{self._name}_downsample" + out = layers.AveragePooling2D(self._stride, name=f"{name}_avgpool")(inputs) + out = layers.Conv2D(self._planes * self.expansion, + 1, + strides=1, + use_bias=False, + name=f"{name}_0")(out) + out = layers.BatchNormalization(name=f"{name}_1", epsilon=1e-5)(out) + return out + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Performs the forward pass for a Bottleneck block. + + All conv layers have stride 1. an avgpool is performed after the second convolution when + stride > 1 + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the Bottleneck block. + + Returns + ------- + :class:`keras.KerasTensor` + The result of the forward pass through the Bottleneck block. + """ + out = layers.Conv2D(self._planes, 1, use_bias=False, name=f"{self._name}_conv1")(inputs) + out = layers.BatchNormalization(name=f"{self._name}_bn1", epsilon=1e-5)(out) + out = layers.ReLU()(out) + + out = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(out) + out = layers.Conv2D(self._planes, 3, use_bias=False, name=f"{self._name}_conv2")(out) + out = layers.BatchNormalization(name=f"{self._name}_bn2", epsilon=1e-5)(out) + out = layers.ReLU()(out) + + if self._stride > 1: + out = layers.AveragePooling2D(self._stride)(out) + + out = layers.Conv2D(self._planes * self.expansion, + 1, + use_bias=False, + name=f"{self._name}_conv3")(out) + out = layers.BatchNormalization(name=f"{self._name}_bn3", epsilon=1e-5)(out) + + identity = self._downsample(inputs) + + out += identity + out = layers.ReLU()(out) + return out + + +class AttentionPool2d(): + """ An Attention Pooling layer that applies a multi-head self-attention mechanism over a + spatial grid of features. + + Parameters + ---------- + spatial_dim: int + The dimensionality of the spatial grid of features. + embed_dim: int + The dimensionality of the feature embeddings. + num_heads: int + The number of attention heads. + output_dim: int + The output dimensionality of the attention layer. If None, it defaults to embed_dim. + name: str + The name of the layer. + """ + def __init__(self, + spatial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int | None = None, + name="AttentionPool2d"): + logger.debug("Initializing: %s (spatial_dim: %s, embed_dim: %s, num_heads: %s, " + "output_dim: %s, name: %s)", + self.__class__.__name__, spatial_dim, embed_dim, num_heads, output_dim, name) + + self._spatial_dim = spatial_dim + self._embed_dim = embed_dim + self._num_heads = num_heads + self._output_dim = output_dim + self._name = name + logger.debug("Initialized: %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """Performs the attention pooling operation on the input tensor. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor`: + The input tensor of shape [batch_size, height, width, embed_dim]. + + Returns + ------- + :class:`keras.KerasTensor`:: The result of the attention pooling operation + """ + var_x: KerasTensor + var_x = layers.Reshape((-1, inputs.shape[-1]))(inputs) # NHWC -> N(HW)C + var_x = layers.Concatenate(axis=1)([ops.mean(var_x, axis=1, # N(HW)C -> N(HW+1)C + keepdims=True), var_x]) + pos_embed = PositionalEmbedding((self._spatial_dim ** 2 + 1, self._embed_dim), # N(HW+1)C + self._embed_dim ** 0.5, + name=f"{self._name}_positional_embedding")(var_x) + var_x = layers.Add()([var_x, pos_embed]) + # TODO At this point torch + keras match. They mismatch after MHA + var_x = layers.MultiHeadAttention(num_heads=self._num_heads, + key_dim=self._embed_dim // self._num_heads, + output_shape=self._output_dim or self._embed_dim, + use_bias=True, + name=f"{self._name}_mha")(var_x[:, :1, ...], + var_x, + var_x) + # only return the first element in the sequence + return var_x[:, 0, ...] + + +class ModifiedResNet(): + """ A ResNet class that is similar to torchvision's but contains the following changes: + + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max + pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions + with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + + Parameters + ---------- + input_resolution: int + The input resolution of the model. Default is 224. + width: int + The width of the model. Default is 64. + layer_config: list + A list containing the number of Bottleneck blocks for each layer. + output_dim: int + The output dimension of the model. + heads: int + The number of heads for the QKV attention. + name: str + The name of the model. Default is "ModifiedResNet". + """ + def __init__(self, + input_resolution: int, + width: int, + layer_config: tuple[int, int, int, int], + output_dim: int, + heads: int, + name="ModifiedResNet"): + self._input_resolution = input_resolution + self._width = width + self._layer_config = layer_config + self._heads = heads + self._output_dim = output_dim + self._name = name + + def _stem(self, inputs: KerasTensor) -> KerasTensor: + """ Applies the stem operation to the input tensor, which consists of 3 convolutional + layers with BatchNormalization and ReLU activation, followed by an average pooling + layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor after applying the stem operation. + """ + var_x = inputs + for i in range(1, 4): + width = self._width if i == 3 else self._width // 2 + strides = 2 if i == 1 else 1 + var_x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=f"conv{i}_padding")(var_x) + var_x = layers.Conv2D(width, + 3, + strides=strides, + use_bias=False, + name=f"conv{i}")(var_x) + var_x = layers.BatchNormalization(name=f"bn{i}", epsilon=1e-5)(var_x) + var_x = layers.ReLU()(var_x) + var_x = layers.AveragePooling2D(2, name="avgpool")(var_x) + return var_x + + def _bottleneck(self, + inputs: KerasTensor, + planes: int, + blocks: int, + stride: int = 1, + name: str = "layer") -> KerasTensor: + """ A private method that creates a sequential layer of Bottleneck blocks for the + ModifiedResNet model. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor + planes: int + The number of output channels for the layer. + blocks: int + The number of Bottleneck blocks in the layer. + stride: int + The stride for the first Bottleneck block in the layer. Default is 1. + name: str + The name of the layer. Default is "layer". + + Returns + ------- + :class:`keras.KerasTensor` + Sequential block of bottlenecks + """ + retval: KerasTensor + retval = Bottleneck(planes, planes, stride, name=f"{name}_0")(inputs) + for i in range(1, blocks): + retval = Bottleneck(planes * Bottleneck.expansion, + planes, + name=f"{name}_{i}")(retval) + return retval + + def __call__(self) -> models.Model: + """ Implements the forward pass of the ModifiedResNet model. + + Returns + ------- + :class:`keras.models.Model` + The modified resnet model. + """ + inputs = layers.Input((self._input_resolution, self._input_resolution, 3)) + var_x = self._stem(inputs) + + for i in range(4): + stride = 1 if i == 0 else 2 + var_x = self._bottleneck(var_x, + self._width * (2 ** i), + self._layer_config[i], + stride=stride, + name=f"{self._name}_layer{i + 1}") + + var_x = AttentionPool2d(self._input_resolution // 32, + self._width * 32, # the ResNet feature dimension + self._heads, + self._output_dim, + name=f"{self._name}_attnpool")(var_x) + return models.Model(inputs, outputs=var_x, name=self._name) + + +# ### # +# VIT # +# ### # +class ViT(): + """ Visiual Transform from CLIP + + A Convolutional Language-Image Pre-Training (CLIP) model that encodes images and text into a + shared latent space. + + Reference + --------- + https://arxiv.org/abs/2103.00020 + + Parameters + ---------- + name: ["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B-32", + "ViT-B-16", "ViT-L-14", "ViT-L-14-336px", "FaRL-B_16-64"] + The model configuration to use + input_size: int, optional + The required resolution size for the model. ``None`` for default preset size + load_weights: bool, optional + ``True`` to load pretrained weights. Default: ``False`` + """ + def __init__(self, + name: TypeModels, + input_size: int | None = None, + load_weights: bool = False) -> None: + logger.debug("Initializing: %s (name: %s, input_size: %s, load_weights: %s)", + self.__class__.__name__, name, input_size, load_weights) + assert name in MODEL_CONFIG, ("Name must be one of %s", list(MODEL_CONFIG)) + + self._name = name + self._load_weights = load_weights + + config = MODEL_CONFIG[name] + self._git_id = config.git_id + + res = input_size if input_size is not None else config.resolution + self._net = self._get_vision_net(config.layer_conf, + config.width, + config.embed_dim, + res, + config.patch) + logger.debug("Initialized: %s", self.__class__.__name__) + + def _get_vision_net(self, + layer_config: int | tuple[int, int, int, int], + width: int, + embed_dim: int, + resolution: int, + patch_size: int) -> models.Model: + """ Obtain the network for the vision layets + + Parameters + ---------- + layer_config: tuple[int, int, int, int] | int + Number of layers in the visual encoder, or a tuple of layer configurations for a custom + ResNet visual encoder. + width: int + Width of the visual encoder layers. + embed_dim: int + Dimensionality of the final shared embedding space. + resolution: int + Spatial resolution of the input images. + patch_size: int + Size of the patches to be extracted from the images. + + Returns + ------- + :class:`keras.models.Model` + The :class:`ModifiedResNet` or :class:`VisualTransformer` vision model to use + """ + if isinstance(layer_config, (tuple, list)): + vision_heads = width * 32 // 64 + return ModifiedResNet(input_resolution=resolution, + width=width, + layer_config=layer_config, + output_dim=embed_dim, + heads=vision_heads, + name="visual") + vision_heads = width // 64 + return VisualTransformer(input_resolution=resolution, + width=width, + num_layers=layer_config, + output_dim=embed_dim, + heads=vision_heads, + patch_size=patch_size, + name="visual") + + def __call__(self) -> models.Model: + """ Get the configured ViT model + + Returns + ------- + :class:`keras.models.Model` + The requested Visual Transformer model + """ + net: models.Model = self._net() + if self._load_weights and not self._git_id: + logger.warning("Trained weights are not available for '%s'", self._name) + return net + if self._load_weights: + model_path = GetModel(f"CLIPv_{self._name}_v1.h5", self._git_id).model_path + logger.info("Loading CLIPv trained weights for '%s'", self._name) + with warnings.catch_warnings(): + # TODO There is a potential bug in keras load_weights_by_name that tries to load + # top_level_weights where they don't exist. This always generates a scary looking + # warning, so it supressed for now + warnings.simplefilter("ignore") + # NOTE: Don't load by name as we had to replace local dots with underscores + net.load_weights(model_path, by_name=False, skip_mismatch=True) + + return net + + +# Update layers into Keras custom objects +for name_, obj in inspect.getmembers(sys.modules[__name__]): + if (inspect.isclass(obj) and issubclass(obj, layers.Layer) + and obj.__module__ == __name__): + saving.get_custom_objects().update({name_: obj}) + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/networks/simple_nets.py b/lib/model/networks/simple_nets.py new file mode 100644 index 0000000000..a33e0337a3 --- /dev/null +++ b/lib/model/networks/simple_nets.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +""" Ports of existing NN Architecture for use in faceswap.py """ +from __future__ import annotations +import logging +import typing as T + +from keras import layers +from keras.models import Model + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class _net(): # pylint:disable=too-few-public-methods + """ Base class for existing NeuralNet architecture + + Notes + ----- + All architectures assume channels_last format + + Parameters + ---------- + input_shape, Tuple, optional + The input shape for the model. Default: ``None`` + """ + def __init__(self, + input_shape: tuple[int, int, int] | None = None) -> None: + logger.debug(parse_class_init(locals())) + self._input_shape = (None, None, 3) if input_shape is None else input_shape + assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, ( + "Input shape must be in the format (height, width, channels) and the number of " + f"channels must equal 3. Received: {self._input_shape}") + logger.debug("Initialized: %s", self.__class__.__name__) + + +class AlexNet(_net): + """ AlexNet ported from torchvision version. + + Notes + ----- + This port only contains the features portion of the model. + + References + ---------- + https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf + + Parameters + ---------- + input_shape, Tuple, optional + The input shape for the model. Default: ``None`` + """ + def __init__(self, input_shape: tuple[int, int, int] | None = None) -> None: + super().__init__(input_shape) + self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch + self._filters = [64, 192, 384, 256, 256] # Filters at each block + + @classmethod + def _conv_block(cls, + inputs: KerasTensor, + padding: int, + filters: int, + kernel_size: int, + strides: int, + block_idx: int, + max_pool: bool) -> KerasTensor: + """ + The Convolutional block for AlexNet + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the block + padding: int + The amount of zero paddin to apply prior to convolution + filters: int + The number of filters to apply during convolution + kernel_size: int + The kernel size of the convolution + strides: int + The number of strides for the convolution + block_idx: int + The index of the current block (for standardized naming convention) + max_pool: bool + ``True`` to apply a max pooling layer at the beginning of the block otherwise ``False`` + + Returns + ------- + :class:`keras.KerasTensor` + The output of the Convolutional block + """ + name = f"features_{block_idx}" + var_x = inputs + if max_pool: + var_x = layers.MaxPooling2D(pool_size=3, strides=2, name=f"{name}_pool")(var_x) + var_x = layers.ZeroPadding2D(padding=padding, name=f"{name}_pad")(var_x) + var_x = layers.Conv2D(filters, + kernel_size=kernel_size, + strides=strides, + padding="valid", + activation="relu", + name=name)(var_x) + return var_x + + def __call__(self) -> Model: + """ Create the AlexNet Model + + Returns + ------- + :class:`keras.models.Model` + The compiled AlexNet model + """ + inputs = layers.Input(self._input_shape) + var_x = T.cast("KerasTensor", inputs) + kernel_size = 11 + strides = 4 + + for idx, (filters, block_idx) in enumerate(zip(self._filters, self._feature_indices)): + padding = 2 if idx < 2 else 1 + do_max_pool = 0 < idx < 3 + var_x = self._conv_block(var_x, + padding, + filters, + kernel_size, + strides, + block_idx, + do_max_pool) + kernel_size = max(3, kernel_size // 2) + strides = 1 + return Model(inputs=inputs, outputs=[var_x]) + + +class SqueezeNet(_net): + """ SqueezeNet ported from torchvision version. + + Notes + ----- + This port only contains the features portion of the model. + + References + ---------- + https://arxiv.org/abs/1602.07360 + + Parameters + ---------- + input_shape, Tuple, optional + The input shape for the model. Default: ``None`` + """ + + @classmethod + def _fire(cls, + inputs: KerasTensor, + squeeze_planes: int, + expand_planes: int, + block_idx: int) -> KerasTensor: + """ The fire block for SqueezeNet. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the fire block + squeeze_planes: int + The number of filters for the squeeze convolution + expand_planes: int + The number of filters for the expand convolutions + block_idx: int + The index of the current block (for standardized naming convention) + + Returns + ------- + :class:`keras.KerasTensor` + The output of the SqueezeNet fire block + """ + name = f"features_{block_idx}" + squeezed = layers.Conv2D(squeeze_planes, 1, + activation="relu", name=f"{name}_squeeze")(inputs) + expand1 = layers.Conv2D(expand_planes, 1, + activation="relu", name=f"{name}_expand1x1")(squeezed) + expand3 = layers.Conv2D(expand_planes, + 3, + activation="relu", + padding="same", + name=f"{name}_expand3x3")(squeezed) + return layers.Concatenate(axis=-1, name=name)([expand1, expand3]) + + def __call__(self) -> Model: + """ Create the SqueezeNet Model + + Returns + ------- + :class:`keras.models.Model` + The compiled SqueezeNet model + """ + inputs = layers.Input(self._input_shape) + var_x = layers.Conv2D(64, 3, strides=2, activation="relu", name="features_0")(inputs) + + block_idx = 2 + squeeze = 16 + expand = 64 + for idx in range(4): + if idx < 3: + var_x = layers.MaxPooling2D(pool_size=3, strides=2)(var_x) + block_idx += 1 + var_x = self._fire(var_x, squeeze, expand, block_idx) + block_idx += 1 + var_x = self._fire(var_x, squeeze, expand, block_idx) + block_idx += 1 + squeeze += 16 + expand += 64 + return Model(inputs=inputs, outputs=[var_x]) + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/nn_blocks.py b/lib/model/nn_blocks.py index 9a64a342fb..b70a64b8c7 100644 --- a/lib/model/nn_blocks.py +++ b/lib/model/nn_blocks.py @@ -1,279 +1,911 @@ #!/usr/bin/env python3 -""" Neural Network Blocks for faceswap.py - Blocks from: - the original https://www.reddit.com/r/deepfakes/ code sample + contribs - dfaker: https://github.com/dfaker/df - shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" - +""" Neural Network Blocks for faceswap.py. """ +from __future__ import annotations import logging -import tensorflow as tf -import keras.backend as K - -from keras.layers import (add, Add, BatchNormalization, concatenate, Lambda, regularizers, - Permute, Reshape, SeparableConv2D, Softmax, UpSampling2D) -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import Conv2D -from keras.layers.core import Activation -from keras.initializers import he_uniform, Constant -from .initializers import ICNR -from .layers import PixelShuffler, Scale, SubPixelUpscaling, ReflectionPadding2D -from .normalization import GroupNormalization, InstanceNormalization - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class NNBlocks(): - """ Blocks to use for creating models """ - def __init__(self, use_subpixel=False, use_icnr_init=False, use_reflect_padding=False): - logger.debug("Initializing %s: (use_subpixel: %s, use_icnr_init: %s, use_reflect_padding: %s)", - self.__class__.__name__, use_subpixel, use_icnr_init, use_reflect_padding) - self.use_subpixel = use_subpixel - self.use_icnr_init = use_icnr_init - self.use_reflect_padding = use_reflect_padding +import typing as T + +from keras import initializers, layers + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from plugins.train import train_config as cfg + +from .initializers import ICNR, ConvolutionAware +from .layers import PixelShuffler, ReflectionPadding2D, Swish, KResizeImages +from .normalization import InstanceNormalization + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +_names: dict[str, int] = {} + + +def _get_name(name: str) -> str: + """ Return unique layer name for requested block. + + As blocks can be used multiple times, auto appends an integer to the end of the requested + name to keep all block names unique + + Parameters + ---------- + name: str + The requested name for the layer + + Returns + ------- + str + The unique name for this layer + """ + _names[name] = _names.setdefault(name, -1) + 1 + name = f"{name}_{_names[name]}" + logger.debug("Generating block name: %s", name) + return name + + +def reset_naming() -> None: + """ Reset the naming convention for nn_block layers to start from 0 + + Used when a model needs to be rebuilt and the names for each build should be identical + """ + logger.debug("Resetting nn_block layer naming") + global _names # pylint:disable=global-statement + _names = {} + + +# << CONVOLUTIONS >> +def _get_default_initializer( + initializer: initializers.Initializer) -> initializers.Initializer: + """ Returns a default initializer of Convolutional Aware or HeUniform for convolutional + layers. + + Parameters + ---------- + initializer: :class:`keras.initializers.Initializer` or None + The initializer that has been passed into the model. If this value is ``None`` then a + default initializer will be set to 'HeUniform'. If Convolutional Aware initialization + has been enabled, then any passed through initializer will be replaced with the + Convolutional Aware initializer. + + Returns + ------- + :class:`keras.initializers.Initializer` + The kernel initializer to use for this convolutional layer. Either the original given + initializer, HeUniform or convolutional aware (if selected in config options) + """ + if isinstance(initializer, dict) and initializer.get("class_name", "") == "ConvolutionAware": + logger.debug("Returning serialized initialized ConvAware initializer: %s", initializer) + return initializer + + if cfg.conv_aware_init(): + retval = ConvolutionAware() + elif initializer is None: + retval = initializers.HeUniform() + else: + retval = initializer + logger.debug("Using model supplied initializer: %s", retval) + logger.debug("Set default kernel_initializer: (original: %s current: %s)", initializer, retval) + + return retval + + +class Conv2D(): # pylint:disable=too-many-ancestors,abstract-method + """ A standard Keras Convolution 2D layer with parameters updated to be more appropriate for + Faceswap architecture. + + Parameters are the same, with the same defaults, as a standard :class:`keras.layers.Conv2D` + except where listed below. The default initializer is updated to `HeUniform` or `convolutional + aware` based on user configuration settings. + + Parameters + ---------- + padding: str, optional + One of `"valid"` or `"same"` (case-insensitive). Default: `"same"`. Note that `"same"` is + slightly inconsistent across backends with `strides` != 1, as described + `here `_. + is_upscale: `bool`, optional + ``True`` if the convolution is being called from an upscale layer. This causes the instance + to check the user configuration options to see if ICNR initialization has been selected and + should be applied. This should only be passed in as ``True`` from :class:`UpscaleBlock` + layers. Default: ``False`` + """ + def __init__(self, *args, padding: str = "same", is_upscale: bool = False, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + if kwargs.get("name", None) is None: + filters = kwargs["filters"] if "filters" in kwargs else args[0] + kwargs["name"] = _get_name(f"conv2d_{filters}") + initializer = _get_default_initializer(kwargs.pop("kernel_initializer", None)) + if is_upscale and cfg.icnr_init(): + initializer = ICNR(initializer=initializer) + logger.debug("Using ICNR Initializer: %s", initializer) + self._conv2d = layers.Conv2D( + *args, + padding=padding, + kernel_initializer=initializer, # pyright:ignore[reportArgumentType] + **kwargs) logger.debug("Initialized %s", self.__class__.__name__) - @staticmethod - def update_kwargs(kwargs): - """ Set the default kernel initializer to he_uniform() """ - kwargs["kernel_initializer"] = kwargs.get("kernel_initializer", he_uniform()) - return kwargs - - # <<< Original Model Blocks >>> # - def conv(self, inp, filters, kernel_size=5, strides=2, padding='same', use_instance_norm=False, res_block_follows=False, **kwargs): - """ Convolution Layer""" - logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, use_instance_norm: %s, " - "kwargs: %s)", inp, filters, kernel_size, strides, use_instance_norm, kwargs) - kwargs = self.update_kwargs(kwargs) - if self.use_reflect_padding: - inp = ReflectionPadding2D(stride=strides, kernel_size=kernel_size)(inp) - padding = 'valid' - var_x = Conv2D(filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - **kwargs)(inp) - if use_instance_norm: - var_x = InstanceNormalization()(var_x) - if not res_block_follows: - var_x = LeakyReLU(0.1)(var_x) + def __call__(self, *args, **kwargs) -> KerasTensor: + """ Call the Conv2D layer + + Parameters + ---------- + args : tuple + Standard Conv2D layer call arguments + kwargs : dict[str, Any] + Standard Conv2D layer call keyword arguments + + Returns + ------- + :class: `keras.KerasTensor` + The Tensor from the Conv2D layer + """ + return self._conv2d(*args, **kwargs) + +class DepthwiseConv2D(): # noqa,pylint:disable=too-many-ancestors,abstract-method + """ A standard Keras Depthwise Convolution 2D layer with parameters updated to be more + appropriate for Faceswap architecture. + + Parameters are the same, with the same defaults, as a standard + :class:`keras.layers.DepthwiseConv2D` except where listed below. The default initializer is + updated to `HeUniform` or `convolutional aware` based on user configuration settings. + + Parameters + ---------- + padding: str, optional + One of `"valid"` or `"same"` (case-insensitive). Default: `"same"`. Note that `"same"` is + slightly inconsistent across backends with `strides` != 1, as described + `here `_. + is_upscale: `bool`, optional + ``True`` if the convolution is being called from an upscale layer. This causes the instance + to check the user configuration options to see if ICNR initialization has been selected and + should be applied. This should only be passed in as ``True`` from :class:`UpscaleBlock` + layers. Default: ``False`` + """ + def __init__(self, *args, padding: str = "same", is_upscale: bool = False, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + if kwargs.get("name", None) is None: + kwargs["name"] = _get_name("dwconv2d") + initializer = _get_default_initializer(kwargs.pop("depthwise_initializer", None)) + if is_upscale and cfg.icnr_init(): + initializer = ICNR(initializer=initializer) + logger.debug("Using ICNR Initializer: %s", initializer) + self._deptwiseconv2d = layers.DepthwiseConv2D( + *args, + padding=padding, + depthwise_initializer=initializer, # pyright:ignore[reportArgumentType] + **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, *args, **kwargs) -> KerasTensor: + """ Call the DepthwiseConv2D layer + + Parameters + ---------- + args : tuple + Standard DepthwiseConv2D layer call arguments + kwargs : dict[str, Any] + Standard DepthwiseConv2D layer call keyword arguments + + Returns + ------- + :class: `keras.KerasTensor` + The Tensor from the DepthwiseConv2D layer + """ + return self._deptwiseconv2d(*args, **kwargs) + + +class Conv2DOutput(): + """ A Convolution 2D layer that separates out the activation layer to explicitly set the data + type on the activation to float 32 to fully support mixed precision training. + + The Convolution 2D layer uses default parameters to be more appropriate for Faceswap + architecture. + + Parameters are the same, with the same defaults, as a standard :class:`keras.layers.Conv2D` + except where listed below. The default initializer is updated to HeUniform or convolutional + aware based on user config settings. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int or tuple/list of 2 ints + The height and width of the 2D convolution window. Can be a single integer to specify the + same value for all spatial dimensions. + activation: str, optional + The activation function to apply to the output. Default: `"sigmoid"` + padding: str, optional + One of `"valid"` or `"same"` (case-insensitive). Default: `"same"`. Note that `"same"` is + slightly inconsistent across backends with `strides` != 1, as described + `here `_. + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D layer + """ + def __init__(self, + filters: int, + kernel_size: int | tuple[int], + activation: str = "sigmoid", + padding: str = "same", **kwargs) -> None: + logger.debug(parse_class_init(locals())) + name = _get_name(kwargs.pop("name")) if "name" in kwargs else _get_name( + f"conv_output_{filters}") + self._conv = Conv2D(filters, + kernel_size, + padding=padding, + name=f"{name}_conv2d", + **kwargs) + self._activation = layers.Activation(activation, dtype="float32", name=name) + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Convolutional Output Layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Convolution 2D Layer + """ + var_x = self._conv(inputs) + return self._activation(var_x) + + +class Conv2DBlock(): # pylint:disable=too-many-instance-attributes + """ A standard Convolution 2D layer which applies user specified configuration to the + layer. + + Adds reflection padding if it has been selected by the user, and other post-processing + if requested by the plugin. + + Adds instance normalization if requested. Adds a LeakyReLU if a residual block follows. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. NB: If `use_depthwise` is ``True`` then a value must still be provided here, + but it will be ignored. Default: 5 + strides: tuple or int, optional + An integer or tuple/list of 2 integers, specifying the strides of the convolution along the + height and width. Can be a single integer to specify the same value for all spatial + dimensions. Default: `2` + padding: ["valid", "same"], optional + The padding to use. NB: If reflect padding has been selected in the user configuration + options, then this argument will be ignored in favor of reflect padding. Default: `"same"` + normalization: str or ``None``, optional + Normalization to apply after the Convolution Layer. Select one of "batch" or "instance". + Set to ``None`` to not apply normalization. Default: ``None`` + activation: str or ``None``, optional + The activation function to use. This is applied at the end of the convolution block. Select + one of `"leakyrelu"`, `"prelu"` or `"swish"`. Set to ``None`` to not apply an activation + function. Default: `"leakyrelu"` + use_depthwise: bool, optional + Set to ``True`` to use a Depthwise Convolution 2D layer rather than a standard Convolution + 2D layer. Default: ``False`` + relu_alpha: float + The alpha to use for LeakyRelu Activation. Default=`0.1` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D layer + """ + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 5, + strides: int | tuple[int, int] = 2, + padding: str = "same", + normalization: str | None = None, + activation: str | None = "leakyrelu", + use_depthwise: bool = False, + relu_alpha: float = 0.1, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + + self._name = kwargs.pop("name") if "name" in kwargs else _get_name(f"conv_{filters}") + self._use_reflect_padding = cfg.reflect_padding() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + self._args = (kernel_size, ) if use_depthwise else (filters, kernel_size) + self._strides = (strides, strides) if isinstance(strides, int) else strides + self._padding = "valid" if self._use_reflect_padding else padding + self._kwargs = kwargs + self._normalization = None if not normalization else normalization.lower() + self._activation = None if not activation else activation.lower() + self._use_depthwise = use_depthwise + self._relu_alpha = relu_alpha + + self._assert_arguments() + self._layers = self._get_layers() + logger.debug("Initialized %s", self.__class__.__name__) + + def _assert_arguments(self) -> None: + """ Validate the given arguments. """ + assert self._normalization in ("batch", "instance", None), ( + "normalization should be 'batch', 'instance' or None") + assert self._activation in ("leakyrelu", "swish", "prelu", None), ( + "activation should be 'leakyrelu', 'prelu', 'swish' or None") + + def _get_layers(self) -> list[layers.Layer]: + """ Obtain the layer chain for the block + + Returns + ------- + list[:class:`keras.layers.Layer] + The layers, in the correct order, to pass the tensor through + """ + retval = [] + if self._use_reflect_padding: + retval.append(ReflectionPadding2D(stride=self._strides[0], + kernel_size=self._args[-1][0], # type:ignore[index] + name=f"{self._name}_reflectionpadding2d")) + + conv: layers.Layer = ( + DepthwiseConv2D if self._use_depthwise + else Conv2D) # pyright:ignore[reportAssignmentType] + + retval.append(conv(*self._args, + strides=self._strides, + padding=self._padding, + name=f"{self._name}_{'dw' if self._use_depthwise else ''}conv2d", + **self._kwargs)) + + # normalization + if self._normalization == "instance": + retval.append(InstanceNormalization(name=f"{self._name}_instancenorm")) + + if self._normalization == "batch": + retval.append(layers.BatchNormalization(axis=3, name=f"{self._name}_batchnorm")) + + # activation + if self._activation == "leakyrelu": + retval.append(layers.LeakyReLU(self._relu_alpha, name=f"{self._name}_leakyrelu")) + if self._activation == "swish": + retval.append(Swish(name=f"{self._name}_swish")) + if self._activation == "prelu": + retval.append(layers.PReLU(name=f"{self._name}_prelu")) + + logger.debug("%s layers: %s", self.__class__.__name__, retval) + return retval + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Convolutional Layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Convolution 2D Layer + """ + var_x = inputs + for layer in self._layers: + var_x = layer(var_x) return var_x - def upscale(self, inp, filters, kernel_size=3, padding= 'same', use_instance_norm=False, res_block_follows=False, **kwargs): - """ Upscale Layer """ - logger.debug("inp: %s, filters: %s, kernel_size: %s, use_instance_norm: %s, kwargs: %s)", - inp, filters, kernel_size, use_instance_norm, kwargs) - kwargs = self.update_kwargs(kwargs) - if self.use_reflect_padding: - inp = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(inp) - padding = 'valid' - if self.use_icnr_init: - kwargs["kernel_initializer"] = ICNR(initializer=kwargs["kernel_initializer"]) - var_x = Conv2D(filters * 4, - kernel_size=kernel_size, - padding=padding, - **kwargs)(inp) - if use_instance_norm: - var_x = InstanceNormalization()(var_x) - if not res_block_follows: - var_x = LeakyReLU(0.1)(var_x) - if self.use_subpixel: - var_x = SubPixelUpscaling()(var_x) + +class SeparableConv2DBlock(): + """ Seperable Convolution Block. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. Default: 5 + strides: tuple or int, optional + An integer or tuple/list of 2 integers, specifying the strides of the convolution along + the height and width. Can be a single integer to specify the same value for all spatial + dimensions. Default: `2` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Separable + Convolutional 2D layer + """ + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 5, + strides: int | tuple[int, int] = 2, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + + initializer = _get_default_initializer(kwargs.pop("kernel_initializer", None)) + + name = _get_name(f"separableconv2d_{filters}") + self._conv = layers.SeparableConv2D( + filters, + kernel_size=kernel_size, + strides=strides, + padding="same", + depthwise_initializer=initializer, # pyright:ignore[reportArgumentType] + pointwise_initializer=initializer, # pyright:ignore[reportArgumentType] + name=f"{name}_seperableconv2d", + **kwargs) + self._activation = layers.Activation("relu", name=f"{name}_relu") + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Separable Convolutional 2D Block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Upscale Layer + """ + var_x = self._conv(inputs) + return self._activation(var_x) + + +# << UPSCALING >> + +class UpscaleBlock(): + """ An upscale layer for sub-pixel up-scaling. + + Adds reflection padding if it has been selected by the user, and other post-processing + if requested by the plugin. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. Default: 3 + padding: ["valid", "same"], optional + The padding to use. NB: If reflect padding has been selected in the user configuration + options, then this argument will be ignored in favor of reflect padding. Default: `"same"` + scale_factor: int, optional + The amount to upscale the image. Default: `2` + normalization: str or ``None``, optional + Normalization to apply after the Convolution Layer. Select one of "batch" or "instance". + Set to ``None`` to not apply normalization. Default: ``None`` + activation: str or ``None``, optional + The activation function to use. This is applied at the end of the convolution block. Select + one of `"leakyrelu"`, `"prelu"` or `"swish"`. Set to ``None`` to not apply an activation + function. Default: `"leakyrelu"` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D layer + """ + + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 3, + padding: str = "same", + scale_factor: int = 2, + normalization: str | None = None, + activation: str | None = "leakyrelu", + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + name = _get_name(f"upscale_{filters}") + self._conv = Conv2DBlock(filters * scale_factor * scale_factor, + kernel_size, + strides=(1, 1), + padding=padding, + normalization=normalization, + activation=activation, + name=f"{name}_conv2d", + is_upscale=True, + **kwargs) + self._shuffle = PixelShuffler(name=f"{name}_pixelshuffler", size=scale_factor) + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Convolutional Layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Upscale Layer + """ + var_x = self._conv(inputs) + return self._shuffle(var_x) + + +class Upscale2xBlock(): + """ Custom hybrid upscale layer for sub-pixel up-scaling. + + Most of up-scaling is approximating lighting gradients which can be accurately achieved + using linear fitting. This layer attempts to improve memory consumption by splitting + with bilinear and convolutional layers so that the sub-pixel update will get details + whilst the bilinear filter will get lighting. + + Adds reflection padding if it has been selected by the user, and other post-processing + if requested by the plugin. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. Default: 3 + padding: ["valid", "same"], optional + The padding to use. Default: `"same"` + activation: str or ``None``, optional + The activation function to use. This is applied at the end of the convolution block. Select + one of `"leakyrelu"`, `"prelu"` or `"swish"`. Set to ``None`` to not apply an activation + function. Default: `"leakyrelu"` + interpolation: ["nearest", "bilinear"], optional + Interpolation to use for up-sampling. Default: `"bilinear"` + scale_factor: int, optional + The amount to upscale the image. Default: `2` + sr_ratio: float, optional + The proportion of super resolution (pixel shuffler) filters to use. Non-fast mode only. + Default: `0.5` + fast: bool, optional + Use a faster up-scaling method that may appear more rugged. Default: ``False`` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D layer + """ + # TODO Class function this + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 3, + padding: str = "same", + activation: str | None = "leakyrelu", + interpolation: str = "bilinear", + sr_ratio: float = 0.5, + scale_factor: int = 2, + fast: bool = False, **kwargs) -> None: + logger.debug(parse_class_init(locals())) + + self._fast = fast + self._filters = filters if fast else filters - int(filters * sr_ratio) + + name = _get_name(f"upscale2x_{filters}_{'fast' if fast else 'hyb'}") + + self._upscale = UpscaleBlock(self._filters, + kernel_size=kernel_size, + padding=padding, + scale_factor=scale_factor, + activation=activation, + **kwargs) + + if self._fast or (not self._fast and self._filters > 0): + self._conv = Conv2D(self._filters, + 3, + padding=padding, + is_upscale=True, + name=f"{name}_conv2d", + **kwargs) + self._upsample = layers.UpSampling2D(size=(scale_factor, scale_factor), + interpolation=interpolation, + name=f"{name}_upsampling2D") + + self._joiner = layers.Add() if self._fast else layers.Concatenate( + name=f"{name}_concatenate") + + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Upscale 2x Layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Upscale Layer + """ + var_x = inputs + var_x_sr = None + if not self._fast: + var_x_sr = self._upscale(var_x) + if self._fast or (not self._fast and self._filters > 0): + + var_x2 = self._conv(var_x) + var_x2 = self._upsample(var_x2) + + if self._fast: + var_x1 = self._upscale(var_x) + var_x = self._joiner([var_x2, var_x1]) + else: + var_x = self._joiner([var_x_sr, var_x2]) + else: - var_x = PixelShuffler()(var_x) - return var_x + assert var_x_sr is not None + var_x = var_x_sr - # <<< DFaker Model Blocks >>> # - def res_block(self, inp, filters, kernel_size=3, padding= 'same', **kwargs): - """ Residual block """ - logger.debug("inp: %s, filters: %s, kernel_size: %s, kwargs: %s)", - inp, filters, kernel_size, kwargs) - kwargs = self.update_kwargs(kwargs) - var_x = LeakyReLU(alpha=0.2)(inp) - if self.use_reflect_padding: - var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x) - padding = 'valid' - var_x = Conv2D(filters, - kernel_size=kernel_size, - padding=padding, - **kwargs)(var_x) - var_x = LeakyReLU(alpha=0.2)(var_x) - if self.use_reflect_padding: - var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x) - padding = 'valid' - var_x = Conv2D(filters, - kernel_size=kernel_size, - padding=padding, - **kwargs)(var_x) - var_x = Scale(gamma_init=Constant(value=0.1))(var_x) - var_x = Add()([var_x, inp]) - var_x = LeakyReLU(alpha=0.2)(var_x) return var_x - # <<< Unbalanced Model Blocks >>> # - def conv_sep(self, inp, filters, kernel_size=5, strides=2, **kwargs): - """ Seperable Convolution Layer """ - logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s)", - inp, filters, kernel_size, strides, kwargs) - kwargs = self.update_kwargs(kwargs) - var_x = SeparableConv2D(filters, - kernel_size=kernel_size, - strides=strides, - padding='same', - **kwargs)(inp) - var_x = Activation("relu")(var_x) + +class UpscaleResizeImagesBlock(): + """ Upscale block that uses the Keras Backend function resize_images to perform the up scaling + Similar in methodology to the :class:`Upscale2xBlock` + + Adds reflection padding if it has been selected by the user, and other post-processing + if requested by the plugin. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. Default: 3 + padding: ["valid", "same"], optional + The padding to use. Default: `"same"` + activation: str or ``None``, optional + The activation function to use. This is applied at the end of the convolution block. Select + one of `"leakyrelu"`, `"prelu"` or `"swish"`. Set to ``None`` to not apply an activation + function. Default: `"leakyrelu"` + scale_factor: int, optional + The amount to upscale the image. Default: `2` + interpolation: ["nearest", "bilinear"], optional + Interpolation to use for up-sampling. Default: `"bilinear"` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D layer + """ + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 3, + padding: str = "same", + activation: str | None = "leakyrelu", + scale_factor: int = 2, + interpolation: T.Literal["nearest", "bilinear"] = "bilinear") -> None: + logger.debug(parse_class_init(locals())) + name = _get_name(f"upscale_ri_{filters}") + + self._resize = KResizeImages(size=scale_factor, + interpolation=interpolation, + name=f"{name}_resize") + self._conv = Conv2D(filters, + kernel_size, + strides=1, + padding=padding, + is_upscale=True, + name=f"{name}_conv") + self._conv_trans = layers.Conv2DTranspose(filters, + 3, + strides=2, + padding=padding, + name=f"{name}_convtrans") + self._add = layers.Add() + + if activation == "leakyrelu": + self._acivation = layers.LeakyReLU(0.2, name=f"{name}_leakyrelu") + if activation == "swish": + self._acivation = Swish(name=f"{name}_swish") + if activation == "prelu": + self._acivation = layers.PReLU(name=f"{name}_prelu") + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Resize Images Layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Upscale Layer + """ + var_x = inputs + + var_x_sr = self._resize(var_x) + var_x_sr = self._conv(var_x_sr) + + var_x_us = self._conv_trans(var_x) + + var_x = self._add([var_x_sr, var_x_us]) + + return self._acivation(var_x) + + +class UpscaleDNYBlock(): + """ Upscale block that implements methodology similar to the Disney Research Paper using an + upsampling2D block and 2 x convolutions + + Adds reflection padding if it has been selected by the user, and other post-processing + if requested by the plugin. + + References + ---------- + https://studios.disneyresearch.com/2020/06/29/high-resolution-neural-face-swapping-for-visual-effects/ + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. Default: 3 + activation: str or ``None``, optional + The activation function to use. This is applied at the end of the convolution block. Select + one of `"leakyrelu"`, `"prelu"` or `"swish"`. Set to ``None`` to not apply an activation + function. Default: `"leakyrelu"` + size: int, optional + The amount to upscale the image. Default: `2` + interpolation: ["nearest", "bilinear"], optional + Interpolation to use for up-sampling. Default: `"bilinear"` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D + layers + """ + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 3, + padding: str = "same", + activation: str | None = "leakyrelu", + size: int = 2, + interpolation: str = "bilinear", + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + name = _get_name(f"upscale_dny_{filters}") + self._upsample = layers.UpSampling2D(size=size, + interpolation=interpolation, + name=f"{name}_upsample2d") + self._convs = [Conv2DBlock(filters, + kernel_size, + strides=1, + padding=padding, + activation=activation, + relu_alpha=0.2, + name=f"{name}_conv2d_{idx + 1}", + is_upscale=True, + **kwargs) + for idx in range(2)] + logger.debug("Initialized %s", self.__class__.__name__) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the UpscaleDNY block + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + var_x = self._upsample(inputs) + for conv in (self._convs): + var_x = conv(var_x) return var_x -# <<< GAN V2.2 Blocks >>> # -# TODO Merge these into NNBLock class when porting GAN2.2 - - -# Gan Constansts: -GAN22_CONV_INIT = "he_normal" -GAN22_REGULARIZER = 1e-4 - - -# Gan Blocks: -def normalization(inp, norm='none', group='16'): - """ GAN Normalization """ - if norm == 'layernorm': - var_x = GroupNormalization(group=group)(inp) - elif norm == 'batchnorm': - var_x = BatchNormalization()(inp) - elif norm == 'groupnorm': - var_x = GroupNormalization(group=16)(inp) - elif norm == 'instancenorm': - var_x = InstanceNormalization()(inp) - elif norm == 'hybrid': - if group % 2 == 1: - raise ValueError("Output channels must be an even number for hybrid norm, " - "received {}.".format(group)) - filt = group - var_x_0 = Lambda(lambda var_x: var_x[..., :filt // 2])(var_x) - var_x_1 = Lambda(lambda var_x: var_x[..., filt // 2:])(var_x) - var_x_0 = Conv2D(filt // 2, - kernel_size=1, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer=GAN22_CONV_INIT)(var_x_0) - var_x_1 = InstanceNormalization()(var_x_1) - var_x = concatenate([var_x_0, var_x_1], axis=-1) - else: - var_x = inp - return var_x - - -def upscale_ps(inp, filters, initializer, use_norm=False, norm="none"): - """ GAN Upscaler - Pixel Shuffler """ - var_x = Conv2D(filters * 4, - kernel_size=3, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer=initializer, - padding="same")(inp) - var_x = LeakyReLU(0.2)(var_x) - var_x = normalization(var_x, norm, filters) if use_norm else var_x - var_x = PixelShuffler()(var_x) - return var_x - - -def upscale_nn(inp, filters, use_norm=False, norm="none"): - """ GAN Neural Network """ - var_x = UpSampling2D()(inp) - var_x = reflect_padding_2d(var_x, 1) - var_x = Conv2D(filters, - kernel_size=3, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer="he_normal")(var_x) - var_x = normalization(var_x, norm, filters) if use_norm else var_x - return var_x - - -def reflect_padding_2d(inp, pad=1): - """ GAN Reflect Padding (2D) """ - var_x = Lambda(lambda var_x: tf.pad(var_x, - [[0, 0], [pad, pad], [pad, pad], [0, 0]], - mode="REFLECT"))(inp) - return var_x - - -def conv_gan(inp, filters, use_norm=False, strides=2, norm='none'): - """ GAN Conv Block """ - var_x = Conv2D(filters, - kernel_size=3, - strides=strides, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer=GAN22_CONV_INIT, - use_bias=False, - padding="same")(inp) - var_x = Activation("relu")(var_x) - var_x = normalization(var_x, norm, filters) if use_norm else var_x - return var_x - - -def conv_d_gan(inp, filters, use_norm=False, norm='none'): - """ GAN Discriminator Conv Block """ - var_x = inp - var_x = Conv2D(filters, - kernel_size=4, - strides=2, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer=GAN22_CONV_INIT, - use_bias=False, - padding="same")(var_x) - var_x = LeakyReLU(alpha=0.2)(var_x) - var_x = normalization(var_x, norm, filters) if use_norm else var_x - return var_x - - -def res_block_gan(inp, filters, use_norm=False, norm='none'): - """ GAN Res Block """ - var_x = Conv2D(filters, - kernel_size=3, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer=GAN22_CONV_INIT, - use_bias=False, - padding="same")(inp) - var_x = LeakyReLU(alpha=0.2)(var_x) - var_x = normalization(var_x, norm, filters) if use_norm else var_x - var_x = Conv2D(filters, - kernel_size=3, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), - kernel_initializer=GAN22_CONV_INIT, - use_bias=False, - padding="same")(var_x) - var_x = add([var_x, inp]) - var_x = LeakyReLU(alpha=0.2)(var_x) - var_x = normalization(var_x, norm, filters) if use_norm else var_x - return var_x - - -def self_attn_block(inp, n_c, squeeze_factor=8): - """ GAN Self Attention Block - Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow + +# << OTHER BLOCKS >> +class ResidualBlock(): + """ Residual block from dfaker. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution) + kernel_size: int, optional + An integer or tuple/list of 2 integers, specifying the height and width of the 2D + convolution window. Can be a single integer to specify the same value for all spatial + dimensions. Default: 3 + padding: ["valid", "same"], optional + The padding to use. Default: `"same"` + kwargs: dict + Any additional Keras standard layer keyword arguments to pass to the Convolutional 2D layer + + Returns + ------- + tensor + The output tensor from the Upscale layer """ - msg = "Input channels must be >= {}, recieved nc={}".format(squeeze_factor, n_c) - assert n_c // squeeze_factor > 0, msg - var_x = inp - shape_x = var_x.get_shape().as_list() - - var_f = Conv2D(n_c // squeeze_factor, 1, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x) - var_g = Conv2D(n_c // squeeze_factor, 1, - kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x) - var_h = Conv2D(n_c, 1, kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x) - - shape_f = var_f.get_shape().as_list() - shape_g = var_g.get_shape().as_list() - shape_h = var_h.get_shape().as_list() - flat_f = Reshape((-1, shape_f[-1]))(var_f) - flat_g = Reshape((-1, shape_g[-1]))(var_g) - flat_h = Reshape((-1, shape_h[-1]))(var_h) - - var_s = Lambda(lambda var_x: K.batch_dot(var_x[0], - Permute((2, 1))(var_x[1])))([flat_g, flat_f]) - - beta = Softmax(axis=-1)(var_s) - var_o = Lambda(lambda var_x: K.batch_dot(var_x[0], var_x[1]))([beta, flat_h]) - var_o = Reshape(shape_x[1:])(var_o) - var_o = Scale()(var_o) - - out = add([var_o, inp]) - return out + def __init__(self, + filters: int, + kernel_size: int | tuple[int, int] = 3, + padding: str = "same", + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + + self._name = _get_name(f"residual_{filters}") + self._use_reflect_padding = cfg.reflect_padding() + + self._filters = filters + self._kernel_size = (kernel_size, + kernel_size) if isinstance(kernel_size, int) else kernel_size + self._padding = "valid" if self._use_reflect_padding else padding + self._kwargs = kwargs + + self._layers = self._get_layers() + self._add = layers.Add() + self._activation = layers.LeakyReLU(negative_slope=0.2, name=f"{self._name}_leakyrelu_3") + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_layers(self) -> list[layers.Layer]: + """ Obtain the layer chain for the block + + Returns + ------- + list[:class:`keras.layers.Layer] + The layers, in the correct order, to pass the tensor through + """ + retval: list[layers.Layer] = [] + if self._use_reflect_padding: + retval.append(ReflectionPadding2D(stride=1, + kernel_size=self._kernel_size[0], + name=f"{self._name}_reflectionpadding2d_0")) + + retval.append(Conv2D(self._filters, # pyright:ignore[reportArgumentType] + kernel_size=self._kernel_size, + padding=self._padding, + name=f"{self._name}_conv2d_0", + **self._kwargs)) + retval.append(layers.LeakyReLU(negative_slope=0.2, name=f"{self._name}_leakyrelu_1")) + + if self._use_reflect_padding: + retval.append(ReflectionPadding2D(stride=1, + kernel_size=self._kernel_size[0], + name=f"{self._name}_reflectionpadding2d_1")) + + kwargs = {key: val for key, val in self._kwargs.items() if key != "kernel_initializer"} + if not cfg.conv_aware_init(): + kwargs["kernel_initializer"] = initializers.VarianceScaling(scale=0.2, + mode="fan_in", + distribution="uniform") + retval.append(Conv2D(self._filters, # pyright:ignore[reportArgumentType] + kernel_size=self._kernel_size, + padding=self._padding, + name=f"{self._name}_conv2d_1", + **kwargs)) + + logger.debug("%s layers: %s", self.__class__.__name__, retval) + return retval + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Faceswap Residual Block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Upscale Layer + """ + var_x = inputs + for layer in self._layers: + var_x = layer(var_x) + + var_x = self._add([var_x, inputs]) + return self._activation(var_x) + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/normalization.py b/lib/model/normalization.py index ec4dbb1f5e..5cbde8f049 100644 --- a/lib/model/normalization.py +++ b/lib/model/normalization.py @@ -1,77 +1,431 @@ #!/usr/bin/env python3 -""" Normaliztion methods for faceswap.py - Code from: - shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" +""" Normalization methods for faceswap.py specific to Torch backend """ +from __future__ import annotations -import sys import inspect +import logging +import sys +import typing as T + +from keras import constraints, initializers, InputSpec, layers, ops, regularizers, saving + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class AdaInstanceNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Adaptive Instance Normalization Layer for Keras. + + Parameters + ---------- + axis: int, optional + The axis that should be normalized (typically the features axis). For instance, after a + `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in + :class:`InstanceNormalization`. Setting `axis=None` will normalize all values in each + instance of the batch. Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid + errors. Default: ``None`` + momentum: float, optional + Momentum for the moving mean and the moving variance. Default: `0.99` + epsilon: float, optional + Small float added to variance to avoid dividing by zero. Default: `1e-3` + center: bool, optional + If ``True``, add offset of `beta` to normalized tensor. If ``False``, `beta` is ignored. + Default: ``True`` + scale: bool, optional + If ``True``, multiply by `gamma`. If ``False``, `gamma` is not used. When the next layer + is linear (also e.g. `relu`), this can be disabled since the scaling will be done by + the next layer. Default: ``True`` + + References + ---------- + Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization - \ + https://arxiv.org/abs/1703.06868 + """ + def __init__(self, + axis: int = -1, + momentum: float = 0.99, + epsilon: float = 1e-3, + center: bool = True, + scale: bool = True, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(**kwargs) + self.axis = axis + self.momentum = momentum + self.epsilon = epsilon + self.center = center + self.scale = scale + logger.debug("Initialized %s", self.__class__.__name__) + + def build(self, input_shape: tuple[tuple[int, ...], ...]) -> None: + """Creates the layer weights. + + Parameters + ---------- + input_shape: tuple[int, ...] + Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to + reference for weight shape computations. + """ + dim = input_shape[0][self.axis] + if dim is None: + raise ValueError('Axis ' + str(self.axis) + ' of ' + 'input tensor should have a defined dimension ' + 'but the layer received an input with shape ' + + str(input_shape[0]) + '.') + + super().build(input_shape) + + def call(self, inputs: KerasTensor # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + input_shape = inputs[0].shape + reduction_axes = list(range(0, len(input_shape))) + + beta = inputs[1] + gamma = inputs[2] + + if self.axis is not None: + del reduction_axes[self.axis] + + del reduction_axes[0] + mean = ops.mean(inputs[0], reduction_axes, keepdims=True) + stddev = ops.std(inputs[0], reduction_axes, keepdims=True) + self.epsilon + normed = (inputs[0] - mean) / stddev + + return normed * gamma + beta + + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. + + The Keras configuration for the layer. + + Returns + -------- + dict[str, Any] + A python dictionary containing the layer configuration + """ + config = { + 'axis': self.axis, + 'momentum': self.momentum, + 'epsilon': self.epsilon, + 'center': self.center, + 'scale': self.scale + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> int: + """ Calculate the output shape from this layer. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + + Returns + ------- + int + The output shape to the layer + """ + return input_shape[0] + + +class GroupNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Group Normalization + + Parameters + ---------- + axis: int, optional + The axis that should be normalized (typically the features axis). For instance, after a + `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in + :class:`InstanceNormalization`. Setting `axis=None` will normalize all values in each + instance of the batch. Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid + errors. Default: ``None`` + gamma_init: str, optional + Initializer for the gamma weight. Default: `"one"` + beta_init: str, optional + Initializer for the beta weight. Default `"zero"` + gamma_regularizer: varies, optional + Optional regularizer for the gamma weight. Default: ``None`` + beta_regularizer: varies, optional + Optional regularizer for the beta weight. Default ``None`` + epsilon: float, optional + Small float added to variance to avoid dividing by zero. Default: `1e-3` + group: int, optional + The group size. Default: `32` + data_format: ["channels_first", "channels_last"], optional + The required data format. Optional. Default: ``None`` + kwargs: dict + Any additional standard Keras Layer key word arguments + + References + ---------- + Shaoanlu GAN: https://github.com/shaoanlu/faceswap-GAN + """ + # pylint:disable=too-many-instance-attributes + def __init__(self, + axis: int = -1, + gamma_init: str = 'one', + beta_init: str = 'zero', + gamma_regularizer: T.Any = None, + beta_regularizer: T.Any = None, + epsilon: float = 1e-6, + group: int = 32, + data_format: str | None = None, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + self.beta = None + self.gamma = None + super().__init__(**kwargs) + self.axis = axis if isinstance(axis, (list, tuple)) else [axis] + self.gamma_init = initializers.get(gamma_init) + self.beta_init = initializers.get(beta_init) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.epsilon = epsilon + self.group = group + self.data_format = "channels_last" if data_format is None else data_format + + self.supports_masking = True + logger.debug("Initialized %s", self.__class__.__name__) + + def build(self, input_shape: tuple[int, ...]) -> None: + """Creates the layer weights. + + Parameters + ---------- + input_shape: tuple[int, ...] + Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to + reference for weight shape computations. + """ + input_spec = [InputSpec(shape=input_shape)] + self.input_spec = input_spec # pylint:disable=attribute-defined-outside-init + shape = [1 for _ in input_shape] + if self.data_format == 'channels_last': + channel_axis = -1 + shape[channel_axis] = input_shape[channel_axis] + elif self.data_format == 'channels_first': + channel_axis = 1 + shape[channel_axis] = input_shape[channel_axis] + # for i in self.axis: + # shape[i] = input_shape[i] + self.gamma = self.add_weight(shape=shape, + initializer=self.gamma_init, + regularizer=self.gamma_regularizer, + name='gamma') + self.beta = self.add_weight(shape=shape, + initializer=self.beta_init, + regularizer=self.beta_regularizer, + name='beta') + self.built = True # pylint:disable=attribute-defined-outside-init + + def _process_4_channel(self, inputs: KerasTensor) -> KerasTensor: + """ Logic for processing 4 channel inputs + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + input_shape = inputs.shape + if self.data_format == 'channels_last': + batch_size, height, width, channels = input_shape + if batch_size is None: + batch_size = -1 + + if channels < self.group: + raise ValueError('Input channels should be larger than group size' + + '; Received input channels: ' + str(channels) + + '; Group size: ' + str(self.group)) + + var_x = ops.reshape(inputs, (batch_size, + height, + width, + self.group, + channels // self.group)) + mean = ops.mean(var_x, axis=[1, 2, 4], keepdims=True) + std = ops.sqrt(ops.var(var_x, axis=[1, 2, 4], keepdims=True) + self.epsilon) + var_x = (var_x - mean) / std + + var_x = ops.reshape(var_x, (batch_size, height, width, channels)) + return self.gamma * var_x + self.beta + + # Channels first + batch_size, channels, height, width = input_shape + if batch_size is None: + batch_size = -1 + + if channels < self.group: + raise ValueError('Input channels should be larger than group size' + + '; Received input channels: ' + str(channels) + + '; Group size: ' + str(self.group)) + + var_x = ops.reshape(inputs, (batch_size, + self.group, + channels // self.group, + height, + width)) + mean = ops.mean(var_x, axis=[2, 3, 4], keepdims=True) + std = ops.sqrt(ops.var(var_x, axis=[2, 3, 4], keepdims=True) + self.epsilon) + var_x = (var_x - mean) / std + + var_x = ops.reshape(var_x, (batch_size, channels, height, width)) + return self.gamma * var_x + self.beta + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Calculate the output shape from this layer. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + + Returns + ------- + int + The output shape to the layer + """ + return input_shape + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + input_shape = inputs.shape + if len(input_shape) != 4 and len(input_shape) != 2: + raise ValueError('Inputs should have rank ' + + str(4) + " or " + str(2) + + '; Received input shape:', str(input_shape)) + + if len(input_shape) == 4: + return self._process_4_channel(inputs) + + reduction_axes = list(range(0, len(input_shape))) + del reduction_axes[0] + batch_size, _ = input_shape + if batch_size is None: + batch_size = -1 + + mean = ops.mean(inputs, keepdims=True) + std = ops.sqrt(ops.var(inputs, keepdims=True) + self.epsilon) + var_x = (inputs - mean) / std -from keras.engine import Layer, InputSpec -from keras import initializers, regularizers, constraints -from keras import backend as K -from keras.utils.generic_utils import get_custom_objects + return self.gamma * var_x + self.beta + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. -def to_list(inp): - """ Convert to list """ - if not isinstance(inp, (list, tuple)): - return [inp] - return list(inp) + The Keras configuration for the layer. + Returns + -------- + dict[str, Any]: + A python dictionary containing the layer configuration + """ + config = {'epsilon': self.epsilon, + 'axis': self.axis, + 'gamma_init': initializers.serialize(self.gamma_init), + 'beta_init': initializers.serialize(self.beta_init), + 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), + 'beta_regularizer': regularizers.serialize(self.gamma_regularizer), + 'group': self.group} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) -class InstanceNormalization(Layer): + +class InstanceNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). - Normalize the activations of the previous layer at each step, - i.e. applies a transformation that maintains the mean activation - close to 0 and the activation standard deviation close to 1. - # Arguments - axis: Integer, the axis that should be normalized - (typically the features axis). - For instance, after a `Conv2D` layer with - `data_format="channels_first"`, - set `axis=1` in `InstanceNormalization`. - Setting `axis=None` will normalize all values in each instance of the batch. - Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. - If False, `beta` is ignored. - scale: If True, multiply by `gamma`. - If False, `gamma` is not used. - When the next layer is linear (also e.g. `nn.relu`), - this can be disabled since the scaling - will be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - beta_constraint: Optional constraint for the beta weight. - gamma_constraint: Optional constraint for the gamma weight. - # Input shape - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - # Output shape - Same shape as input. - # References - - [Layer Normalization](https://arxiv.org/abs/1607.06450) - - [Instance Normalization: The Missing Ingredient for Fast - Stylization](https://arxiv.org/abs/1607.08022) + + Normalize the activations of the previous layer at each step, i.e. applies a transformation + that maintains the mean activation close to 0 and the activation standard deviation close to 1. + + Parameters + ---------- + axis: int, optional + The axis that should be normalized (typically the features axis). For instance, after a + `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in + :class:`InstanceNormalization`. Setting `axis=None` will normalize all values in each + instance of the batch. Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid + errors. Default: ``None`` + epsilon: float, optional + Small float added to variance to avoid dividing by zero. Default: `1e-3` + center: bool, optional + If ``True``, add offset of `beta` to normalized tensor. If ``False``, `beta` is ignored. + Default: ``True`` + scale: bool, optional + If ``True``, multiply by `gamma`. If ``False``, `gamma` is not used. When the next layer + is linear (also e.g. `relu`), this can be disabled since the scaling will be done by + the next layer. Default: ``True`` + beta_initializer: str, optional + Initializer for the beta weight. Default: `"zeros"` + gamma_initializer: str, optional + Initializer for the gamma weight. Default: `"ones"` + beta_regularizer: str, optional + Optional regularizer for the beta weight. Default: ``None`` + gamma_regularizer: str, optional + Optional regularizer for the gamma weight. Default: ``None`` + beta_constraint: float, optional + Optional constraint for the beta weight. Default: ``None`` + gamma_constraint: float, optional + Optional constraint for the gamma weight. Default: ``None`` + + References + ---------- + - Layer Normalization - https://arxiv.org/abs/1607.06450 + + - Instance Normalization: The Missing Ingredient for Fast Stylization - \ + https://arxiv.org/abs/1607.08022 """ + # pylint:disable=too-many-instance-attributes,too-many-arguments,too-many-positional-arguments def __init__(self, - axis=None, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - beta_constraint=None, - gamma_constraint=None, - **kwargs): + axis: int | None = None, + epsilon: float = 1e-3, + center: bool = True, + scale: bool = True, + beta_initializer: str = "zeros", + gamma_initializer: str = "ones", + beta_regularizer: T.Any = None, + gamma_regularizer: T.Any = None, + beta_constraint: T.Any = None, + gamma_constraint: T.Any = None, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) self.beta = None self.gamma = None - super(InstanceNormalization, self).__init__(**kwargs) + super().__init__(**kwargs) self.supports_masking = True self.axis = axis self.epsilon = epsilon @@ -83,16 +437,25 @@ def __init__(self, self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_constraint = constraints.get(beta_constraint) self.gamma_constraint = constraints.get(gamma_constraint) + logger.debug("Initialized %s", self.__class__.__name__) + + def build(self, input_shape: tuple[int, ...]) -> None: + """Creates the layer weights. - def build(self, input_shape): + Parameters + ---------- + input_shape: tuple[int, ...] + Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to + reference for weight shape computations. + """ ndim = len(input_shape) if self.axis == 0: - raise ValueError('Axis cannot be zero') + raise ValueError("Axis cannot be zero") if (self.axis is not None) and (ndim == 2): - raise ValueError('Cannot specify axis for rank 1 tensor') + raise ValueError("Cannot specify axis for rank 1 tensor") - self.input_spec = InputSpec(ndim=ndim) + self.input_spec = InputSpec(ndim=ndim) # pylint:disable=attribute-defined-outside-init if self.axis is None: shape = (1,) @@ -101,7 +464,7 @@ def build(self, input_shape): if self.scale: self.gamma = self.add_weight(shape=shape, - name='gamma', + name="gamma", initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) @@ -109,16 +472,45 @@ def build(self, input_shape): self.gamma = None if self.center: self.beta = self.add_weight(shape=shape, - name='beta', + name="beta", initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint) else: self.beta = None - self.built = True - - def call(self, inputs, training=None): - input_shape = K.int_shape(inputs) + self.built = True # pylint:disable=attribute-defined-outside-init + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ Calculate the output shape from this layer. + + Parameters + ---------- + input_shape: tuple + The input shape to the layer + + Returns + ------- + int + The output shape to the layer + """ + return input_shape + + def call(self, inputs: KerasTensor # pylint:disable=arguments-differ + ) -> KerasTensor: + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + input_shape = inputs.shape reduction_axes = list(range(0, len(input_shape))) if self.axis is not None: @@ -126,8 +518,8 @@ def call(self, inputs, training=None): del reduction_axes[0] - mean = K.mean(inputs, reduction_axes, keepdims=True) - stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon + mean = ops.mean(inputs, reduction_axes, keepdims=True) + stddev = ops.std(inputs, reduction_axes, keepdims=True) + self.epsilon normed = (inputs - mean) / stddev broadcast_shape = [1] * len(input_shape) @@ -135,155 +527,208 @@ def call(self, inputs, training=None): broadcast_shape[self.axis] = input_shape[self.axis] if self.scale: - broadcast_gamma = K.reshape(self.gamma, broadcast_shape) + broadcast_gamma = ops.reshape(self.gamma, broadcast_shape) normed = normed * broadcast_gamma if self.center: - broadcast_beta = K.reshape(self.beta, broadcast_shape) + broadcast_beta = ops.reshape(self.beta, broadcast_shape) normed = normed + broadcast_beta return normed - def get_config(self): + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. + + A layer config is a Python dictionary (serializable) containing the configuration of a + layer. The same layer can be reinstated later (without its trained weights) from this + configuration. + + The configuration of a layer does not include connectivity information, nor the layer + class name. These are handled by `Network` (one layer of abstraction above). + + Returns + -------- + dict[str, Any] + A python dictionary containing the layer configuration + """ config = { - 'axis': self.axis, - 'epsilon': self.epsilon, - 'center': self.center, - 'scale': self.scale, - 'beta_initializer': initializers.serialize(self.beta_initializer), - 'gamma_initializer': initializers.serialize(self.gamma_initializer), - 'beta_regularizer': regularizers.serialize(self.beta_regularizer), - 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), - 'beta_constraint': constraints.serialize(self.beta_constraint), - 'gamma_constraint': constraints.serialize(self.gamma_constraint) + "axis": self.axis, + "epsilon": self.epsilon, + "center": self.center, + "scale": self.scale, + "beta_initializer": initializers.serialize(self.beta_initializer), + "gamma_initializer": initializers.serialize(self.gamma_initializer), + "beta_regularizer": regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), + "beta_constraint": constraints.serialize(self.beta_constraint), + "gamma_constraint": constraints.serialize(self.gamma_constraint) } - base_config = super(InstanceNormalization, self).get_config() + base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) -class GroupNormalization(Layer): - """ Group Normalization - from: shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" - - def __init__(self, axis=-1, - gamma_init='one', beta_init='zero', - gamma_regularizer=None, beta_regularizer=None, - epsilon=1e-6, - group=32, - data_format=None, - **kwargs): - self.beta = None - self.gamma = None - super(GroupNormalization, self).__init__(**kwargs) +class RMSNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Root Mean Square Layer Normalization (Biao Zhang, Rico Sennrich, 2019) + + RMSNorm is a simplification of the original layer normalization (LayerNorm). LayerNorm is a + regularization technique that might handle the internal covariate shift issue so as to + stabilize the layer activations and improve model convergence. It has been proved quite + successful in NLP-based model. In some cases, LayerNorm has become an essential component + to enable model optimization, such as in the SOTA NMT model Transformer. + + RMSNorm simplifies LayerNorm by removing the mean-centering operation, or normalizing layer + activations with RMS statistic. + + Parameters + ---------- + axis: int + The axis to normalize across. Typically this is the features axis. The left-out axes are + typically the batch axis/axes. This argument defaults to `-1`, the last dimension in the + input. + epsilon: float, optional + Small float added to variance to avoid dividing by zero. Default: `1e-8` + partial: float, optional + Partial multiplier for calculating pRMSNorm. Valid values are between `0.0` and `1.0`. + Setting to `0.0` or `1.0` disables. Default: `0.0` + bias: bool, optional + Whether to use a bias term for RMSNorm. Disabled by default because RMSNorm does not + enforce re-centering invariance. Default ``False`` + kwargs: dict + Standard keras layer kwargs + + References + ---------- + - RMS Normalization - https://arxiv.org/abs/1910.07467 + - Official implementation - https://github.com/bzhangGo/rmsnorm + """ + def __init__(self, + axis: int = -1, + epsilon: float = 1e-8, + partial: float = 0.0, + bias: bool = False, + **kwargs) -> None: + logger.debug(parse_class_init(locals())) + self.scale = None + super().__init__(**kwargs) + + # Checks + if not isinstance(axis, int): + raise TypeError(f"Expected an int for the argument 'axis', but received: {axis}") + + if not 0.0 <= partial <= 1.0: + raise ValueError(f"partial must be between 0.0 and 1.0, but received {partial}") - self.axis = to_list(axis) - self.gamma_init = initializers.get(gamma_init) - self.beta_init = initializers.get(beta_init) - self.gamma_regularizer = regularizers.get(gamma_regularizer) - self.beta_regularizer = regularizers.get(beta_regularizer) + self.axis = axis self.epsilon = epsilon - self.group = group - self.data_format = K.normalize_data_format(data_format) - - self.supports_masking = True - - def build(self, input_shape): - self.input_spec = [InputSpec(shape=input_shape)] - shape = [1 for _ in input_shape] - if self.data_format == 'channels_last': - channel_axis = -1 - shape[channel_axis] = input_shape[channel_axis] - elif self.data_format == 'channels_first': - channel_axis = 1 - shape[channel_axis] = input_shape[channel_axis] - # for i in self.axis: - # shape[i] = input_shape[i] - self.gamma = self.add_weight(shape=shape, - initializer=self.gamma_init, - regularizer=self.gamma_regularizer, - name='gamma') - self.beta = self.add_weight(shape=shape, - initializer=self.beta_init, - regularizer=self.beta_regularizer, - name='beta') - self.built = True - - def call(self, inputs, mask=None): - input_shape = K.int_shape(inputs) - if len(input_shape) != 4 and len(input_shape) != 2: - raise ValueError('Inputs should have rank ' + - str(4) + " or " + str(2) + - '; Received input shape:', str(input_shape)) - - if len(input_shape) == 4: - if self.data_format == 'channels_last': - batch_size, height, width, channels = input_shape - if batch_size is None: - batch_size = -1 - - if channels < self.group: - raise ValueError('Input channels should be larger than group size' + - '; Received input channels: ' + str(channels) + - '; Group size: ' + str(self.group)) - - var_x = K.reshape(inputs, (batch_size, - height, - width, - self.group, - channels // self.group)) - mean = K.mean(var_x, axis=[1, 2, 4], keepdims=True) - std = K.sqrt(K.var(var_x, axis=[1, 2, 4], keepdims=True) + self.epsilon) - var_x = (var_x - mean) / std - - var_x = K.reshape(var_x, (batch_size, height, width, channels)) - retval = self.gamma * var_x + self.beta - elif self.data_format == 'channels_first': - batch_size, channels, height, width = input_shape - if batch_size is None: - batch_size = -1 - - if channels < self.group: - raise ValueError('Input channels should be larger than group size' + - '; Received input channels: ' + str(channels) + - '; Group size: ' + str(self.group)) - - var_x = K.reshape(inputs, (batch_size, - self.group, - channels // self.group, - height, - width)) - mean = K.mean(var_x, axis=[2, 3, 4], keepdims=True) - std = K.sqrt(K.var(var_x, axis=[2, 3, 4], keepdims=True) + self.epsilon) - var_x = (var_x - mean) / std - - var_x = K.reshape(var_x, (batch_size, channels, height, width)) - retval = self.gamma * var_x + self.beta - - elif len(input_shape) == 2: - reduction_axes = list(range(0, len(input_shape))) - del reduction_axes[0] - batch_size, _ = input_shape - if batch_size is None: - batch_size = -1 - - mean = K.mean(inputs, keepdims=True) - std = K.sqrt(K.var(inputs, keepdims=True) + self.epsilon) - var_x = (inputs - mean) / std - - retval = self.gamma * var_x + self.beta - return retval - - def get_config(self): - config = {'epsilon': self.epsilon, - 'axis': self.axis, - 'gamma_init': initializers.serialize(self.gamma_init), - 'beta_init': initializers.serialize(self.beta_init), - 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), - 'beta_regularizer': regularizers.serialize(self.gamma_regularizer), - 'group': self.group} - base_config = super(GroupNormalization, self).get_config() + self.partial = partial + self.bias = bias + self.offset = 0. + logger.debug("Initialized %s", self.__class__.__name__) + + def build(self, input_shape: tuple[int, ...]) -> None: + """ Validate and populate :attr:`axis` + + Parameters + ---------- + input_shape: tuple[int, ...] + Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to + reference for weight shape computations. + """ + ndims = len(input_shape) + if ndims is None: + raise ValueError(f"Input shape {input_shape} has undefined rank.") + + # Resolve negative axis + if self.axis < 0: + self.axis += ndims + + # Validate axes + if self.axis < 0 or self.axis >= ndims: + raise ValueError(f"Invalid axis: {self.axis}") + + param_shape = [input_shape[self.axis]] + self.scale = self.add_weight( + name="scale", + shape=param_shape, + initializer="ones") + if self.bias: + self.offset = self.add_weight( + name="offset", + shape=param_shape, + initializer="zeros") + + self.built = True # pylint:disable=attribute-defined-outside-init + + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Call Root Mean Square Layer Normalization + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor, or list/tuple of input tensors + + Returns + ------- + :class:`keras.KerasTensor` + A tensor or list/tuple of tensors + """ + # Compute the axes along which to reduce the mean / variance + input_shape = inputs.shape + layer_size = input_shape[self.axis] + + if self.partial in (0.0, 1.0): + mean_square = ops.mean(ops.square(inputs), axis=self.axis, keepdims=True) + else: + partial_size = int(layer_size * self.partial) + partial_x, _ = ops.split(inputs, [partial_size], axis=self.axis) + mean_square = ops.mean(ops.square(partial_x), axis=self.axis, keepdims=True) + + recip_square_root = ops.rsqrt(mean_square + self.epsilon) + output = self.scale * inputs * recip_square_root + self.offset + return output + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """ The output shape of the layer is the same as the input shape. + + Parameters + ---------- + input_shape: tuple[int, ...] + The input shape to the layer + + Returns + ------- + tuple[int, ...] + The output shape to the layer + """ + return input_shape + + def get_config(self) -> dict[str, T.Any]: + """Returns the config of the layer. + + A layer config is a Python dictionary (serializable) containing the configuration of a + layer. The same layer can be reinstated later (without its trained weights) from this + configuration. + + The configuration of a layer does not include connectivity information, nor the layer + class name. These are handled by `Network` (one layer of abstraction above). + + Returns + -------- + dict[str, Any]: + A python dictionary containing the layer configuration + """ + base_config = super().get_config() + config = {"axis": self.axis, + "epsilon": self.epsilon, + "partial": self.partial, + "bias": self.bias} return dict(list(base_config.items()) + list(config.items())) -# Update normalizations into Keras custom objects +# Update normalization into Keras custom objects for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and obj.__module__ == __name__: - get_custom_objects().update({name: obj}) + saving.get_custom_objects().update({name: obj}) + + +__all__ = get_module_objects(__name__) diff --git a/lib/model/optimizers.py b/lib/model/optimizers.py new file mode 100644 index 0000000000..835258ad28 --- /dev/null +++ b/lib/model/optimizers.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +""" Custom Optimizers for Torch/keras """ +from __future__ import annotations +import inspect +import logging +import sys +import typing as T + +from keras import ops, Optimizer, saving + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import KerasTensor, Variable + +logger = logging.getLogger(__name__) + + +class AdaBelief(Optimizer): # pylint:disable=too-many-instance-attributes,too-many-ancestors + """ Implementation of the AdaBelief Optimizer + + Inherits from: keras.optimizers.Optimizer. + + AdaBelief Optimizer is not a placement of the heuristic warmup, the settings should be kept if + warmup has already been employed and tuned in the baseline method. You can enable warmup by + setting `total_steps` and `warmup_proportion` (see examples) + + Lookahead (see references) can be integrated with AdaBelief Optimizer, which is announced by + Less Wright and the new combined optimizer can also be called "Ranger". The mechanism can be + enabled by using the lookahead wrapper. (See examples) + + Parameters + ---------- + learning_rate: `Tensor`, float or :class: `keras.optimizers.schedules.LearningRateSchedule` + The learning rate. + beta_1: float + The exponential decay rate for the 1st moment estimates. + beta_2: float + The exponential decay rate for the 2nd moment estimates. + epsilon: float + A small constant for numerical stability. + amsgrad: bool + Whether to apply AMSGrad variant of this algorithm from the paper "On the Convergence + of Adam and beyond". + rectify: bool + Whether to enable rectification as in RectifiedAdam + sma_threshold. float + The threshold for simple mean average. + total_steps: int + Total number of training steps. Enable warmup by setting a positive value. + warmup_proportion: float + The proportion of increasing steps. + min_lr: float + Minimum learning rate after warmup. + name: str, optional + Name for the operations created when applying gradients. Default: ``"AdaBeliefOptimizer"``. + **kwargs: dict + Standard Keras Optimizer keyword arguments. Allowed to be (`weight_decay`, `clipnorm`, + `clipvalue`, `global_clipnorm`, `use_ema`, `ema_momentum`, `ema_overwrite_frequency`, + `loss_scale_factor`, `gradient_accumulation_steps`) + + Examples + -------- + >>> from optimizers import AdaBelief + >>> opt = AdaBelief(lr=1e-3) + + Example of serialization: + + >>> optimizer = AdaBelief(learning_rate=lr_scheduler, weight_decay=wd_scheduler) + >>> config = keras.optimizers.serialize(optimizer) + >>> new_optimizer = keras.optimizers.deserialize(config, + ... custom_objects=dict(AdaBelief=AdaBelief)) + + Example of warm up: + + >>> opt = AdaBelief(lr=1e-3, total_steps=10000, warmup_proportion=0.1, min_lr=1e-5) + + In the above example, the learning rate will increase linearly from 0 to `lr` in 1000 steps, + then decrease linearly from `lr` to `min_lr` in 9000 steps. + + Example of enabling Lookahead: + + >>> adabelief = AdaBelief() + >>> ranger = tfa.optimizers.Lookahead(adabelief, sync_period=6, slow_step_size=0.5) + + Notes + ----- + `amsgrad` is not described in the original paper. Use it with caution. + + References + ---------- + Juntang Zhuang et al. - AdaBelief Optimizer: Adapting stepsizes by the belief in observed + gradients - https://arxiv.org/abs/2010.07468. + + Original implementation - https://github.com/juntang-zhuang/Adabelief-Optimizer + + Michael R. Zhang et.al - Lookahead Optimizer: k steps forward, 1 step back - + https://arxiv.org/abs/1907.08610v1 + + Adapted from https://github.com/juntang-zhuang/Adabelief-Optimizer + + BSD 2-Clause License + + Copyright (c) 2021, Juntang Zhuang + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + """ + + def __init__(self, # pylint:disable=too-many-arguments,too-many-positional-arguments + learning_rate: float = 0.001, + beta_1: float = 0.9, + beta_2: float = 0.999, + epsilon: float = 1e-14, + amsgrad: bool = False, + rectify: bool = True, + sma_threshold: float = 5.0, + total_steps: int = 0, + warmup_proportion: float = 0.1, + min_learning_rate: float = 0.0, + name="AdaBeliefOptimizer", + **kwargs): + logger.debug(parse_class_init(locals())) + super().__init__(learning_rate=learning_rate, name=name, **kwargs) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.amsgrad = amsgrad + self.rectify = rectify + self.sma_threshold = sma_threshold + # TODO change the following 2 to "warm_up_steps" + # TODO Make learning rate warm up a global option + # Or these params can be calculated from a user "warm_up_steps" parameter + self.total_steps = total_steps + self.warmup_proportion = warmup_proportion + self.min_learning_rate = min_learning_rate + logger.debug("Initialized %s", self.__class__.__name__) + + self._momentums: list[Variable] = [] + self._velocities: list[Variable] = [] + self._velocity_hats: list[Variable] = [] # Amsgrad only + + def build(self, variables: list[Variable]) -> None: + """Initialize optimizer variables. + + AdaBelief optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), + + Parameters + ---------- + variables: list[:class:`keras.Variable`] + list of model variables to build AdaBelief variables on. + """ + if self.built: + return + logger.debug("Building AdaBelief. var_list: %s", variables) + super().build(variables) + + for var in variables: + self._momentums.append(self.add_variable_from_reference( + reference_variable=var, name="momentum")) + self._velocities.append(self.add_variable_from_reference( + reference_variable=var, name="velocity")) + if self.amsgrad: + self._velocity_hats.append(self.add_variable_from_reference( + reference_variable=var, name="velocity_hat")) + logger.debug("Built AdaBelief. momentums: %s, velocities: %s, velocity_hats: %s", + len(self._momentums), len(self._velocities), len(self._velocity_hats)) + + def _maybe_warmup(self, learning_rate: KerasTensor, local_step: KerasTensor) -> KerasTensor: + """ Do learning rate warm up if requested + + Parameters + ---------- + learning_rate: :class:`keras.KerasTensor` + The learning rate + local_step: :class:`keras.KerasTensor` + The current training step + + Returns + ------- + :class:`keras.KerasTensor` + Either the original learning rate or adjusted learning rate if warmup is requested + """ + if self.total_steps <= 0: + return learning_rate + + total_steps = ops.cast(self.total_steps, learning_rate.dtype) + warmup_steps = total_steps * ops.cast(self.warmup_proportion, learning_rate.dtype) + min_lr = ops.cast(self.min_learning_rate, learning_rate.dtype) + decay_steps = ops.maximum(total_steps - warmup_steps, 1) + decay_rate = ops.divide(min_lr - learning_rate, decay_steps) + return ops.where(local_step <= warmup_steps, + ops.multiply(learning_rate, (ops.divide(local_step, warmup_steps))), + ops.multiply(learning_rate + decay_rate, + ops.minimum(local_step - warmup_steps, decay_steps))) + + def _maybe_rectify(self, + momentum: KerasTensor, + velocity: KerasTensor, + local_step: KerasTensor, + beta_2_power: KerasTensor) -> KerasTensor: + """ Apply rectification, if requested + + Parameters + ---------- + momentum: :class:`keras.KerasTensor` + The momentum update + velocity: :class:`keras.KerasTensor` + The velocity update + local_step: :class:`keras.KerasTensor` + The current training step + beta_2_power + Adjusted exponential decay rate for the 2nd moment estimates. + + Returns + ------- + :class:`keras.KerasTensor` + The standard or rectified update (if rectification enabled) + """ + if not self.rectify: + return ops.divide(momentum, ops.add(velocity, self.epsilon)) + + sma_inf = 2 / (1 - self.beta_2) - 1 + sma_t = sma_inf - 2 * local_step * beta_2_power / (1 - beta_2_power) + rect = ops.sqrt((sma_t - 4) / (sma_inf - 4) * + (sma_t - 2) / (sma_inf - 2) * + sma_inf / sma_t) + return ops.where(sma_t >= self.sma_threshold, + ops.divide( + ops.multiply(rect, momentum), + (ops.add(velocity, self.epsilon))), + momentum) + + def update_step(self, + gradient: KerasTensor, + variable: Variable, + learning_rate: Variable) -> None: + """Update step given gradient and the associated model variable for AdaBelief. + + Parameters + ---------- + gradient :class:`keras.KerasTensor` + The gradient to update + variable: :class:`keras.Variable` + The variable to update + learning_rate: :class:`keras.Variable` + The learning rate + """ + local_step = ops.cast(self.iterations + 1, variable.dtype) + learning_rate = self._maybe_warmup(ops.cast(learning_rate, variable.dtype), local_step) + gradient = ops.cast(gradient, variable.dtype) + beta_1_power = ops.power(ops.cast(self.beta_1, variable.dtype), local_step) + beta_2_power = ops.power(ops.cast(self.beta_2, variable.dtype), local_step) + + # m_t = b1 * m + (1 - b1) * g + # => m_t = m + (g - m) * (1 - b1) + momentum = self._momentums[self._get_variable_index(variable)] + self.assign_add(momentum, ops.multiply(ops.subtract(gradient, momentum), 1 - self.beta_1)) + momentum_corr = ops.divide(momentum, (1 - beta_1_power)) + + # v_t = b2 * v + (1 - b2) * (g - m_t)^2 + e + # => v_t = v + ((g - m_t)^2 - v) * (1 - b2) + e + velocity = self._velocities[self._get_variable_index(variable)] + self.assign_add(velocity, + ops.multiply( + ops.subtract(ops.square(gradient - momentum), velocity), + 1 - self.beta_2) + + self.epsilon) + + if self.amsgrad: + velocity_hat = self._velocity_hats[self._get_variable_index(variable)] + self.assign(velocity_hat, ops.maximum(velocity, velocity_hat)) + velocity_corr = ops.sqrt(ops.divide(velocity_hat, (1 - beta_2_power))) + else: + velocity_corr = ops.sqrt(ops.divide(velocity, (1 - beta_2_power))) + + var_t = self._maybe_rectify(momentum_corr, velocity_corr, local_step, beta_2_power) + + self.assign_sub(variable, ops.multiply(learning_rate, var_t)) + + def get_config(self) -> dict[str, T.Any]: + """ Returns the config of the optimizer. + + Optimizer configuration for AdaBelief. + + Returns + ------- + dict[str, Any] + The optimizer configuration. + """ + config = super().get_config() + config.update({"beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + "amsgrad": self.amsgrad, + "rectify": self.rectify, + "sma_threshold": self.sma_threshold, + "total_steps": self.total_steps, + "warmup_proportion": self.warmup_proportion, + "min_learning_rate": self.min_learning_rate}) + return config + + +# Update Optimizers into Keras custom objects +for _name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and obj.__module__ == __name__: + saving.get_custom_objects().update({_name: obj}) + + +__all__ = get_module_objects(__name__) diff --git a/lib/multithreading.py b/lib/multithreading.py index 55e9bb0cbc..e20862e643 100644 --- a/lib/multithreading.py +++ b/lib/multithreading.py @@ -1,124 +1,108 @@ #!/usr/bin/env python3 """ Multithreading/processing utils for faceswap """ - +from __future__ import annotations import logging -import multiprocessing as mp +import typing as T +from multiprocessing import cpu_count + import queue as Queue import sys import threading -from lib.logger import LOG_QUEUE, set_root_logger +from types import TracebackType -logger = logging.getLogger(__name__) # pylint: disable=invalid-name -_launched_processes = set() # pylint: disable=invalid-name +from lib.utils import get_module_objects +if T.TYPE_CHECKING: + from collections.abc import Callable, Generator -class PoolProcess(): - """ Pool multiple processes """ - def __init__(self, method, in_queue, out_queue, *args, processes=None, **kwargs): - self._name = method.__qualname__ - logger.debug("Initializing %s: (target: '%s', processes: %s)", - self.__class__.__name__, self._name, processes) +logger = logging.getLogger(__name__) +_ErrorType: T.TypeAlias = tuple[type[BaseException], + BaseException, + TracebackType] | tuple[T.Any, T.Any, T.Any] | None +_THREAD_NAMES: set[str] = set() - self.procs = self.set_procs(processes) - ctx = mp.get_context("spawn") - self.pool = ctx.Pool(processes=self.procs) - self._method = method - self._kwargs = self.build_target_kwargs(in_queue, out_queue, kwargs) - self._args = args +def total_cpus(): + """ Return total number of cpus """ + return cpu_count() - logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name) - @staticmethod - def build_target_kwargs(in_queue, out_queue, kwargs): - """ Add standard kwargs to passed in kwargs list """ - kwargs["log_init"] = set_root_logger - kwargs["log_queue"] = LOG_QUEUE - kwargs["in_queue"] = in_queue - kwargs["out_queue"] = out_queue - return kwargs - - def set_procs(self, processes): - """ Set the number of processes to use """ - if processes is None: - running_processes = len(mp.active_children()) - processes = max(mp.cpu_count() - running_processes, 1) - logger.verbose("Processing '%s' in %s processes", self._name, processes) - return processes - - def start(self): - """ Run the processing pool """ - logging.debug("Pooling Processes: (target: '%s', args: %s, kwargs: %s)", - self._name, self._args, self._kwargs) - for idx in range(self.procs): - logger.debug("Adding process %s of %s to mp.Pool '%s'", - idx + 1, self.procs, self._name) - self.pool.apply_async(self._method, args=self._args, kwds=self._kwargs) - logging.debug("Pooled Processes: '%s'", self._name) - - def join(self): - """ Join the process """ - logger.debug("Joining Pooled Process: '%s'", self._name) - self.pool.close() - self.pool.join() - logger.debug("Joined Pooled Process: '%s'", self._name) - - -class SpawnProcess(mp.context.SpawnProcess): - """ Process in spawnable context - Must be spawnable to share CUDA across processes """ - def __init__(self, target, in_queue, out_queue, *args, **kwargs): - name = target.__qualname__ - logger.debug("Initializing %s: (target: '%s', args: %s, kwargs: %s)", - self.__class__.__name__, name, args, kwargs) - ctx = mp.get_context("spawn") - self.event = ctx.Event() - kwargs = self.build_target_kwargs(in_queue, out_queue, kwargs) - super().__init__(target=target, name=name, args=args, kwargs=kwargs) - self.daemon = True - logger.debug("Initialized %s: '%s'", self.__class__.__name__, name) - - def build_target_kwargs(self, in_queue, out_queue, kwargs): - """ Add standard kwargs to passed in kwargs list """ - kwargs["event"] = self.event - kwargs["log_init"] = set_root_logger - kwargs["log_queue"] = LOG_QUEUE - kwargs["in_queue"] = in_queue - kwargs["out_queue"] = out_queue - return kwargs - - def start(self): - """ Add logging to start function """ - logger.debug("Spawning Process: (name: '%s', args: %s, kwargs: %s, daemon: %s)", - self._name, self._args, self._kwargs, self.daemon) - super().start() - _launched_processes.add(self) - logger.debug("Spawned Process: (name: '%s', PID: %s)", self._name, self.pid) - - def join(self, timeout=None): - """ Add logging to join function """ - logger.debug("Joining Process: (name: '%s', PID: %s)", self._name, self.pid) - super().join(timeout=timeout) - _launched_processes.remove(self) - logger.debug("Joined Process: (name: '%s', PID: %s)", self._name, self.pid) +def _get_name(name: str) -> str: + """ Obtain a unique name for a thread + + Parameters + ---------- + name: str + The requested name + + Returns + ------- + str + The request name with "_#" appended (# being an integer) making the name unique + """ + idx = 0 + real_name = name + while True: + if real_name in _THREAD_NAMES: + real_name = f"{name}_{idx}" + idx += 1 + continue + _THREAD_NAMES.add(real_name) + return real_name class FSThread(threading.Thread): - """ Subclass of thread that passes errors back to parent """ - def __init__(self, group=None, target=None, name=None, # pylint: disable=too-many-arguments - args=(), kwargs=None, *, daemon=None): - super().__init__(group=group, target=target, name=name, - args=args, kwargs=kwargs, daemon=daemon) - self.err = None - - def run(self): + """ Subclass of thread that passes errors back to parent + + Parameters + ---------- + target: callable object, Optional + The callable object to be invoked by the run() method. If ``None`` nothing is called. + Default: ``None`` + name: str, optional + The thread name. if ``None`` a unique name is constructed of the form "Thread-N" where N + is a small decimal number. Default: ``None`` + args: tuple + The argument tuple for the target invocation. Default: (). + kwargs: dict + keyword arguments for the target invocation. Default: {}. + """ + _target: Callable + _args: tuple + _kwargs: dict[str, T.Any] + _name: str + + def __init__(self, + target: Callable | None = None, + name: str | None = None, + args: tuple = (), + kwargs: dict[str, T.Any] | None = None, + *, + daemon: bool | None = None) -> None: + super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) + self.err: _ErrorType = None + + def check_and_raise_error(self) -> None: + """ Checks for errors in thread and raises them in caller. + + Raises + ------ + Error + Re-raised error from within the thread + """ + if not self.err: + return + logger.debug("Thread error caught: %s", self.err) + raise self.err[1].with_traceback(self.err[2]) + + def run(self) -> None: + """ Runs the target, reraising any errors from within the thread in the caller. """ try: - if self._target: + if self._target is not None: self._target(*self._args, **self._kwargs) - except Exception: # pylint: disable=broad-except + except Exception as err: # pylint:disable=broad-except self.err = sys.exc_info() - logger.debug("Error in thread (%s): %s", self._name, - self.err[1].with_traceback(self.err[2])) + logger.debug("Error in thread (%s): %s", self._name, str(err)) finally: # Avoid a refcycle if the thread is running a function with # an argument that has a member that points to the thread. @@ -126,36 +110,85 @@ def run(self): class MultiThread(): - """ Threading for IO heavy ops - Catches errors in thread and rethrows to parent """ - def __init__(self, target, *args, thread_count=1, name=None, **kwargs): - self._name = name if name else target.__name__ + """ Threading for IO heavy ops. Catches errors in thread and rethrows to parent. + + Parameters + ---------- + target: callable object + The callable object to be invoked by the run() method. + args: tuple + The argument tuple for the target invocation. Default: (). + thread_count: int, optional + The number of threads to use. Default: 1 + name: str, optional + The thread name. if ``None`` a unique name is constructed of the form {target.__name__}_N + where N is an incrementing integer. Default: ``None`` + kwargs: dict + keyword arguments for the target invocation. Default: {}. + """ + def __init__(self, + target: Callable, + *args, + thread_count: int = 1, + name: str | None = None, + **kwargs) -> None: + self._name = _get_name(name if name else target.__name__) logger.debug("Initializing %s: (target: '%s', thread_count: %s)", self.__class__.__name__, self._name, thread_count) - logger.trace("args: %s, kwargs: %s", args, kwargs) + logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore self.daemon = True self._thread_count = thread_count - self._threads = list() + self._threads: list[FSThread] = [] self._target = target self._args = args self._kwargs = kwargs logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name) @property - def has_error(self): - """ Return true if a thread has errored, otherwise false """ + def has_error(self) -> bool: + """ bool: ``True`` if a thread has errored, otherwise ``False`` """ return any(thread.err for thread in self._threads) @property - def errors(self): - """ Return a list of thread errors """ - return [thread.err for thread in self._threads] + def errors(self) -> list[_ErrorType]: + """ list: List of thread error values """ + return [thread.err for thread in self._threads if thread.err] - def start(self): - """ Start a thread with the given method and args """ + @property + def name(self) -> str: + """ :str: The name of the thread """ + return self._name + + def check_and_raise_error(self) -> None: + """ Checks for errors in thread and raises them in caller. + + Raises + ------ + Error + Re-raised error from within the thread + """ + if not self.has_error: + return + logger.debug("Thread error caught: %s", self.errors) + error = self.errors[0] + assert error is not None + raise error[1].with_traceback(error[2]) + + def is_alive(self) -> bool: + """ Check if any threads are still alive + + Returns + ------- + bool + ``True`` if any threads are alive. ``False`` if no threads are alive + """ + return any(thread.is_alive() for thread in self._threads) + + def start(self) -> None: + """ Start all the threads for the given method, args and kwargs """ logger.debug("Starting thread(s): '%s'", self._name) for idx in range(self._thread_count): - name = "{}_{}".format(self._name, idx) + name = self._name if self._thread_count == 1 else f"{self._name}_{idx}" logger.debug("Starting thread %s of %s: '%s'", idx + 1, self._thread_count, name) thread = FSThread(name=name, @@ -167,58 +200,106 @@ def start(self): self._threads.append(thread) logger.debug("Started all threads '%s': %s", self._name, len(self._threads)) - def join(self): - """ Join the running threads, catching and re-raising any errors """ + def completed(self) -> bool: + """ Check if all threads have completed + + Returns + ------- + ``True`` if all threads have completed otherwise ``False`` + """ + retval = all(not thread.is_alive() for thread in self._threads) + logger.debug(retval) + return retval + + def join(self) -> None: + """ Join the running threads, catching and re-raising any errors + + Clear the list of threads for class instance re-use + """ logger.debug("Joining Threads: '%s'", self._name) for thread in self._threads: - logger.debug("Joining Thread: '%s'", thread._name) # pylint: disable=protected-access + logger.debug("Joining Thread: '%s'", thread._name) # pylint:disable=protected-access thread.join() if thread.err: logger.error("Caught exception in thread: '%s'", - thread._name) # pylint: disable=protected-access + thread._name) # pylint:disable=protected-access raise thread.err[1].with_traceback(thread.err[2]) + del self._threads + self._threads = [] logger.debug("Joined all Threads: '%s'", self._name) -class BackgroundGenerator(threading.Thread): - """ Run a queue in the background. From: - https://stackoverflow.com/questions/7323664/ """ - # See below why prefetch count is flawed - def __init__(self, generator, prefetch=1): - threading.Thread.__init__(self) - self.queue = Queue.Queue(maxsize=prefetch) +class BackgroundGenerator(MultiThread): + """ Run a task in the background background and queue data for consumption + + Parameters + ---------- + generator: iterable + The generator to run in the background + prefetch, int, optional + The number of items to pre-fetch from the generator before blocking (see Notes). Default: 1 + name: str, optional + The thread name. if ``None`` a unique name is constructed of the form + {generator.__name__}_N where N is an incrementing integer. Default: ``None`` + args: tuple, Optional + The argument tuple for generator invocation. Default: ``None``. + kwargs: dict, Optional + keyword arguments for the generator invocation. Default: ``None``. + + Notes + ----- + Putting to the internal queue only blocks if put is called while queue has already + reached max size. Therefore this means prefetch is actually 1 more than the parameter + supplied (N in the queue, one waiting for insertion) + + References + ---------- + https://stackoverflow.com/questions/7323664/ + """ + def __init__(self, + generator: Callable, + prefetch: int = 1, + name: str | None = None, + args: tuple | None = None, + kwargs: dict[str, T.Any] | None = None) -> None: + super().__init__(name=name, target=self._run) + self.queue: Queue.Queue = Queue.Queue(prefetch) self.generator = generator - self.daemon = True + self._gen_args = args or tuple() + self._gen_kwargs = kwargs or {} self.start() - def run(self): - """ Put until queue size is reached. - Note: put blocks only if put is called while queue has already - reached max size => this makes 2 prefetched items! One in the - queue, one waiting for insertion! """ - for item in self.generator: - self.queue.put(item) - self.queue.put(None) - - def iterator(self): - """ Iterate items out of the queue """ + def _run(self) -> None: + """ Run the :attr:`_generator` and put into the queue until until queue size is reached. + + Raises + ------ + Exception + If there is a failure to run the generator and put to the queue + """ + try: + for item in self.generator(*self._gen_args, **self._gen_kwargs): + self.queue.put(item) + self.queue.put(None) + except Exception: + self.queue.put(None) + raise + + def iterator(self) -> Generator: + """ Iterate items out of the queue + + Yields + ------ + Any + The items from the generator + """ while True: next_item = self.queue.get() - if next_item is None: + self.check_and_raise_error() + if next_item is None or next_item == "EOF": + logger.debug("Got EOF OR NONE in BackgroundGenerator") break yield next_item -def terminate_processes(): - """ Join all active processes on unexpected shutdown - - If the process is doing long running work, make sure you - have a mechanism in place to terminate this work to avoid - long blocks - """ - logger.debug("Processes to join: %s", [process.name - for process in _launched_processes - if process.is_alive()]) - for process in list(_launched_processes): - if process.is_alive(): - process.join() +__all__ = get_module_objects(__name__) diff --git a/lib/queue_manager.py b/lib/queue_manager.py index 6acbfd05cb..1dfee26e9f 100644 --- a/lib/queue_manager.py +++ b/lib/queue_manager.py @@ -5,101 +5,178 @@ a multiprocess on a Windows System it will break Faceswap""" import logging -import multiprocessing as mp -import sys import threading -from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa +from queue import Queue, Empty as QueueEmpty # pylint:disable=unused-import; # noqa from time import sleep -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +from lib.utils import get_module_objects +logger = logging.getLogger(__name__) -class QueueManager(): - """ Manage queues for availabilty across processes - Don't import this class directly, instead - import the variable: queue_manager """ - def __init__(self): - logger.debug("Initializing %s", self.__class__.__name__) - # Hacky fix to stop multiprocessing spawning managers in child processes - if mp.current_process().name == "MainProcess": - # Use a Multiprocessing manager in main process - self.manager = mp.Manager() - else: - # Use a standard mp.queue in child process. NB: This will never be used - # but spawned processes will load this module, so we need to dummy in a queue - self.manager = mp - self.shutdown = self.manager.Event() - self.queues = dict() - # Despite launching a subprocess, the scripts still want to access the same logging - # queue as the GUI, so make sure the GUI gets it's own queue - self._log_queue = self.manager.Queue() if "gui" not in sys.argv else mp.Queue() - logger.debug("Initialized %s", self.__class__.__name__) +class EventQueue(Queue): + """ Standard Queue object with a separate global shutdown parameter indicating that the main + process, and by extension this queue, should be shut down. + + Parameters + ---------- + shutdown_event: :class:`threading.Event` + The global shutdown event common to all managed queues + maxsize: int, Optional + Upperbound limit on the number of items that can be placed in the queue. Default: `0` + """ + def __init__(self, shutdown_event: threading.Event, maxsize: int = 0) -> None: + super().__init__(maxsize=maxsize) + self._shutdown = shutdown_event - def add_queue(self, name, maxsize=0, multiprocessing_queue=True): - """ Add a queue to the manager + @property + def shutdown_event(self) -> threading.Event: + """ :class:`threading.Event`: The global shutdown event """ + return self._shutdown - Adds an event "shutdown" to the queue that can be used to indicate - to a process that any activity on the queue should cease """ - logger.debug("QueueManager adding: (name: '%s', maxsize: %s)", name, maxsize) - if name in self.queues.keys(): - raise ValueError("Queue '{}' already exists.".format(name)) +class _QueueManager(): + """ Manage :class:`EventQueue` objects for availabilty across processes. + + Notes + ----- + Don't import this class directly, instead import via :func:`queue_manager` """ + def __init__(self) -> None: + logger.debug("Initializing %s", self.__class__.__name__) - if multiprocessing_queue: - queue = self.manager.Queue(maxsize=maxsize) - else: - queue = Queue(maxsize=maxsize) + self.shutdown = threading.Event() + self.queues: dict[str, EventQueue] = {} + logger.debug("Initialized %s", self.__class__.__name__) - setattr(queue, "shutdown", self.shutdown) - self.queues[name] = queue + def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str: + """ Add a :class:`EventQueue` to the manager. + + Parameters + ---------- + name: str + The name of the queue to create + maxsize: int, optional + The maximum queue size. Set to `0` for unlimited. Default: `0` + create_new: bool, optional + If a queue of the given name exists, and this value is ``False``, then an error is + raised preventing the creation of duplicate queues. If this value is ``True`` and + the given name exists then an integer is appended to the end of the queue name and + incremented until the given name is unique. Default: ``False`` + + Returns + ------- + str + The final generated name for the queue + """ + logger.debug("QueueManager adding: (name: '%s', maxsize: %s, create_new: %s)", + name, maxsize, create_new) + if not create_new and name in self.queues: + raise ValueError(f"Queue '{name}' already exists.") + if create_new and name in self.queues: + i = 0 + while name in self.queues: + name = f"{name}{i}" + logger.debug("Duplicate queue name. Updated to: '%s'", name) + + self.queues[name] = EventQueue(self.shutdown, maxsize=maxsize) logger.debug("QueueManager added: (name: '%s')", name) + return name + + def del_queue(self, name: str) -> None: + """ Remove a queue from the manager - def del_queue(self, name): - """ remove a queue from the manager """ + Parameters + ---------- + name: str + The name of the queue to be deleted. Must exist within the queue manager. + """ logger.debug("QueueManager deleting: '%s'", name) del self.queues[name] logger.debug("QueueManager deleted: '%s'", name) - def get_queue(self, name, maxsize=0): - """ Return a queue from the manager - If it doesn't exist, create it """ + def get_queue(self, name: str, maxsize: int = 0) -> EventQueue: + """ Return a :class:`EventQueue` from the manager. If it doesn't exist, create it. + + Parameters + ---------- + name: str + The name of the queue to obtain + maxsize: int, Optional + The maximum queue size. Set to `0` for unlimited. Only used if the requested queue + does not already exist. Default: `0` + """ logger.debug("QueueManager getting: '%s'", name) - queue = self.queues.get(name, None) + queue = self.queues.get(name) if not queue: self.add_queue(name, maxsize) queue = self.queues[name] logger.debug("QueueManager got: '%s'", name) return queue - def terminate_queues(self): - """ Set shutdown event, clear and send EOF to all queues - To be called if there is an error """ + def terminate_queues(self) -> None: + """ Terminates all managed queues. + + Sets the global shutdown event, clears and send EOF to all queues. To be called if there + is an error """ logger.debug("QueueManager terminating all queues") self.shutdown.set() + self._flush_queues() for q_name, queue in self.queues.items(): logger.debug("QueueManager terminating: '%s'", q_name) - while not queue.empty(): - queue.get(True, 1) queue.put("EOF") logger.debug("QueueManager terminated all queues") - def debug_monitor(self, update_secs=2): - """ Debug tool for monitoring queues """ - thread = threading.Thread(target=self.debug_queue_sizes, - args=(update_secs, )) + def _flush_queues(self): + """ Empty out the contents of every managed queue. """ + for q_name in self.queues: + self.flush_queue(q_name) + logger.debug("QueueManager flushed all queues") + + def flush_queue(self, name: str) -> None: + """ Flush the contents from a managed queue. + + Parameters + ---------- + name: str + The name of the managed :class:`EventQueue` to flush + """ + logger.debug("QueueManager flushing: '%s'", name) + queue = self.queues[name] + while not queue.empty(): + queue.get(True, 1) + + def debug_monitor(self, update_interval: int = 2) -> None: + """ A debug tool for monitoring managed :class:`EventQueues`. + + Prints queue sizes to the console for all managed queues. + + Parameters + ---------- + update_interval: int, Optional + The number of seconds between printing information to the console. Default: 2 + """ + thread = threading.Thread(target=self._debug_queue_sizes, + args=(update_interval, )) thread.daemon = True thread.start() - def debug_queue_sizes(self, update_secs): - """ Output the queue sizes - logged to INFO so it also displays in console + def _debug_queue_sizes(self, update_interval) -> None: + """ Print the queue size for each managed queue to console. + + Parameters + ---------- + update_interval: int + The number of seconds between printing information to the console """ while True: + logger.info("====================================================") for name in sorted(self.queues.keys()): logger.info("%s: %s", name, self.queues[name].qsize()) - sleep(update_secs) + sleep(update_interval) + + +queue_manager = _QueueManager() # pylint:disable=invalid-name -queue_manager = QueueManager() # pylint: disable=invalid-name +__all__ = get_module_objects(__name__) diff --git a/lib/serializer.py b/lib/serializer.py new file mode 100644 index 0000000000..0ab2277440 --- /dev/null +++ b/lib/serializer.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +""" +Library for serializing python objects to and from various different serializer formats +""" + +import json +import logging +import os +import pickle +import zlib + +from io import BytesIO + +import numpy as np + +from lib.utils import FaceswapError, get_module_objects + +try: + import yaml + _HAS_YAML = True +except ImportError: + _HAS_YAML = False + +logger = logging.getLogger(__name__) + + +class Serializer(): + """ A convenience class for various serializers. + + This class should not be called directly as it acts as the parent for various serializers. + All serializers should be called from :func:`get_serializer` or + :func:`get_serializer_from_filename` + + Example + ------- + >>> from lib.serializer import get_serializer + >>> serializer = get_serializer('json') + >>> json_file = '/path/to/json/file.json' + >>> data = serializer.load(json_file) + >>> serializer.save(json_file, data) + + """ + def __init__(self): + self._file_extension = None + self._write_option = "wb" + self._read_option = "rb" + + @property + def file_extension(self): + """ str: The file extension of the serializer """ + return self._file_extension + + def save(self, filename, data): + """ Serialize data and save to a file + + Parameters + ---------- + filename: str + The path to where the serialized file should be saved + data: varies + The data that is to be serialized to file + + Example + ------ + >>> serializer = get_serializer('json') + >>> data ['foo', 'bar'] + >>> json_file = '/path/to/json/file.json' + >>> serializer.save(json_file, data) + """ + logger.debug("filename: %s, data type: %s", filename, type(data)) + filename = self._check_extension(filename) + try: + with open(filename, self._write_option) as s_file: + s_file.write(self.marshal(data)) + except IOError as err: + msg = f"Error writing to '{filename}': {err.strerror}" + raise FaceswapError(msg) from err + + def _check_extension(self, filename): + """ Check the filename has an extension. If not add the correct one for the serializer """ + extension = os.path.splitext(filename)[1] + retval = filename if extension else f"{filename}.{self.file_extension}" + logger.debug("Original filename: '%s', final filename: '%s'", filename, retval) + return retval + + def load(self, filename): + """ Load data from an existing serialized file + + Parameters + ---------- + filename: str + The path to the serialized file + + Returns + ---------- + data: varies + The data in a python object format + + Example + ------ + >>> serializer = get_serializer('json') + >>> json_file = '/path/to/json/file.json' + >>> data = serializer.load(json_file) + """ + logger.debug("filename: %s", filename) + try: + with open(filename, self._read_option) as s_file: + data = s_file.read() + logger.debug("stored data type: %s", type(data)) + retval = self.unmarshal(data) + + except IOError as err: + msg = f"Error reading from '{filename}': {err.strerror}" + raise FaceswapError(msg) from err + logger.debug("data type: %s", type(retval)) + return retval + + def marshal(self, data): + """ Serialize an object + + Parameters + ---------- + data: varies + The data that is to be serialized + + Returns + ------- + data: varies + The data in a the serialized data format + + Example + ------ + >>> serializer = get_serializer('json') + >>> data ['foo', 'bar'] + >>> json_data = serializer.marshal(data) + """ + logger.debug("data type: %s", type(data)) + try: + retval = self._marshal(data) + except Exception as err: + msg = f"Error serializing data for type {type(data)}: {str(err)}" + raise FaceswapError(msg) from err + logger.debug("returned data type: %s", type(retval)) + return retval + + def unmarshal(self, serialized_data): + """ Unserialize data to its original object type + + Parameters + ---------- + serialized_data: varies + Data in serializer format that is to be unmarshalled to its original object + + Returns + ------- + data: varies + The data in a python object format + + Example + ------ + >>> serializer = get_serializer('json') + >>> json_data = + >>> data = serializer.unmarshal(json_data) + """ + logger.debug("data type: %s", type(serialized_data)) + try: + retval = self._unmarshal(serialized_data) + except Exception as err: + msg = f"Error unserializing data for type {type(serialized_data)}: {str(err)}" + raise FaceswapError(msg) from err + logger.debug("returned data type: %s", type(retval)) + return retval + + def _marshal(self, data): + """ Override for serializer specific marshalling """ + raise NotImplementedError() + + def _unmarshal(self, data): + """ Override for serializer specific unmarshalling """ + raise NotImplementedError() + + +class _YAMLSerializer(Serializer): + """ YAML Serializer """ + def __init__(self): + super().__init__() + self._file_extension = "yml" + + def _marshal(self, data): + return yaml.dump(data, default_flow_style=False).encode("utf-8") + + def _unmarshal(self, data): + return yaml.load(data.decode("utf-8", errors="replace"), Loader=yaml.FullLoader) + + +class _JSONSerializer(Serializer): + """ JSON Serializer """ + def __init__(self): + super().__init__() + self._file_extension = "json" + + def _marshal(self, data): + return json.dumps(data, indent=2).encode("utf-8") + + def _unmarshal(self, data): + return json.loads(data.decode("utf-8", errors="replace")) + + +class _PickleSerializer(Serializer): + """ Pickle Serializer """ + def __init__(self): + super().__init__() + self._file_extension = "pickle" + + def _marshal(self, data): + return pickle.dumps(data) + + def _unmarshal(self, data): + return pickle.loads(data) + + +class _NPYSerializer(Serializer): + """ NPY Serializer """ + def __init__(self): + super().__init__() + self._file_extension = "npy" + self._bytes = BytesIO() + + def _marshal(self, data): + """ NPY Marshal to bytesIO so standard bytes writer can write out """ + b_handler = BytesIO() + np.save(b_handler, data) + b_handler.seek(0) + return b_handler.read() + + def _unmarshal(self, data): + """ NPY Unmarshal to bytesIO so we can use numpy loader """ + b_handler = BytesIO(data) + retval = np.load(b_handler) + del b_handler + if retval.dtype == "object": + retval = retval[()] + return retval + + +class _CompressedSerializer(Serializer): + """ A compressed pickle serializer for Faceswap """ + def __init__(self): + super().__init__() + self._file_extension = "fsa" + self._child = get_serializer("pickle") + + def _marshal(self, data): + """ Pickle and compress data """ + data = self._child._marshal(data) # pylint:disable=protected-access + return zlib.compress(data) + + def _unmarshal(self, data): + """ Decompress and unpicke data """ + data = zlib.decompress(data) + return self._child._unmarshal(data) # pylint:disable=protected-access + + +def get_serializer(serializer): + """ Obtain a serializer object + + Parameters + ---------- + serializer: {'json', 'pickle', yaml', 'npy', 'compressed'} + The required serializer format + + Returns + ------- + serializer: :class:`Serializer` + A serializer object for handling the requested data format + + Example + ------- + >>> serializer = get_serializer('json') + """ + retval = None + if serializer.lower() == "npy": + retval = _NPYSerializer() + elif serializer.lower() == "compressed": + retval = _CompressedSerializer() + elif serializer.lower() == "json": + retval = _JSONSerializer() + elif serializer.lower() == "pickle": + retval = _PickleSerializer() + elif serializer.lower() == "yaml" and _HAS_YAML: + retval = _YAMLSerializer() + elif serializer.lower() == "yaml": + logger.warning("You must have PyYAML installed to use YAML as the serializer." + "Switching to JSON as the serializer.") + retval = _JSONSerializer + else: + logger.warning("Unrecognized serializer: '%s'. Returning json serializer", serializer) + logger.debug(retval) + return retval + + +def get_serializer_from_filename(filename): + """ Obtain a serializer object from a filename + + Parameters + ---------- + filename: str + Filename to determine the serializer type from + + Returns + ------- + serializer: :class:`Serializer` + A serializer object for handling the requested data format + + Example + ------- + >>> filename = '/path/to/json/file.json' + >>> serializer = get_serializer_from_filename(filename) + """ + logger.debug("filename: '%s'", filename) + extension = os.path.splitext(filename)[1].lower() + logger.debug("extension: '%s'", extension) + + if extension == ".json": + retval = _JSONSerializer() + elif extension in (".p", ".pickle"): + retval = _PickleSerializer() + elif extension == ".npy": + retval = _NPYSerializer() + elif extension == ".fsa": + retval = _CompressedSerializer() + elif extension in (".yaml", ".yml") and _HAS_YAML: + retval = _YAMLSerializer() + elif extension in (".yaml", ".yml"): + logger.warning("You must have PyYAML installed to use YAML as the serializer.\n" + "Switching to JSON as the serializer.") + retval = _JSONSerializer() + else: + logger.warning("Unrecognized extension: '%s'. Returning json serializer", extension) + retval = _JSONSerializer() + logger.debug(retval) + return retval + + +__all__ = get_module_objects(__name__) diff --git a/lib/sysinfo.py b/lib/sysinfo.py deleted file mode 100644 index 411b000e1f..0000000000 --- a/lib/sysinfo.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin python3 -""" Obtain information about the running system, environment and gpu """ - -import locale -import os -import platform -import sys -from subprocess import PIPE, Popen - -import psutil - -from lib.gpu_stats import GPUStats - - -class SysInfo(): - """ System and Python Information """ - # pylint: disable=too-many-instance-attributes,too-many-public-methods - - def __init__(self): - gpu_stats = GPUStats(log=False) - - self.platform = platform.platform() - self.system = platform.system() - self.machine = platform.machine() - self.release = platform.release() - self.processor = platform.processor() - self.cpu_count = os.cpu_count() - self.py_implementation = platform.python_implementation() - self.py_version = platform.python_version() - self._cuda_path = self.get_cuda_path() - self.vram = gpu_stats.vram - self.gfx_driver = gpu_stats.driver - self.gfx_devices = gpu_stats.devices - - @property - def encoding(self): - """ Return system preferred encoding """ - return locale.getpreferredencoding() - - @property - def is_conda(self): - """ Boolean for whether in a conda environment """ - return "conda" in sys.version.lower() - - @property - def is_linux(self): - """ Boolean for whether system is Linux """ - return self.system.lower() == "linux" - - @property - def is_macos(self): - """ Boolean for whether system is macOS """ - return self.system.lower() == "darwin" - - @property - def is_windows(self): - """ Boolean for whether system is Windows """ - return self.system.lower() == "windows" - - @property - def is_virtual_env(self): - """ Boolean for whether running in a virtual environment """ - if not self.is_conda: - retval = (hasattr(sys, "real_prefix") or - (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)) - else: - prefix = os.path.dirname(sys.prefix) - retval = (os.path.basename(prefix) == "envs") - return retval - - @property - def ram(self): - """ Return RAM stats """ - return psutil.virtual_memory() - - @property - def ram_free(self): - """ return free RAM """ - return getattr(self.ram, "free") - - @property - def ram_total(self): - """ return total RAM """ - return getattr(self.ram, "total") - - @property - def ram_available(self): - """ return available RAM """ - return getattr(self.ram, "available") - - @property - def ram_used(self): - """ return used RAM """ - return getattr(self.ram, "used") - - @property - def fs_command(self): - """ Return the executed faceswap command """ - return " ".join(sys.argv) - - @property - def installed_pip(self): - """ Installed pip packages """ - pip = Popen("{} -m pip freeze".format(sys.executable), - shell=True, stdout=PIPE) - installed = pip.communicate()[0].decode().splitlines() - return "\n".join(installed) - - @property - def installed_conda(self): - """ Installed Conda packages """ - if not self.is_conda: - return None - conda = Popen("conda list", shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = conda.communicate() - if stderr: - return "Could not get package list" - installed = stdout.decode().splitlines() - return "\n".join(installed) - - @property - def conda_version(self): - """ Get conda version """ - if not self.is_conda: - return "N/A" - conda = Popen("conda --version", shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = conda.communicate() - if stderr: - return "Conda is used, but version not found" - version = stdout.decode().splitlines() - return "\n".join(version) - - @property - def git_branch(self): - """ Get the current git branch """ - git = Popen("git status", shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = git.communicate() - if stderr: - return "Not Found" - branch = stdout.decode().splitlines()[0].replace("On branch ", "") - return branch - - @property - def git_commits(self): - """ Get last 5 git commits """ - git = Popen("git log --pretty=oneline --abbrev-commit -n 5", - shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = git.communicate() - if stderr: - return "Not Found" - commits = stdout.decode().splitlines() - return ". ".join(commits) - - @property - def cuda_version(self): - """ Get the installed CUDA version """ - if self.is_linux: - version = self.cuda_version_linux() - elif self.is_windows: - version = self.cuda_version_windows() - else: - version = "Unsupported OS" - return version - - @property - def cudnn_version(self): - """ Get the installed cuDNN version """ - if not self._cuda_path: - retval = "Not Found" - if self.is_conda: - retval += ". Check Conda packages for Conda cuDNN" - return retval - cudnn_checkfile = os.path.join(self._cuda_path, "include", "cudnn.h") - if not os.path.isfile(cudnn_checkfile): - retval = "Not Found" - if self.is_conda: - retval += ". Check Conda packages for Conda cuDNN" - return retval - found = 0 - with open(cudnn_checkfile, "r") as ofile: - for line in ofile: - if line.lower().startswith("#define cudnn_major"): - major = line[line.rfind(" ") + 1:].strip() - found += 1 - elif line.lower().startswith("#define cudnn_minor"): - minor = line[line.rfind(" ") + 1:].strip() - found += 1 - elif line.lower().startswith("#define cudnn_patchlevel"): - patchlevel = line[line.rfind(" ") + 1:].strip() - found += 1 - if found == 3: - break - if found != 3: - retval = "Not Found" - if self.is_conda: - retval += ". Check Conda packages for Conda cuDNN" - return retval - return "{}.{}.{}".format(major, minor, patchlevel) - - def get_cuda_path(self): - """ Return the correct CUDA Path """ - if self.is_linux: - path = self.cuda_path_linux() - elif self.is_windows: - path = self.cuda_path_windows() - else: - path = None - return path - - @staticmethod - def cuda_path_linux(): - """ Get the path to Cuda on linux systems """ - ld_library_path = os.environ.get("LD_LIBRARY_PATH", None) - chk = os.popen("ldconfig -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read() - if ld_library_path and not chk: - paths = ld_library_path.split(":") - for path in paths: - chk = os.popen("ls {} | grep -P -o \"libcudart.so.\\d+.\\d+\" | " - "head -n 1".format(path)).read() - if chk: - break - if not chk: - return None - return chk[chk.find("=>") + 3:chk.find("targets") - 1] - - @staticmethod - def cuda_path_windows(): - """ Get the path to Cuda on Windows systems """ - cuda_path = os.environ.get("CUDA_PATH", None) - return cuda_path - - def cuda_version_linux(self): - """ Get CUDA version for linux systems """ - ld_library_path = os.environ.get("LD_LIBRARY_PATH", None) - chk = os.popen("ldconfig -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read() - if ld_library_path and not chk: - paths = ld_library_path.split(":") - for path in paths: - chk = os.popen("ls {} | grep -P -o \"libcudart.so.\\d+.\\d+\" | " - "head -n 1".format(path)).read() - if chk: - break - if not chk: - retval = "Not Found" - if self.is_conda: - retval += ". Check Conda packages for Conda Cuda" - return retval - cudavers = chk.strip().replace("libcudart.so.", "") - return cudavers[:cudavers.find(" ")] - - def cuda_version_windows(self): - """ Get CUDA version for Windows systems """ - cuda_keys = [key - for key in os.environ.keys() - if key.lower().startswith("cuda_path_v")] - if not cuda_keys: - retval = "Not Found" - if self.is_conda: - retval += ". Check Conda packages for Conda Cuda" - return retval - cudavers = [key.replace("CUDA_PATH_V", "").replace("_", ".") for key in cuda_keys] - return " ".join(cudavers) - - def full_info(self): - """ Format system info human readable """ - retval = "\n============ System Information ============\n" - sys_info = {"os_platform": self.platform, - "os_machine": self.machine, - "os_release": self.release, - "py_conda_version": self.conda_version, - "py_implementation": self.py_implementation, - "py_version": self.py_version, - "py_command": self.fs_command, - "py_virtual_env": self.is_virtual_env, - "sys_cores": self.cpu_count, - "sys_processor": self.processor, - "sys_ram": self.format_ram(), - "encoding": self.encoding, - "git_branch": self.git_branch, - "git_commits": self.git_commits, - "gpu_cuda": self.cuda_version, - "gpu_cudnn": self.cudnn_version, - "gpu_driver": self.gfx_driver, - "gpu_devices": ", ".join(["GPU_{}: {}".format(idx, device) - for idx, device in enumerate(self.gfx_devices)]), - "gpu_vram": ", ".join(["GPU_{}: {}MB".format(idx, int(vram)) - for idx, vram in enumerate(self.vram)])} - for key in sorted(sys_info.keys()): - retval += ("{0: <18} {1}\n".format(key + ":", sys_info[key])) - retval += "\n=============== Pip Packages ===============\n" - retval += self.installed_pip - if not self.is_conda: - return retval - retval += "\n\n============== Conda Packages ==============\n" - retval += self.installed_conda - return retval - - def format_ram(self): - """ Format the RAM stats for human output """ - retval = list() - for name in ("total", "available", "used", "free"): - value = getattr(self, "ram_{}".format(name)) - value = int(value / (1024 * 1024)) - retval.append("{}: {}MB".format(name.capitalize(), value)) - return ", ".join(retval) - - -sysinfo = SysInfo() # pylint: disable=invalid-name diff --git a/lib/system/__init__.py b/lib/system/__init__.py new file mode 100644 index 0000000000..59b7bf0962 --- /dev/null +++ b/lib/system/__init__.py @@ -0,0 +1,5 @@ +#! /usr/env/bin/python3 +""" Contains system information for error reporting and installation.""" + +from .system import Packages, System +from .ml_libs import Cuda, ROCm diff --git a/lib/system/ml_libs.py b/lib/system/ml_libs.py new file mode 100644 index 0000000000..d0c956db30 --- /dev/null +++ b/lib/system/ml_libs.py @@ -0,0 +1,998 @@ +#! /usr/env/bin/python +""" +Queries information about system installed Machine Learning Libraries. +NOTE: Only packages from Python's Standard Library should be imported in this module +""" +from __future__ import annotations + +import json +import logging +import os +import platform +import re +import typing as T + +from abc import ABC, abstractmethod +from shutil import which + +from lib.utils import get_module_objects + +from .system import _lines_from_command + +if platform.system() == "Windows": + import winreg # pylint:disable=import-error +else: + winreg = None # type:ignore[assignment] # pylint:disable=invalid-name + +if T.TYPE_CHECKING: + from winreg import HKEYType # type:ignore[attr-defined] + +logger = logging.getLogger(__name__) + + +_TORCH_ROCM_REQUIREMENTS = {">=2.2.1,<2.4.0": ((6, 0), (6, 0))} +"""dict[str, tuple[tuple[int, int], tuple[int, int]]]: Minumum and maximum ROCm versions """ + + +def _check_dynamic_linker(lib: str) -> list[str]: + """ Locate the folders that contain a given library in ldconfig and $LD_LIBRARY_PATH + + Parameters + ---------- + lib: str The library to locate + + Returns + ------- + list[str] + All real existing folders from ldconfig or $LD_LIBRARY_PATH that contain the given lib + """ + paths: set[str] = set() + ldconfig = which("ldconfig") + if ldconfig: + paths.update({os.path.realpath(os.path.dirname(line.split("=>")[-1].strip())) + for line in _lines_from_command([ldconfig, "-p"]) + if lib in line and "=>" in line}) + + if not os.environ.get("LD_LIBRARY_PATH"): + return list(paths) + + paths.update({os.path.realpath(path) + for path in os.environ["LD_LIBRARY_PATH"].split(":") + if path and os.path.exists(path) + for fname in os.listdir(path) + if lib in fname}) + return list(paths) + + +def _files_from_folder(folder: str, prefix: str) -> list[str]: + """ Obtain all filenames from the given folder that start with the given prefix + + Parameters + ---------- + folder : str + The folder to search for files in + prefix : str + The filename prefix to search for + + Returns + ------- + list[str] + All filenames that exist in the given folder with the given prefic + """ + if not os.path.exists(folder): + return [] + return [f for f in os.listdir(folder) if f.startswith(prefix)] + + +class _Alternatives: + """ Holds output from the update-alternatives command for the given package + + Parameters + ---------- + package : str + The package to query update-alternatives for information + """ + def __init__(self, package: str) -> None: + self._package = package + self._bin = which("update-alternatives") + self._default_marker = "link currently points to" + self._alternatives_marker = "priority" + self._output: list[str] | None = None + + @property + def alternatives(self) -> list[str]: + """ list[str] : Full path to alternatives listed for the given package """ + if self._output is None: + self._query() + if not self._output: + return [] + retval = [line.rsplit(" - ", maxsplit=1)[0] for line in self._output + if self._alternatives_marker in line.lower()] + logger.debug("Versions from 'update-alternatives' for '%s': %s", self._package, retval) + return retval + + @property + def default(self) -> str: + """ str : Full path to the default package """ + if self._output is None: + self._query() + if not self._output: + return "" + retval = next((x for x in self._output + if x.startswith(self._default_marker)), "").replace(self._default_marker, + "").strip() + logger.debug("Default from update-alternatives for '%s': %s", self._package, retval) + return retval + + def _query(self) -> None: + """ Query update-alternatives for the given package and place stripped output into + :attr:`_output` """ + if not self._bin: + self._output = [] + return + cmd = [self._bin, "--display", self._package] + retval = [line.strip() for line in _lines_from_command(cmd)] + logger.debug("update-alternatives output for command %s: %s", + cmd, retval) + self._output = retval + + +class _Cuda(ABC): + """ Find the location of system installed Cuda and cuDNN on Windows and Linux. """ + def __init__(self) -> None: + self.versions: list[tuple[int, int]] = [] + """ list[tuple[int, int]] : All detected globally installed Cuda versions """ + self.version: tuple[int, int] = (0, 0) + """ tuple[int, int] : Default installed Cuda version. (0, 0) if not detected """ + self.cudnn_versions: dict[tuple[int, int], tuple[int, int, int]] = {} + """ dict[tuple[int, int], tuple[int, int, int]] : Detected cuDNN version for each installed + Cuda. key (0, 0) denotes globally installed cudnn """ + self._paths: list[str] = [] + """ list[str] : list of path to Cuda install folders relating to :attr:`versions` """ + + self._version_file = "version.json" + self._lib = "libcudart.so" + self._cudnn_header = "cudnn_version.h" + self._alternatives = _Alternatives("cuda") + self._re_cudnn = re.compile(r"#define CUDNN_(MAJOR|MINOR|PATCHLEVEL)\s+(\d+)") + + if platform.system() in ("Windows", "Linux"): + self._get_versions() + self._get_version() + self._get_cudnn_versions() + + def __repr__(self) -> str: + """ Pretty representation of this class """ + attrs = ", ".join(f"{k}={repr(v)}" for k, v in self.__dict__.items() + if not k.startswith("_")) + return f"{self.__class__.__name__}({attrs})" + + @classmethod + def _tuple_from_string(cls, version: str) -> tuple[int, int] | None: + """ Convert a Cuda version string to a version tuple + + Parameters + ---------- + version : str + The Cuda version string to convert + + Returns + ------- + tuple[int, int] | None + The converted Cuda version string. ``None`` if not a valid version string + """ + if version.startswith("."): + version = version[1:] + split = version.split(".") + if len(split) not in (2, 3): + return None + split = split[:2] + if not all(x.isdigit() for x in split): + return None + return (int(split[0]), int(split[1])) + + @abstractmethod + def get_versions(self) -> dict[tuple[int, int], str]: + """ Overide to Attempt to detect all installed Cuda versions on Linux or Windows systems + + Returns + ------- + dict[tuple[int, int], str] + The Cuda versions to the folder path on the system + """ + + @abstractmethod + def get_version(self) -> tuple[int, int] | None: + """ Override to attempt to locate the default Cuda version on Linux or Windows + + Returns + ------- + tuple[int, int] | None + The Default global Cuda version or ``None`` if not found + """ + + @abstractmethod + def get_cudnn_versions(self) -> dict[tuple[int, int], tuple[int, int, int]]: + """ Override to attempt to locate any installed cuDNN versions + + Returns + ------- + dict[tuple[int, int], tuple[int, int, int]] + Detected cuDNN version for each installed Cuda. key (0, 0) denotes globally installed + cudnn + """ + + def version_from_version_file(self, folder: str) -> tuple[int, int] | None: + """ Attempt to get an installed Cuda version from its version.json file + + Parameters + ---------- + folder : str + Full path to the folder to check for a version file + + Returns + ------- + tuple[int, int] | None + The detected Cuda version or ``None`` if not detected + """ + vers_file = os.path.join(folder, self._version_file) + if not os.path.exists(vers_file): + return None + with open(vers_file, "r", encoding="utf-8", errors="replace") as f: + vers = json.load(f) + retval = self._tuple_from_string(vers.get("cuda_cudart", {}).get("version")) + logger.debug("Version from '%s': %s", vers_file, retval) + return retval + + def _version_from_nvcc(self) -> tuple[int, int] | None: + """ Obtain the version from NVCC output if it is on PATH + + Returns + ------- + tuple[int, int] | None + The detected default Cuda version. ``None`` if not version detected + """ + retval = None + nvcc = which("nvcc") + if not nvcc: + return retval + + for line in _lines_from_command([nvcc, "-V"]): + vers = re.match(r".*release (\d+\.\d+)", line) + if vers is not None: + retval = self._tuple_from_string(vers.group(1)) + break + logger.debug("Version from NVCC '%s': %s", nvcc, retval) + return retval + + def _get_versions(self) -> None: + """ Attempt to detect all installed Cuda versions and populate to :attr:`versions` """ + versions = self.get_versions() + if versions: + logger.debug("Cuda Versions: %s", versions) + self.versions = list(versions) + self._paths = list(versions.values()) + return + logger.debug("Could not locate any Cuda versions") + + def _get_version(self) -> None: + """ Attempt to detect the default Cuda version and populate to :attr:`version` """ + version: tuple[int, int] | None = None + if len(self.versions) == 1: + version = self.versions[0] + logger.debug("Only 1 installed Cuda version: %s", version) + if not version: + version = self._version_from_nvcc() + if not version: + version = self.get_version() + if version: + self.version = version + logger.debug("Cuda version: %s", self.version if version else "not detected") + + def _get_cudnn_versions(self) -> None: + """ Attempt to locate any installed cuDNN versions and add to :attr`cudnn_versions` """ + versions = self.get_cudnn_versions() + if versions: + logger.debug("cudnn versions: %s", versions) + self.cudnn_versions = versions + return + logger.debug("No cudnn versions found") + + def cudnn_version_from_header(self, folder: str) -> tuple[int, int, int] | None: + """ Attempt to detect the cuDNN version from the version header file + + Parameters + ---------- + folder : str + The folder to check for the cuDNN header file + + Returns + ------- + tuple[int, int, int] | None + The cuDNN version found from the given folder or ``None`` if not detected + """ + path = os.path.join(folder, self._cudnn_header) + if not os.path.exists(path): + logger.debug("cudnn file '%s' does not exist", path) + return None + + with open(path, "r", encoding="utf-8", errors="ignore") as f: + file = f.read() + version = {v[0]: int(v[1]) if v[1].isdigit() else 0 + for v in self._re_cudnn.findall(file)} + if not version: + logger.debug("cudnn version could not be found in '%s'", path) + return None + + logger.debug("cudnn version from '%s': %s", path, version) + retval = (version.get("MAJOR", 0), version.get("MINOR", 0), version.get("PATCHLEVEL", 0)) + logger.debug("cudnn versions: %s", retval) + return retval + + +class CudaLinux(_Cuda): + """ Find the location of system installed Cuda and cuDNN on Linux. """ + def __init__(self) -> None: + self._folder_prefix = "cuda-" + super().__init__() + + def _version_from_lib(self, folder: str) -> tuple[int, int] | None: + """ Attempt to locate the version from the existence of libcudart.so within a Cuda + targets/x86_64-linux/lib folder + + Parameters + ---------- + folder : str + Full file path to the Cuda folder + + Returns + ------- + tuple[int, int] | None + The Cuda version identified by the existence of the libcudart.so file. ``None`` if + not detected + """ + lib_folder = os.path.join(folder, "targets", "x86_64-linux", "lib") + lib_versions = [f.replace(self._lib, "") + for f in _files_from_folder(lib_folder, self._lib)] + if not lib_versions: + return None + versions = [self._tuple_from_string(f[1:]) + for f in lib_versions if f and f.startswith(".")] + valid = [v for v in versions if v is not None] + if not valid or not len(set(valid)) == 1: + return None + retval = valid[0] + logger.debug("Version from '%s': %s", os.path.join(lib_folder, self._lib), retval) + return retval + + def _versions_from_usr(self) -> dict[tuple[int, int], str]: + """ Attempt to detect all installed Cuda versions from the /usr/local folder + + Scan /usr/local for cuda-x.x folders containing either a version.json file or + include/lib/libcudart.so.x. + + Returns + ------- + dict[tuple[int, int], str] + A dictionary of detected Cuda versions to their install paths + """ + retval: dict[tuple[int, int], str] = {} + usr = os.path.join(os.sep, "usr", "local") + + for folder in _files_from_folder(usr, self._folder_prefix): + path = os.path.join(usr, folder) + if os.path.islink(path): + continue + version = self.version_from_version_file(path) or self._version_from_lib(path) + if version is not None: + retval[version] = path + return retval + + def _versions_from_alternatives(self) -> dict[tuple[int, int], str]: + """ Attempt to detect all installed Cuda versions from update-alternatives + + Returns + ------- + list[tuple[int, int, int]] + A dictionary of detected Cuda versions to their install paths found in + update-alternatives + """ + retval: dict[tuple[int, int], str] = {} + alts = self._alternatives.alternatives + for path in alts: + vers = self.version_from_version_file(path) or self._version_from_lib(path) + if vers is not None: + retval[vers] = path + logger.debug("Versions from 'update-alternatives': %s", retval) + return retval + + def _parent_from_targets(self, folder: str) -> str: + """ Obtain the Cuda parent folder from a path obtained from child targets folder + + Parameters + ---------- + folder : str + Full path to a folder that has a 'targets' folder in its path + + Returns + ------- + str + The potential parent Cuda folder, or an empty string if not detected + """ + split = folder.split(os.sep) + return os.sep.join(split[:split.index("targets")]) if "targets" in split else "" + + def _versions_from_dynamic_linker(self) -> dict[tuple[int, int], str]: + """ Attempt to detect all installed Cuda versions from ldconfig + + Returns + ------- + dict[tuple[int, int], str] + The Cuda version to the folder path found from ldconfig + """ + retval: dict[tuple[int, int], str] = {} + folders = _check_dynamic_linker(self._lib) + cuda_roots = [self._parent_from_targets(f) for f in folders] + for path in cuda_roots: + if not path: + continue + version = self.version_from_version_file(path) or self._version_from_lib(path) + if version is not None: + retval[version] = path + + logger.debug("Versions from 'ld_config': %s", retval) + return retval + + def get_versions(self) -> dict[tuple[int, int], str]: + """ Attempt to detect all installed Cuda versions on Linux systems + + Returns + ------- + dict[tuple[int, int], str] + The Cuda version to the folder path on Linux + """ + versions = (self._versions_from_usr() | + self._versions_from_alternatives() | + self._versions_from_dynamic_linker()) + return {k: versions[k] for k in sorted(versions)} + + def _version_from_alternatives(self) -> tuple[int, int] | None: + """ Attempt to get the default Cuda version from update-alternatives + + Returns + ------- + tuple[int, int] | None + The detected default Cuda version. ``None`` if not version detected + """ + default = self._alternatives.default + if not default: + return None + retval = self.version_from_version_file(default) or self._version_from_lib(default) + logger.debug("Version from update-alternatives: %s", retval) + return retval + + def _version_from_link(self) -> tuple[int, int] | None: + """ Attempt to get the default Cuda version from the /usr/local/cuda file + + Returns + ------- + tuple[int, int] | None + The detected default Cuda version. ``None`` if not version detected + """ + path = os.path.join(os.sep, "usr", "local", "cuda") + if not os.path.exists(path): + return None + real_path = os.path.abspath(os.path.realpath(path)) if os.path.islink(path) else path + retval = self.version_from_version_file(real_path) or self._version_from_lib(real_path) + logger.debug("Version from symlink: %s", retval) + return retval + + def _version_from_dynamic_linker(self) -> tuple[int, int] | None: + """ Attempt to get the default version from ldconfig or $LD_LIBRARY_PATH + + Returns + ------- + tuple[int, int, int] | None + The detected default ROCm version. ``None`` if not version detected + """ + paths = _check_dynamic_linker(self._lib) + if len(paths) != 1: # Multiple or None + return None + root = self._parent_from_targets(paths[0]) + retval = self.version_from_version_file(root) or self._version_from_lib(root) + logger.debug("Version from ld_config: %s", retval) + return retval + + def get_version(self) -> tuple[int, int] | None: + """ Attempt to locate the default Cuda version on Linux + + Checks, in order: update-alternatives, /usr/local/cuda, ldconfig, nvcc + + Returns + ------- + tuple[int, int] | None + The Default global Cuda version or ``None`` if not found + """ + return (self._version_from_alternatives() or + self._version_from_link() or + self._version_from_dynamic_linker()) + + def get_cudnn_versions(self) -> dict[tuple[int, int], tuple[int, int, int]]: + """ Attempt to locate any installed cuDNN versions on Linux + + Returns + ------- + dict[tuple[int, int], tuple[int, int, int]] + Detected cuDNN version for each installed Cuda. key (0, 0) denotes globally installed + cudnn + """ + retval: dict[tuple[int, int], tuple[int, int, int]] = {} + gbl = ["/usr/include", "/usr/local/include"] + lcl = [os.path.join(f, "include") for f in self._paths] + for root in gbl + lcl: + for folder, _, filenames in os.walk(root): + if self._cudnn_header not in filenames: + continue + version = self.cudnn_version_from_header(folder) + if not version: + continue + cuda_vers = ((0, 0) if root in gbl + else self.versions[self._paths.index(os.path.dirname(root))]) + retval[cuda_vers] = version + return retval + + +class CudaWindows(_Cuda): + """ Find the location of system installed Cuda and cuDNN on Windows. """ + + @classmethod + def _enum_subkeys(cls, key: HKEYType) -> T.Generator[str, None, None]: + """ Iterate through a Registry key's sub-keys + + Parameters + ---------- + key : :class:`winreg.HKEYType` + The Registry key to iterate + + Yields + ------ + str + A sub-key name from the given registry key + """ + assert winreg is not None + i = 0 + while True: + try: + yield winreg.EnumKey(key, i) # type:ignore[attr-defined] + except OSError: + break + i += 1 + + def get_versions(self) -> dict[tuple[int, int], str]: + """ Attempt to detect all installed Cuda versions on Windows systems from the registry + + Returns + ------- + dict[tuple[int, int], str] + The Cuda version to the folder path on Windows + """ + retval: dict[tuple[int, int], str] = {} + assert winreg is not None + reg_key = r"SOFTWARE\NVIDIA Corporation\GPU Computing Toolkit\CUDA" + paths = {k.lower().replace("cuda_path_", "").replace("_", "."): v + for k, v in os.environ.items() + if "cuda_path_v" in k.lower()} + try: + with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, # type:ignore[attr-defined] + reg_key) as key: + for version in self._enum_subkeys(key): + vers_tuple = self._tuple_from_string(version[1:]) + if vers_tuple is not None: + retval[vers_tuple] = paths.get(version, "") + except FileNotFoundError: + logger.debug("Could not find Windows Registry key '%s'", reg_key) + return {k: retval[k] for k in sorted(retval)} + + def get_version(self) -> tuple[int, int] | None: + """ Attempt to get the default Cuda version from the Environment Variable + + Returns + ------- + tuple[int, int] | None + The Default global Cuda version or ``None`` if not found + """ + path = os.environ.get("CUDA_PATH") + if not path or path not in self._paths: + return None + + retval = self.versions[self._paths.index(path)] + logger.debug("Version from CUDA_PATH Environment Variable: %s", path) + return retval + + def _get_cudnn_paths(self) -> list[str]: # noqa[C901] + """ Attempt to locate the locations of cuDNN installs for Windows + + Returns + ------- + list[str] + Full path to existing cuDNN installs under Windows + """ + assert winreg is not None + paths: set[str] = set() + cudnn_key = "cudnn_cuda" + reg_key = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall" + lookups = (winreg.HKEY_LOCAL_MACHINE, # type:ignore[attr-defined] + winreg.HKEY_CURRENT_USER) # type:ignore[attr-defined] + for lookup in lookups: + try: + key = winreg.OpenKey(lookup, reg_key) # type:ignore[attr-defined] + except FileNotFoundError: + continue + for name in self._enum_subkeys(key): + if cudnn_key not in name.lower(): + logger.debug("Skipping subkey '%s'", name) + continue + try: + subkey = winreg.OpenKey(key, name) # type:ignore[attr-defined] + logger.debug("Skipping subkey not found '%s'", name) + except FileNotFoundError: + continue + logger.debug("Parsing cudnn key '%s'", cudnn_key) + try: + path, _ = winreg.QueryValueEx(subkey, # type:ignore[attr-defined] + "InstallLocation") + except (FileNotFoundError, OSError): + logger.debug("Skipping missing InstallLocation for sub-key '%s'", subkey) + continue + if not os.path.isdir(path): + logger.debug("Skipping non-existant path '%s'", path) + continue + paths.add(path) + retval = list(paths) + logger.debug("cudnn install paths: %s", retval) + return retval + + def get_cudnn_versions(self) -> dict[tuple[int, int], tuple[int, int, int]]: + """ Attempt to locate any installed cuDNN versions on Windows + + Returns + ------- + dict[tuple[int, int], tuple[int, int, int]] + Detected cuDNN version for each installed Cuda. key (0, 0) denotes globally installed + cudnn + """ + retval: dict[tuple[int, int], tuple[int, int, int]] = {} + gbl = self._get_cudnn_paths() + lcl = [os.path.join(f, "include") for f in self._paths] + for root in gbl + lcl: + for folder, _, filenames in os.walk(root): + if self._cudnn_header not in filenames: + continue + version = self.cudnn_version_from_header(folder) + if not version: + continue + cuda_vers = ((0, 0) if root in gbl + else self.versions[self._paths.index(os.path.dirname(root))]) + retval[cuda_vers] = version + return retval + + +def get_cuda_finder() -> type[_Cuda]: + """Create a platform-specific CUDA object. + + Returns + ------- + type[_Cuda] + The OS specific finder for system-wide Cuda + """ + if platform.system().lower() == "windows": + return CudaWindows + return CudaLinux + + +Cuda = get_cuda_finder() + + +class ROCm(): + """ Find the location of system installed ROCm on Linux """ + def __init__(self) -> None: + self.version_min = min(v[0] for v in _TORCH_ROCM_REQUIREMENTS.values()) + self.version_max = max(v[1] for v in _TORCH_ROCM_REQUIREMENTS.values()) + self.versions: list[tuple[int, int, int]] = [] + """ list[tuple[int, int, int]] : All detected ROCm installed versions """ + self.version: tuple[int, int, int] = (0, 0, 0) + """ tuple[int, int, int] : Default ROCm installed version. (0, 0, 0) if not detected """ + + self._folder_prefix = "rocm-" + self._version_files = ["version-rocm", "version"] + self._lib = "librocm-core.so" + self._alternatives = _Alternatives("rocm") + self._re_version = re.compile(r"(\d+\.\d+\.\d+)(?=$|[-.])") + self._re_config = re.compile(r"\sroc-(\d+\.\d+\.\d+)(?=\s|[-.])") + if platform.system() == "Linux": + self._rocm_check() + + def __repr__(self) -> str: + """ Pretty representation of this class """ + attrs = ", ".join(f"{k}={repr(v)}" for k, v in self.__dict__.items() + if not k.startswith("_")) + return f"{self.__class__.__name__}({attrs})" + + @property + def valid_versions(self) -> list[tuple[int, int, int]]: + """ list[tuple[int, int, int]] """ + return [v for v in self.versions if self.version_min <= v[:2] <= self.version_max] + + @property + def valid_installed(self) -> bool: + """ bool : ``True`` if a valid version of ROCm is installed """ + return any(self.valid_versions) + + @property + def is_valid(self): + """ bool : ``True`` if the default ROCm version is valid """ + return self.version_min <= self.version[:2] <= self.version_max + + @classmethod + def _tuple_from_string(cls, version: str) -> tuple[int, int, int] | None: + """ Convert a ROCm version string to a version tuple + + Parameters + ---------- + version : str + The ROCm version string to convert + + Returns + ------- + tuple[int, int, int] | None + The converted ROCm version string. ``None`` if not a valid version string + """ + split = version.split(".") + if len(split) != 3: + return None + if not all(x.isdigit() for x in split): + return None + return (int(split[0]), int(split[1]), int(split[2])) + + def _version_from_string(self, string: str) -> tuple[int, int, int] | None: + """ Obtain the ROCm version from the end of a string + + Parameters + ---------- + string : str + The string to test for a valid ROCm version + + Returns + ------- + tuple[int, int, int] | None + The ROCm version from the end of the string or ``None`` if not detected + """ + re_vers = self._re_version.search(string) + if re_vers is None: + return None + return self._tuple_from_string(re_vers.group(1)) + + def _version_from_info(self, folder: str) -> tuple[int, int, int] | None: + """ Attempt to locate the version from a version file within a ROCm .info folder + + Parameters + ---------- + file_path : str + Full path to the ROCm .info folder + + Returns + ------- + tuple[int, int, int] | None + The ROCm version extracted from a version file within the .info folder. ``None`` if + not detected + """ + info_loc = [os.path.join(folder, ".info", v) for v in self._version_files] + for info_file in info_loc: + if not os.path.exists(info_file): + continue + with open(info_file, "r", encoding="utf-8") as f: + vers_string = f.read().strip() + if not vers_string: + continue + retval = self._tuple_from_string(vers_string.split("-", maxsplit=1)[0]) + if retval is None: + continue + logger.debug("Version from '%s': %s", info_file, retval) + return retval + return None + + def _version_from_lib(self, folder: str) -> tuple[int, int, int] | None: + """ Attempt to locate the version from the existence of librocm-core.so within a ROCm + lib folder + + Parameters + ---------- + folder : str + Full file path to the ROCm folder + + Returns + ------- + tuple[int, int, int] | None + The ROCm version identified by the existence of the librocm-core.so file. ``None`` if + not detected + """ + lib_folder = os.path.join(folder, "lib") + lib_files = _files_from_folder(lib_folder, self._lib) + if not lib_files: + return None + + # librocm-core naming is librocm-core.so.1.0.##### which is ambiguous. Get from folder + rocm_folder = os.path.basename(folder) + if not rocm_folder.startswith(self._folder_prefix): + return None + retval = self._version_from_string(rocm_folder) + logger.debug("Version from '%s': %s", os.path.join(lib_folder, self._lib), retval) + return retval + + def _versions_from_opt(self) -> list[tuple[int, int, int]]: + """ Attempt to detect all installed ROCm versions from the /opt folder + + Scan /opt for rocm.x.x.x folders containing either .info or lib/librocm-core.so.x + + Returns + ------- + list[tuple[int, int, int]] + Any ROCm versions found in the /opt folder + """ + retval: list[tuple[int, int, int]] = [] + opt = os.path.join(os.sep, "opt") + + for folder in _files_from_folder(opt, self._folder_prefix): + path = os.path.join(opt, folder) + version = self._version_from_info(path) or self._version_from_lib(path) + if version is not None: + retval.append(version) + + return retval + + def _versions_from_alternatives(self) -> list[tuple[int, int, int]]: + """ Attempt to detect all installed ROCm versions from update-alternatives + + Returns + ------- + list[tuple[int, int, int]] + Any ROCm versions found in update-alternatives + """ + alts = self._alternatives.alternatives + if not alts: + return [] + versions = [self._version_from_string(c) for c in alts] + retval = list(set(v for v in versions if v is not None)) + logger.debug("Versions from 'update-alternatives': %s", retval) + return retval + + def _versions_from_dynamic_linker(self) -> list[tuple[int, int, int]]: + """ Attempt to detect all installed ROCm versions from ldconfig + + Returns + ------- + dict[tuple[int, int], str] + The ROCm versions found from ldconfig + """ + retval: list[tuple[int, int, int]] = [] + folders = _check_dynamic_linker(self._lib) + for folder in folders: + path = os.path.dirname(folder) + version = self._version_from_info(path) or self._version_from_lib(path) + if version is not None: + retval.append(version) + + logger.debug("Versions from 'ld_config': %s", retval) + return retval + + def _get_versions(self) -> None: + """ Attempt to detect all installed ROCm versions and populate to :attr:`rocm_versions` """ + versions = list(sorted(set(self._versions_from_opt()) | + set(self._versions_from_alternatives()) | + set(self._versions_from_dynamic_linker()))) + if versions: + logger.debug("ROCm Versions: %s", versions) + self.versions = versions + return + logger.debug("Could not locate any ROCm versions") + + def _version_from_hipconfig(self) -> tuple[int, int, int] | None: + """ Attempt to get the default version from hipconfig + + Returns + ------- + tuple[int, int, int] | None + The detected default ROCm version. ``None`` if not version detected + """ + retval: tuple[int, int, int] | None = None + exe = which("hipconfig") + if not exe: + return retval + lines = _lines_from_command([exe, "--full"]) + if not lines: + return retval + for line in lines: + line = line.strip() + if line.startswith("ROCM_PATH"): + path = line.split(":", maxsplit=1)[-1] + retval = self._version_from_info(path) or self._version_from_lib(path) + match = self._re_config.search(line) + + if match is not None: + retval = self._tuple_from_string(match.group(1)) + + logger.debug("Version from hipconfig: %s", retval) + return retval + + def _version_from_alternatives(self) -> tuple[int, int, int] | None: + """ Attempt to get the default version from update-alternatives + + Returns + ------- + tuple[int, int, int] | None + The detected default ROCm version. ``None`` if not version detected + """ + default = self._alternatives.default + if not default: + return None + retval = self._version_from_string(default.rsplit(os.sep, maxsplit=1)[-1]) + logger.debug("Version from update-alternatives: %s", retval) + return retval + + def _version_from_link(self) -> tuple[int, int, int] | None: + """ Attempt to get the default version from the /opt/rocm file + + Returns + ------- + tuple[int, int, int] | None + The detected default ROCm version. ``None`` if not version detected + """ + path = os.path.join(os.sep, "opt", "rocm") + if not os.path.exists(path): + return None + real_path = os.path.abspath(os.path.realpath(path)) if os.path.islink(path) else path + retval = self._version_from_info(real_path) or self._version_from_lib(real_path) + logger.debug("Version from symlink: %s", retval) + return retval + + def _version_from_dynamic_linker(self) -> tuple[int, int, int] | None: + """ Attempt to get the default version from ldconfig or $LD_LIBRARY_PATH + + Returns + ------- + tuple[int, int, int] | None + The detected default ROCm version. ``None`` if not version detected + """ + paths = _check_dynamic_linker("librocm-core.so.") + if len(paths) != 1: # Multiple or None + return None + path = os.path.dirname(paths[0]) + retval = self._version_from_info(path) or self._version_from_lib(path) + logger.debug("Version from ld_config: %s", retval) + return retval + + def _get_version(self) -> None: + """ Attempt to detect the default ROCm version """ + version = (self._version_from_hipconfig() or + self._version_from_alternatives() or + self._version_from_link() or + self._version_from_dynamic_linker()) + if version is not None: + logger.debug("ROCm default version: %s", version) + self.version = version + return + logger.debug("Could not locate default ROCm version") + + def _rocm_check(self) -> None: + """ Attempt to locate the installed ROCm versions and the default ROCm version """ + self._get_versions() + self._get_version() + logger.debug("ROCm Versions: %s, Version: %s", self.versions, self.version) + + +__all__ = get_module_objects(__name__) + + +if __name__ == "__main__": + print(Cuda()) + print(ROCm()) diff --git a/lib/system/sysinfo.py b/lib/system/sysinfo.py new file mode 100644 index 0000000000..28f067eda9 --- /dev/null +++ b/lib/system/sysinfo.py @@ -0,0 +1,419 @@ +#!/usr/bin python3 +""" Obtain information about the running system, environment and GPU. """ + +import json +import os +import platform +import sys + +from subprocess import PIPE, Popen + +from lib.git import git +from lib.gpu_stats import GPUInfo, GPUStats +from lib.utils import get_backend, get_module_objects, PROJECT_ROOT + +from .ml_libs import Cuda, ROCm +from .system import Packages, System + +try: + import psutil +except ImportError: + psutil = None # type:ignore[assignment] + + +class _SysInfo(): + """ Obtain information about the System, Python and GPU """ + def __init__(self) -> None: + self._state_file = _State().state_file + self._configs = _Configs().configs + self._system = System() + self._python = {"implementation": platform.python_implementation(), + "version": platform.python_version()} + self._packages = Packages() + self._gpu = self._get_gpu_info() + self._cuda = Cuda() + self._rocm = ROCm() + + @property + def _ram_free(self) -> int: + """ int : The amount of free RAM in bytes. """ + if psutil is None: + return -1 + return psutil.virtual_memory().free + + @property + def _ram_total(self) -> int: + """ int : The amount of total RAM in bytes. """ + if psutil is None: + return -1 + return psutil.virtual_memory().total + + @property + def _ram_available(self) -> int: + """ int : The amount of available RAM in bytes. """ + if psutil is None: + return -1 + return psutil.virtual_memory().available + + @property + def _ram_used(self) -> int: + """ int : The amount of used RAM in bytes. """ + if psutil is None: + return -1 + return psutil.virtual_memory().used + + @property + def _fs_command(self) -> str: + """ str : The command line command used to execute faceswap. """ + return " ".join(sys.argv) + + @property + def _conda_version(self) -> str: + """ str : The installed version of Conda, or `N/A` if Conda is not installed. """ + if not self._system.is_conda: + return "N/A" + with Popen("conda --version", shell=True, stdout=PIPE, stderr=PIPE) as conda: + stdout, stderr = conda.communicate() + if stderr: + return "Conda is used, but version not found" + version = stdout.decode(self._system.encoding, errors="replace").splitlines() + return "\n".join(version) + + @property + def _git_commits(self) -> str: + """ str : The last 5 git commits for the currently running Faceswap. """ + commits = git.get_commits(3) + if not commits: + return "Not Found" + return " | ".join(commits) + + @property + def _cuda_versions(self) -> str: + """ str : The globally installed Cuda versions""" + if not self._cuda.versions: + return "No global Cuda versions found" + return ", ".join(".".join(str(x) for x in v) for v in self._cuda.versions) + + @property + def _cuda_version(self) -> str: + """ str : The installed CUDA version. """ + if self._cuda.version == (0, 0): + retval = "No global version found" + if self._system.is_conda: + retval += ". Check Conda packages for Conda Cuda" + return retval + return ".".join(str(x) for x in self._cuda.version) + + @property + def _cudnn_versions(self) -> str: + """ str : The installed cuDNN versions. """ + if not self._cuda.cudnn_versions: + retval = "No global version found" + if self._system.is_conda: + retval += ". Check Conda packages for Conda cuDNN" + return retval + retval = "" + for k, v in self._cuda.cudnn_versions.items(): + retval += f"{'.'.join(str(x) for x in v)}" + retval += f"({'global' if k == (0, 0) else '.'.join(str(x) for x in k)}), " + + return retval[:-2] + + @property + def _rocm_version(self) -> str: + """ str : The default ROCm version """ + if self._rocm.version == (0, 0, 0): + return "No default ROCm version found" + return ".".join(str(x) for x in self._rocm.version) + + @property + def _rocm_versions(self) -> str: + """ str : The installed ROCm versions """ + if not self._rocm.versions: + return "No ROCm versions found" + return ", ".join(".".join(str(x) for x in v) for v in self._rocm.versions) + + def _get_gpu_info(self) -> GPUInfo: + """ Obtain GPU Stats. If an error is raised, swallow the error, and add to GPUInfo output + + Returns + ------- + :class:`~lib.gpu_stats.GPUInfo` + The information on connected GPUs + """ + if GPUStats is None: + return GPUInfo(vram=[], + vram_free=[], + driver="N/A", + devices=["Error obtaining GPU Stats: 'GPUStats import error'"], + devices_active=[]) + try: + retval = GPUStats(log=False).sys_info + except Exception as err: # pylint:disable=broad-except + err_string = f"{type(err)}: {err}" + retval = GPUInfo(vram=[], + vram_free=[], + driver="N/A", + devices=[f"Error obtaining GPU Stats: '{err_string}'"], + devices_active=[]) + return retval + + def _format_ram(self) -> str: + """ Format the RAM stats into Megabytes to make it more readable. + + Returns + ------- + str + The total, available, used and free RAM displayed in Megabytes + """ + retval = [] + for name in ("total", "available", "used", "free"): + value = getattr(self, f"_ram_{name}") + value = int(value / (1024 * 1024)) + retval.append(f"{name.capitalize()}: {value}MB") + return ", ".join(retval) + + def full_info(self) -> str: + """ Obtain extensive system information stats, formatted into a human readable format. + + Returns + ------- + str + The system information for the currently running system, formatted for output to + console or a log file. + """ + retval = "\n============ System Information ============\n" + sys_info = {"backend": get_backend(), + "os_platform": self._system.platform, + "os_machine": self._system.machine, + "os_release": self._system.release, + "py_conda_version": self._conda_version, + "py_implementation": self._system.python_implementation, + "py_version": self._system.python_version, + "py_command": self._fs_command, + "py_virtual_env": self._system.is_virtual_env, + "sys_cores": self._system.cpu_count, + "sys_processor": self._system.processor, + "sys_ram": self._format_ram(), + "encoding": self._system.encoding, + "git_branch": git.branch, + "git_commits": self._git_commits, + "gpu_cuda_versions": self._cuda_versions, + "gpu_cuda": self._cuda_version, + "gpu_cudnn": self._cudnn_versions, + "gpu_rocm_versions": self._rocm_versions, + "gpu_rocm_version": self._rocm_version, + "gpu_driver": self._gpu.driver, + "gpu_devices": ", ".join([f"GPU_{idx}: {device}" + for idx, device in enumerate(self._gpu.devices)]), + "gpu_vram": ", ".join( + f"GPU_{idx}: {int(vram)}MB ({int(vram_free)}MB free)" + for idx, (vram, vram_free) in enumerate(zip(self._gpu.vram, + self._gpu.vram_free))), + "gpu_devices_active": ", ".join([f"GPU_{idx}" + for idx in self._gpu.devices_active])} + for key in sorted(sys_info.keys()): + retval += (f"{key + ':':<20} {sys_info[key]}\n") + retval += "\n=============== Pip Packages ===============\n" + retval += self._packages.installed_python_pretty + if self._system.is_conda: + retval += "\n\n============== Conda Packages ==============\n" + retval += self._packages.installed_conda_pretty + retval += self._state_file + retval += "\n\n================= Configs ==================" + retval += self._configs + return retval + + +def get_sysinfo() -> str: + """ Obtain extensive system information stats, formatted into a human readable format. + If an error occurs obtaining the system information, then the error message is returned + instead. + + Returns + ------- + str + The system information for the currently running system, formatted for output to + console or a log file. + """ + try: + retval = _SysInfo().full_info() + except Exception as err: # pylint:disable=broad-except + retval = f"Exception occured trying to retrieve sysinfo: {str(err)}" + raise + return retval + + +class _Configs(): # pylint:disable=too-few-public-methods + """ Parses the config files in /faceswap/config and outputs the information stored within them + in a human readable format. """ + + def __init__(self) -> None: + self.config_dir = os.path.join(PROJECT_ROOT, "config") + self.configs = self._get_configs() + + def _get_configs(self) -> str: + """ Obtain the formatted configurations from the config folder. + + Returns + ------- + str + The current configuration in the config files formatted in a human readable format + """ + try: + config_files = [os.path.join(self.config_dir, cfile) + for cfile in os.listdir(self.config_dir) + if os.path.basename(cfile) == ".faceswap" + or os.path.splitext(cfile)[1] == ".ini"] + return self._parse_configs(config_files) + except FileNotFoundError: + return "" + + def _parse_configs(self, config_files: list[str]) -> str: + """ Parse the given list of config files into a human readable format. + + Parameters + ---------- + config_files : list[str] + A list of paths to the faceswap config files + + Returns + ------- + str + The current configuration in the config files formatted in a human readable format + """ + formatted = "" + for cfile in config_files: + fname = os.path.basename(cfile) + ext = os.path.splitext(cfile)[1] + formatted += f"\n--------- {fname} ---------\n" + if ext == ".ini": + formatted += self._parse_ini(cfile) + elif fname == ".faceswap": + formatted += self._parse_json(cfile) + return formatted + + def _parse_ini(self, config_file: str) -> str: + """ Parse an ``.ini`` formatted config file into a human readable format. + + Parameters + ---------- + config_file : str + The path to the config.ini file + + Returns + ------- + str + The current configuration in the config file formatted in a human readable format + """ + formatted = "" + with open(config_file, "r", encoding="utf-8", errors="replace") as cfile: + for line in cfile.readlines(): + line = line.strip() + if line.startswith("#") or not line: + continue + item = line.split("=") + if len(item) == 1: + formatted += f"\n{item[0].strip()}\n" + else: + formatted += self._format_text(item[0], item[1]) + return formatted + + def _parse_json(self, config_file: str) -> str: + """ Parse an ``.json`` formatted config file into a formatted string. + + Parameters + ---------- + config_file : str + The path to the config.json file + + Returns + ------- + dict + The current configuration in the config file formatted as a python dictionary + """ + formatted: str = "" + with open(config_file, "r", encoding="utf-8", errors="replace") as cfile: + conf_dict = json.load(cfile) + for key in sorted(conf_dict.keys()): + formatted += self._format_text(key, conf_dict[key]) + return formatted + + @staticmethod + def _format_text(key: str, value: str) -> str: + """Format a key value pair into a consistently spaced string output for display. + + Parameters + ---------- + key : str + The label for this display item + value : str + The value for this display item + + Returns + ------- + str + The formatted key value pair for display + """ + return f"{key.strip() + ':':<25} {value.strip()}\n" + + +class _State(): # pylint:disable=too-few-public-methods + """ Parses the state file in the current model directory, if the model is training, and + formats the content into a human readable format. """ + def __init__(self) -> None: + self._model_dir = self._get_arg("-m", "--model-dir") + self._trainer = self._get_arg("-t", "--trainer") + self.state_file = self._get_state_file() + + @property + def _is_training(self) -> bool: + """ bool : ``True`` if this function has been called during a training session + otherwise ``False``. """ + return len(sys.argv) > 1 and sys.argv[1].lower() == "train" + + @staticmethod + def _get_arg(*args: str) -> str | None: + """ Obtain the value for a given command line option from sys.argv. + + Returns + ------- + str or ``None`` + The value of the given command line option, if it exists, otherwise ``None`` + """ + cmd = sys.argv + for opt in args: + if opt in cmd: + idx = cmd.index(opt) + 1 + if len(cmd) > idx: + return cmd[idx] + return None + + def _get_state_file(self) -> str: + """ Parses the model's state file and compiles the contents into a human readable string. + + Returns + ------- + str + The state file formatted into a human readable format + """ + if not self._is_training or self._model_dir is None or self._trainer is None: + return "" + fname = os.path.join(self._model_dir, f"{self._trainer}_state.json") + if not os.path.isfile(fname): + return "" + + retval = "\n\n=============== State File =================\n" + with open(fname, "r", encoding="utf-8", errors="replace") as sfile: + retval += sfile.read() + return retval + + +sysinfo = get_sysinfo() # pylint:disable=invalid-name + + +__all__ = get_module_objects(__name__) + + +if __name__ == "__main__": + print(sysinfo) diff --git a/lib/system/system.py b/lib/system/system.py new file mode 100644 index 0000000000..9469fc1e68 --- /dev/null +++ b/lib/system/system.py @@ -0,0 +1,299 @@ +#! /usr/env/bin/python3 +""" +Holds information about the running system. Used in setup.py and lib.sysinfo +NOTE: Only packages from Python's Standard Library should be imported in this module +""" +from __future__ import annotations + +import ctypes +import locale +import logging +import os +import platform +import re +import sys +import typing as T + +from shutil import which +from subprocess import CalledProcessError, run + +from lib.utils import get_module_objects + +logger = logging.getLogger(__name__) + + +VALID_PYTHON = ((3, 11), (3, 13)) +""" tuple[tuple[int, int], tuple[int, int]] : The minimum and maximum versions of Python that can +run Faceswap """ +VALID_TORCH = ((2, 3), (2, 9)) +""" tuple[tuple[int, int], tuple[int, int]] : The minimum and maximum versions of Torch that can +run Faceswap """ +VALID_KERAS = ((3, 12), (3, 12)) +""" tuple[tuple[int, int], tuple[int, int]] : The minimum and maximum versions of Keras that can +run Faceswap """ + + +def _lines_from_command(command: list[str]) -> list[str]: + """ Output stdout lines from an executed command. + + Parameters + ---------- + command : list[str] + The command to run + + Returns + ------- + list[str] + The output lines from the given command + """ + logger.debug("Running command %s", command) + try: + proc = run(command, + capture_output=True, + check=True, + encoding=locale.getpreferredencoding(), + errors="replace") + except (FileNotFoundError, CalledProcessError) as err: + logger.debug("Error from command: %s", str(err)) + return [] + return proc.stdout.splitlines() + + +class System: # pylint:disable=too-many-instance-attributes + """ Holds information about the currently running system and environment """ + def __init__(self) -> None: + self.platform = platform.platform() + """ str : Human readable platform identifier """ + self.system: T.Literal["darwin", "linux", "windows"] = T.cast( + T.Literal["darwin", "linux", "windows"], platform.system().lower()) + """ str : The system (OS type) that this code is running on. Always lowercase """ + self.machine = platform.machine() + """ str : The machine type (eg: "x86_64") """ + self.release = platform.release() + """ str : The OS Release that this code is running on """ + self.processor = platform.processor() + """ str : The processor in use, if detected """ + self.cpu_count = os.cpu_count() + """ int : The number of CPU cores on the system """ + self.python_implementation = platform.python_implementation() + """ str : The python implementation in use""" + self.python_version = platform.python_version() + """ str : The .. version of Python that is running """ + self.python_architecture = platform.architecture()[0] + """ str : The Python architecture that is running (eg: 64bit/32bit)""" + self.encoding = locale.getpreferredencoding() + """ str : The system encoding """ + self.is_conda = ("conda" in sys.version.lower() or + os.path.exists(os.path.join(sys.prefix, 'conda-meta'))) + """ bool : ``True`` if running under Conda otherwise ``False`` """ + self.is_admin = self._get_permissions() + """ bool : ``True`` if we are running with Admin privileges """ + self.is_virtual_env = self._check_virtual_env() + """ bool : ``True`` if Python is being run inside a virtual environment """ + + @property + def is_linux(self) -> bool: + """ bool : `True` if running on a Linux system otherwise ``False``. """ + return self.system == "linux" + + @property + def is_macos(self) -> bool: + """ bool : `True` if running on a macOS system otherwise ``False``. """ + return self.system == "darwin" + + @property + def is_windows(self) -> bool: + """ bool : `True` if running on a Windows system otherwise ``False``. """ + return self.system == "windows" + + def __repr__(self) -> str: + """ Pretty print the system information for logging """ + attrs = ", ".join(f"{k}={repr(v)}" for k, v in self.__dict__.items() + if not k.startswith("_")) + return f"{self.__class__.__name__}({attrs})" + + def _get_permissions(self) -> bool: + """ Check whether user is admin + + Returns + ------- + bool + ``True`` if we are running with Admin privileges + """ + if self.is_windows: + retval = ctypes.windll.shell32.IsUserAnAdmin() != 0 # type:ignore[attr-defined] + else: + retval = os.getuid() == 0 # type:ignore[attr-defined] # pylint:disable=no-member + return retval + + def _check_virtual_env(self) -> bool: + """ Check whether we are in a virtual environment + + Returns + ------- + bool + ``True`` if Python is being run inside a virtual environment + """ + if not self.is_conda: + retval = (hasattr(sys, "real_prefix") or + (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)) + else: + prefix = os.path.dirname(sys.prefix) + retval = os.path.basename(prefix) == "envs" + return retval + + def validate_python(self, max_version: tuple[int, int] | None = None) -> bool: + """ Check that the running Python version is valid + + Parameters + ---------- + max_version: tuple[int, int] | None, Optional + The max version to validate Python against. ``None`` for the project Maximum. + Default: ``None`` (project maximum) + + Returns + ------- + bool + ``True`` if the running Python version is valid, otherwise logs an error and exits + """ + max_python = VALID_PYTHON[1] if max_version is None else max_version + retval = (VALID_PYTHON[0] <= sys.version_info[:2] <= max_python + and self.python_architecture == "64bit") + logger.debug("Python version %s(%s) within %s - %s(64bit): %s", + self.python_version, + self.python_architecture, + VALID_PYTHON[0], + max_python, + retval) + if not retval: + print() + logger.error("Your Python version %s(%s) is unsupported. Please run with Python " + "version %s to %s 64bit.", + self.python_version, + self.python_architecture, + ".".join(str(x) for x in VALID_PYTHON[0]), + ".".join(str(x) for x in max_python)) + print() + logger.error("If you have recently upgraded faceswap, then you will need to create a " + "new virtual environment.") + logger.error("The easiest way to do this is to run the latest version of the Faceswap " + "installer from:") + logger.error("https://github.com/deepfakes/faceswap/releases") + print() + input("Press to close") + sys.exit(1) + + return retval + + def validate(self) -> None: + """ Perform validation that the running system can be used for faceswap. Log an error and + exit if it cannot """ + if not any((self.is_linux, self.is_macos, self.is_windows)): + logger.error("Your system %s is not supported!", self.system.title()) + sys.exit(1) + if self.is_macos and self.machine == "arm64" and not self.is_conda: + logger.error("Setting up Faceswap for Apple Silicon outside of a Conda " + "environment is unsupported") + sys.exit(1) + self.validate_python() + + +class Packages(): + """ Holds information about installed python and conda packages. + + Note: Packaging library is lazy loaded as it may not be available during setup.py + """ + def __init__(self) -> None: + self._conda_exe = which("conda") + self._installed_python = self._get_installed_python() + self._installed_conda: list[str] | None = None + self._get_installed_conda() + + @property + def installed_python(self) -> dict[str, str]: + """ dict[str, str] : Installed Python package names to Python package versions """ + return self._installed_python + + @property + def installed_python_pretty(self) -> str: + """ str: A pretty printed representation of installed Python packages """ + pkgs = self._installed_python + align = max(len(x) for x in pkgs) + 1 + return "\n".join(f"{k.ljust(align)} {v}" for k, v in pkgs.items()) + + @property + def installed_conda(self) -> dict[str, tuple[str, str, str]]: + """ dict[str, tuple[str, str]] : Installed Conda package names to the version and + channel """ + if not self._installed_conda: + return {} + + installed = [re.sub(" +", " ", line.strip()) + for line in self._installed_conda if not line.startswith("#")] + retval = {} + for pkg in installed: + item = pkg.split(" ") + assert len(item) == 4 + retval[item[0]] = T.cast(tuple[str, str, str], tuple(item[1:])) + return retval + + @property + def installed_conda_pretty(self) -> str: + """ str: A pretty printed representation of installed conda packages """ + if not self._installed_conda: + return "Could not get Conda package list" + return "\n".join(self._installed_conda) + + def __repr__(self) -> str: + """ Pretty print the installed packages for logging """ + props = ", ".join( + f"{k}={repr(getattr(self, k))}" + for k, v in self.__class__.__dict__.items() + if isinstance(v, property) and not k.startswith("_") and "pretty" not in k) + return f"{self.__class__.__name__}({props})" + + def _get_installed_python(self) -> dict[str, str]: + """ Parse the installed python modules + + Returns + ------- + dict[str, str] + Installed Python package names to Python package versions + """ + installed = _lines_from_command([sys.executable, "-m", "pip", "freeze", "--local"]) + retval = {} + for pkg in installed: + if "==" not in pkg: + continue + item = pkg.split("==") + retval[item[0].lower()] = item[1] + logger.debug("Installed Python packages: %s", retval) + return retval + + def _get_installed_conda(self) -> None: + """ Collect the output from 'conda list' for the installed Conda packages and + populate :attr:`_installed_conda` + + Returns + ------- + list[str] + Each line of output from the 'conda list' command + """ + if not self._conda_exe: + logger.debug("Conda not found. Not collecting packages") + return + + lines = _lines_from_command([self._conda_exe, "list", "--show-channel-urls"]) + if not lines: + self._installed_conda = ["Could not get Conda package list"] + return + self._installed_conda = lines + logger.debug("Installed Conda packages: %s", self.installed_conda) + + +__all__ = get_module_objects(__name__) + + +if __name__ == "__main__": + print(System()) + print(Packages()) diff --git a/lib/training/__init__.py b/lib/training/__init__.py new file mode 100644 index 0000000000..cc254c0fae --- /dev/null +++ b/lib/training/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +""" Package for handling alignments files, detected faces and aligned faces along with their +associated objects. """ +from __future__ import annotations +import typing as T + +from .augmentation import ImageAugmentation +from .generator import Feeder +from .lr_finder import LearningRateFinder +from .lr_warmup import LearningRateWarmup +from .preview_cv import PreviewBuffer, TriggerType + +if T.TYPE_CHECKING: + from .preview_cv import PreviewBase + Preview: type[PreviewBase] + +try: + from .preview_tk import PreviewTk as Preview +except ImportError: + from .preview_cv import PreviewCV as Preview diff --git a/lib/training/augmentation.py b/lib/training/augmentation.py new file mode 100644 index 0000000000..4856be6525 --- /dev/null +++ b/lib/training/augmentation.py @@ -0,0 +1,593 @@ +#!/usr/bin/env python3 +""" Processes the augmentation of images for feeding into a Faceswap model. """ +from __future__ import annotations +import logging +from dataclasses import dataclass + +import cv2 +import numexpr as ne +import numpy as np +from scipy.interpolate import griddata + +from lib.image import batch_convert_color +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from plugins.train.trainer import trainer_config as cfg + + +logger = logging.getLogger(__name__) + + +@dataclass +class ConstantsColor: + """ Dataclass for holding constants for enhancing an image (ie contrast/color adjustment) + + Parameters + ---------- + clahe_base_contrast : int + The base number for Contrast Limited Adaptive Histogram Equalization + clahe_chance : float + Probability to perform Contrast Limited Adaptive Histogram Equilization + clahe_max_size : int + Maximum clahe window size + lab_adjust : :class:`numpy.ndarray` + Adjustment amounts for L*A*B augmentation + """ + clahe_base_contrast: int + """ int : The base number for Contrast Limited Adaptive Histogram Equalization """ + clahe_chance: float + """ float : Probability to perform Contrast Limited Adaptive Histogram Equilization """ + clahe_max_size: int + """ int : Maximum clahe window size""" + lab_adjust: np.ndarray + """ :class:`numpy.ndarray` : Adjustment amounts for L*A*B augmentation """ + + +@dataclass +class ConstantsTransform: + """ Dataclass for holding constants for transforming an image + + Parameters + ---------- + rotation : int + Rotation range for transformations + zoom : float + Zoom range for transformations + shift : float + Shift range for transformations + """ + rotation: int + """ int : Rotation range for transformations """ + zoom: float + """ float : Zoom range for transformations """ + shift: float + """ float : Shift range for transformations """ + flip: float + """ float : The chance to flip an image """ + + +@dataclass +class ConstantsWarp: + """ Dataclass for holding constants for warping an image + + Parameters + ---------- + maps : :class:`numpy.ndarray` + The stacked (x, y) mappings for image warping + pad : tuple[int, int] + The padding to apply for image warping + slices : slice + The slices for extracting a warped image + lm_edge_anchors : :class:`numpy.ndarray` + The edge anchors for landmark based warping + lm_grids : :class:`numpy.ndarray` + The grids for landmark based warping + """ + maps: np.ndarray + """ :class:`numpy.ndarray` : The stacked (x, y) mappings for image warping """ + pad: tuple[int, int] + """ :tuple[int, int] : The padding to apply for image warping """ + slices: slice + """ slice : The slices for extracting a warped image """ + scale: float + """ float : The scaling to apply to standard warping """ + lm_edge_anchors: np.ndarray + """ :class:`numpy.ndarray` : The edge anchors for landmark based warping """ + lm_grids: np.ndarray + """ :class:`numpy.ndarray` : The grids for landmark based warping """ + lm_scale: float + """ float : The scaling to apply to landmark based warping """ + + def __repr__(self) -> str: + """ Display shape/type information for arrays in __repr__ """ + params = {k: f"array[shape: {v.shape}, dtype: {v.dtype}]" + if isinstance(v, np.ndarray) else v + for k, v in self.__dict__.items()} + str_params = ", ".join(f"{k}={v}" for k, v in params.items()) + return f"{self.__class__.__name__}({str_params})" + + +@dataclass +class ConstantsAugmentation: + """ Dataclass for holding constants for Image Augmentation. + + Attributes + ---------- + color : :class:`ConstantsColor` + The constants for adjusting color/contrast in an image + transform : :class:`ConstantsTransform` + The constants for image transformation + warp : :class:`ConstantsTransform` + The constants for image warping + + Dataclass should be initialized using its :func:`from_config` method: + + Example + ------- + >>> constants = ConstantsAugmentation.from_config(processing_size=256, + ... batch_size=16) + """ + color: ConstantsColor + """ :class:`ConstantsColor` : The constants for adjusting color/contrast in an image """ + transform: ConstantsTransform + """ :class:`ConstantsTransform` : The constants for image transformation """ + warp: ConstantsWarp + """ :class:`ConstantsTransform` : The constants for image warping """ + + @classmethod + def _get_clahe(cls, size: int) -> tuple[int, float, int]: + """ Get the CLAHE constants from user config + + Parameters + ---------- + size : int + The size of image to augment the data for + + Returns + ------- + clahe_base_contrast : int + The base number for Contrast Limited Adaptive Histogram Equalization + clahe_chance : float + Probability to perform Contrast Limited Adaptive Histogram Equilization + clahe_max_size : int + Maximum clahe window size + """ + clahe_base_contrast = max(2, size // 128) + clahe_chance = cfg.color_clahe_chance() / 100 + clahe_max_size = cfg.color_clahe_max_size() + logger.debug("clahe_base_contrast: %s, clahe_chance: %s, clahe_max_size: %s", + clahe_base_contrast, clahe_chance, clahe_max_size) + return clahe_base_contrast, clahe_chance, clahe_max_size + + @classmethod + def _get_lab(cls) -> np.ndarray: + """ Load the random L*A*B augmentation constants + + Returns + ------- + :class:`numpy.ndarray` + Adjustment amounts for L*A*B augmentation + """ + amount_l = cfg.color_lightness() / 100. + amount_ab = cfg.color_ab() / 100. + + lab_adjust = np.array([amount_l, amount_ab, amount_ab], dtype="float32") + logger.debug("lab_adjust: %s", lab_adjust) + return lab_adjust + + @classmethod + def _get_color(cls, size: int) -> ConstantsColor: + """ Get the image enhancements constants from user config + + Parameters + ---------- + size : int + The size of image to augment the data for + + Returns + ------- + :class:`ConstantsColor` + The constants for image enhancement + """ + clahe_base_contrast, clahe_chance, clahe_max_size = cls._get_clahe(size) + retval = ConstantsColor(clahe_base_contrast=clahe_base_contrast, + clahe_chance=clahe_chance, + clahe_max_size=clahe_max_size, + lab_adjust=cls._get_lab()) + logger.debug(retval) + return retval + + @classmethod + def _get_transform(cls, size: int) -> ConstantsTransform: + """ Load the random transform constants + + Parameters + ---------- + size : int + The size of image to augment the data for + + Returns + ------- + :class:`ConstantsTransform` + The constants for image transformation + """ + retval = ConstantsTransform(rotation=cfg.rotation_range(), + zoom=cfg.zoom_amount() / 100., + shift=(cfg.shift_range() / 100.) * size, + flip=cfg.flip_chance() / 100.) + logger.debug(retval) + return retval + + @classmethod + def _get_warp_to_landmarks(cls, size: int, batch_size: int) -> tuple[np.ndarray, np.ndarray]: + """ Load the warp-to-landmarks augmentation constants + + Parameters + ---------- + size : int + The size of image to augment the data for + batch_size : int + The batch size that augmented data is being prepared for + + Returns + ------- + edge_anchors : :class:`numpy.ndarray` + The edge anchors for landmark based warping + grids : :class:`numpy.ndarray` + The grids for landmark based warping + """ + p_mx = size - 1 + p_hf = (size // 2) - 1 + edge_anchors = np.array([(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0), + (p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]).astype("int32") + edge_anchors = np.broadcast_to(edge_anchors, (batch_size, 8, 2)) + grids = np.mgrid[0: p_mx: complex(size), # type:ignore[misc] # pylint:disable=no-member + 0: p_mx: complex(size)].astype("float32") # type:ignore[misc] + + logger.debug("edge_anchors: (%s, %s), grids: (%s, %s)", + edge_anchors.shape, edge_anchors.dtype, + grids.shape, grids.dtype) # pylint:disable=no-member + return edge_anchors, grids + + @classmethod + def _get_warp(cls, size: int, batch_size: int) -> ConstantsWarp: + """ Load the warp augmentation constants + + Parameters + ---------- + size: int + The size of image to augment the data for + batch_size : int + The batch size that augmented data is being prepared for + + Returns + ------- + :class:`ConstantsTransform` + The constants for image warping + """ + lm_edge_anchors, lm_grids = cls._get_warp_to_landmarks(size, batch_size) + + warp_range = np.linspace(0, size, 5, dtype='float32') + warp_mapx = np.broadcast_to(warp_range, (batch_size, 5, 5)).astype("float32") + warp_mapy = np.broadcast_to(warp_mapx[0].T, (batch_size, 5, 5)).astype("float32") + warp_pad = int(1.25 * size) + + retval = ConstantsWarp(maps=np.stack((warp_mapx, warp_mapy), axis=1), + pad=(warp_pad, warp_pad), + slices=slice(warp_pad // 10, -warp_pad // 10), + scale=5 / 256 * size, # Normal random variable scale + lm_edge_anchors=lm_edge_anchors, + lm_grids=lm_grids, + lm_scale=2 / 256 * size) # Normal random variable scale + logger.debug(retval) + return retval + + @classmethod + def from_config(cls, + processing_size: int, + batch_size: int) -> ConstantsAugmentation: + """ Create a new dataclass instance from user config + + Parameters + ---------- + processing_size : int: + The size of image to augment the data for + batch_size : int + The batch size that augmented data is being prepared for + """ + logger.debug("Initializing %s(processing_size=%s, batch_size=%s)", + cls.__name__, processing_size, batch_size) + retval = cls(color=cls._get_color(processing_size), + transform=cls._get_transform(processing_size), + warp=cls._get_warp(processing_size, batch_size)) + logger.debug(retval) + return retval + + +class ImageAugmentation(): + """ Performs augmentation on batches of training images. + + Parameters + ---------- + batch_size : int + The number of images that will be fed through the augmentation functions at once. + processing_size: int + The largest input or output size of the model. This is the size that images are processed + at. + """ + def __init__(self, batch_size: int, processing_size: int) -> None: + logger.debug(parse_class_init(locals())) + self._processing_size = processing_size + self._batch_size = batch_size + self._constants = ConstantsAugmentation.from_config(processing_size, batch_size) + logger.debug("Initialized %s", self.__class__.__name__) + + def __repr__(self) -> str: + """ Pretty print this object """ + return (f"{self.__class__.__name__}(batch_size={self._batch_size}, " + f"processing_size={self._processing_size})") + + # <<< COLOR AUGMENTATION >>> # + def _random_lab(self, batch: np.ndarray) -> None: + """ Perform random color/lightness adjustment in L*a*b* color space on a batch of + images + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format of uint8 dtype. + """ + randoms = np.random.uniform(-self._constants.color.lab_adjust, + self._constants.color.lab_adjust, + size=(self._batch_size, 1, 1, 3)).astype("float32") + logger.trace("Random LAB adjustments: %s", randoms) # type:ignore[attr-defined] + # Iterating through the images and channels is much faster than numpy.where and slightly + # faster than numexpr.where. + for image, rand in zip(batch, randoms): + for idx in range(rand.shape[-1]): + adjustment = rand[:, :, idx] + if adjustment >= 0: + image[:, :, idx] = ((255 - image[:, :, idx]) * adjustment) + image[:, :, idx] + else: + image[:, :, idx] = image[:, :, idx] * (1 + adjustment) + + def _random_clahe(self, batch: np.ndarray) -> None: + """ Randomly perform Contrast Limited Adaptive Histogram Equalization on + a batch of images + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format of uint8 dtype. + """ + base_contrast = self._constants.color.clahe_base_contrast + + batch_random = np.random.rand(self._batch_size) + indices = np.where(batch_random < self._constants.color.clahe_chance)[0] + if not np.any(indices): + return + grid_bases = np.random.randint(self._constants.color.clahe_max_size + 1, + size=indices.shape[0], + dtype="uint8") + grid_sizes = (grid_bases * (base_contrast // 2)) + base_contrast + logger.trace("Adjusting Contrast. Grid Sizes: %s", grid_sizes) # type:ignore[attr-defined] + + clahes = [cv2.createCLAHE(clipLimit=2.0, + tileGridSize=(grid_size, grid_size)) + for grid_size in grid_sizes] + + for idx, clahe in zip(indices, clahes): + batch[idx, :, :, 0] = clahe.apply(batch[idx, :, :, 0], ) + + def color_adjust(self, batch: np.ndarray) -> np.ndarray: + """ Perform color augmentation on the passed in batch. + + The color adjustment parameters are set in :file:`config.train.ini` + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format of uint8 dtype. + + Returns + ---------- + :class:`numpy.ndarray` + A 4-dimensional array of the same shape as :attr:`batch` with color augmentation + applied. + """ + logger.trace("Augmenting color") # type:ignore[attr-defined] + batch = batch_convert_color(batch, "BGR2LAB") + self._random_lab(batch) + self._random_clahe(batch) + batch = batch_convert_color(batch, "LAB2BGR") + return batch + + # <<< IMAGE AUGMENTATION >>> # + def transform(self, batch: np.ndarray): + """ Perform random transformation on the passed in batch. + + The transformation parameters are set in :file:`config.train.ini` + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `channels`) and in `BGR` format. + """ + logger.trace("Randomly transforming image") # type:ignore[attr-defined] + rotation = np.random.uniform(-self._constants.transform.rotation, + self._constants.transform.rotation, + size=self._batch_size).astype("float32") + scale = np.random.uniform(1 - self._constants.transform.zoom, + 1 + self._constants.transform.zoom, + size=self._batch_size).astype("float32") + + tform = np.random.uniform(-self._constants.transform.shift, + self._constants.transform.shift, + size=(self._batch_size, 2)).astype("float32") + mats = np.array( + [cv2.getRotationMatrix2D((self._processing_size // 2, self._processing_size // 2), + rot, + scl) + for rot, scl in zip(rotation, scale)]).astype("float32") + mats[..., 2] += tform + + for image, mat in zip(batch, mats): + cv2.warpAffine(image, + mat, + (self._processing_size, self._processing_size), + dst=image, + borderMode=cv2.BORDER_REPLICATE) + + logger.trace("Randomly transformed image") # type:ignore[attr-defined] + + def random_flip(self, batch: np.ndarray): + """ Perform random horizontal flipping on the passed in batch. + + The probability of flipping an image is set in :file:`config.train.ini` + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `channels`) and in `BGR` format. + """ + logger.trace("Randomly flipping image") # type:ignore[attr-defined] + randoms = np.random.rand(self._batch_size) + indices = np.where(randoms <= self._constants.transform.flip)[0] + batch[indices] = batch[indices, :, ::-1] + logger.trace("Randomly flipped %s images of %s", # type:ignore[attr-defined] + len(indices), self._batch_size) + + def _random_warp(self, batch: np.ndarray) -> np.ndarray: + """ Randomly warp the input batch + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + + Returns + ---------- + :class:`numpy.ndarray` + A 4-dimensional array of the same shape as :attr:`batch` with warping applied. + """ + logger.trace("Randomly warping batch") # type:ignore[attr-defined] + slices = self._constants.warp.slices + rands = np.random.normal(size=(self._batch_size, 2, 5, 5), + scale=self._constants.warp.scale).astype("float32") + batch_maps = ne.evaluate("m + r", local_dict={"m": self._constants.warp.maps, "r": rands}) + + batch_interp = np.array([[cv2.resize(map_, self._constants.warp.pad)[slices, slices] + for map_ in maps] + for maps in batch_maps]) + warped_batch = np.array([cv2.remap(image, interp[0], interp[1], cv2.INTER_LINEAR) + for image, interp in zip(batch, batch_interp)]) + + logger.trace("Warped image shape: %s", warped_batch.shape) # type:ignore[attr-defined] + return warped_batch + + def _random_warp_landmarks(self, + batch: np.ndarray, + batch_src_points: np.ndarray, + batch_dst_points: np.ndarray) -> np.ndarray: + """ From dfaker. Warp the image to a similar set of landmarks from the opposite side + + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + batch_src_points : :class:`numpy.ndarray` + A batch of 68 point landmarks for the source faces. This is a 3-dimensional array in + the shape (`batchsize`, `68`, `2`). + batch_dst_points : :class:`numpy.ndarray` + A batch of randomly chosen closest match destination faces landmarks. This is a + 3-dimensional array in the shape (`batchsize`, `68`, `2`). + + Returns + ---------- + :class:`numpy.ndarray` + A 4-dimensional array of the same shape as :attr:`batch` with warping applied. + """ + logger.trace("Randomly warping landmarks") # type:ignore[attr-defined] + edge_anchors = self._constants.warp.lm_edge_anchors + grids = self._constants.warp.lm_grids + + batch_dst = batch_dst_points + np.random.normal(size=batch_dst_points.shape, + scale=self._constants.warp.lm_scale) + + face_cores = [cv2.convexHull(np.concatenate([src[17:], dst[17:]], axis=0)) + for src, dst in zip(batch_src_points.astype("int32"), + batch_dst.astype("int32"))] + + batch_src = np.append(batch_src_points, edge_anchors, axis=1) + batch_dst = np.append(batch_dst, edge_anchors, axis=1) + + rem_indices = [list(set(idx for fpl in (src, dst) + for idx, (pty, ptx) in enumerate(fpl) + if cv2.pointPolygonTest(face_core, (pty, ptx), False) >= 0)) + for src, dst, face_core in zip(batch_src[:, :18, :], + batch_dst[:, :18, :], + face_cores)] + lbatch_src = [np.delete(src, idxs, axis=0) for idxs, src in zip(rem_indices, batch_src)] + lbatch_dst = [np.delete(dst, idxs, axis=0) for idxs, dst in zip(rem_indices, batch_dst)] + + grid_z = np.array([griddata(dst, src, (grids[0], grids[1]), method="linear") + for src, dst in zip(lbatch_src, lbatch_dst)]) + maps = grid_z.reshape((self._batch_size, + self._processing_size, + self._processing_size, + 2)).astype("float32") + + warped_batch = np.array([cv2.remap(image, + map_[..., 1], + map_[..., 0], + cv2.INTER_LINEAR, + borderMode=cv2.BORDER_TRANSPARENT) + for image, map_ in zip(batch, maps)]) + logger.trace("Warped batch shape: %s", warped_batch.shape) # type:ignore[attr-defined] + return warped_batch + + def warp(self, + batch: np.ndarray, + to_landmarks: bool = False, + batch_src_points: np.ndarray | None = None, + batch_dst_points: np.ndarray | None = None + ) -> np.ndarray: + + """ Perform random warping on the passed in batch by one of two methods. + + Parameters + ---------- + batch : :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + to_landmarks : bool, optional + If ``False`` perform standard random warping of the input image. If ``True`` perform + warping to semi-random similar corresponding landmarks from the other side. Default: + ``False`` + batch_src_points : :class:`numpy.ndarray`, optional + Only used when :attr:`to_landmarks` is ``True``. A batch of 68 point landmarks for the + source faces. This is a 3-dimensional array in the shape (`batchsize`, `68`, `2`). + Default: ``None`` + batch_dst_points : :class:`numpy.ndarray`, optional + Only used when :attr:`to_landmarks` is ``True``. A batch of randomly chosen closest + match destination faces landmarks. This is a 3-dimensional array in the shape + (`batchsize`, `68`, `2`). Default ``None`` + + Returns + ---------- + :class:`numpy.ndarray` + A 4-dimensional array of the same shape as :attr:`batch` with warping applied. + """ + if to_landmarks: + assert batch_src_points is not None + assert batch_dst_points is not None + return self._random_warp_landmarks(batch, batch_src_points, batch_dst_points) + return self._random_warp(batch) + + +__all__ = get_module_objects(__name__) diff --git a/lib/training/cache.py b/lib/training/cache.py new file mode 100644 index 0000000000..e3429a48f9 --- /dev/null +++ b/lib/training/cache.py @@ -0,0 +1,703 @@ +#!/usr/bin/env python3 +""" Holds the data cache for training data generators """ +from __future__ import annotations +import logging +import os +import typing as T + +from dataclasses import dataclass, field +from threading import Lock + +import cv2 +import numpy as np +from tqdm import tqdm + +from lib.align import CenteringType, DetectedFace, LandmarkType +from lib.image import read_image_batch, read_image_meta_batch +from lib.logger import parse_class_init +from lib.utils import FaceswapError, get_module_objects +from plugins.train import train_config as cfg + +if T.TYPE_CHECKING: + from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict + from lib import align + +logger = logging.getLogger(__name__) +_FACE_CACHES: dict[str, Cache] = {} + + +@dataclass +class _MaskConfig: + """ Holds the constants required for manipulating training masks """ + # pylint:disable=unnecessary-lambda + penalized: bool = field(default_factory=lambda: cfg.Loss.penalized_mask_loss()) + learn: bool = field(default_factory=lambda: cfg.Loss.learn_mask()) + mask_type: str | None = field(default_factory=lambda: None + if cfg.Loss.mask_type() == "none" + else cfg.Loss.mask_type()) + dilation: float = field(default_factory=lambda: cfg.Loss.mask_dilation()) + kernel: int = field(default_factory=lambda: cfg.Loss.mask_blur_kernel()) + threshold: int = field(default_factory=lambda: cfg.Loss.mask_threshold()) + multiplier_enabled: bool = field( + default_factory=lambda: ((cfg.Loss.eye_multiplier() > 1 or cfg.Loss.mouth_multiplier() > 1) + and cfg.Loss.penalized_mask_loss())) + + @property + def mask_enabled(self) -> bool: + """ bool : ``True`` if any of :attr:`penalized` or :attr:`learn` are true and + :attr:`mask_type` is not ``None`` """ + return self.mask_type is not None and (self.learn or self.penalized) + + +class _MaskProcessing: + """ Handle the extraction and processing of masks from faceswap PNG headers for caching + + Parameters + ---------- + size : int + The largest output size of the model + coverage_ratio : float + The coverage ratio that the model is using. + centering : Literal["face", "head", "legacy"] + """ + def __init__(self, + size: int, + coverage_ratio: float, + centering: CenteringType) -> None: + + assert isinstance(size, int) + assert isinstance(coverage_ratio, float) + assert centering in T.get_args(CenteringType) + + self._size = size + self._coverage = coverage_ratio + self._centering: CenteringType = centering + + self._config = _MaskConfig() + logger.debug("Initialized %s", self) + + def __repr__(self) -> str: + """ Pretty print for logging """ + params = f"coverage_ratio={repr(self._coverage)}, centering={repr(self._centering)}" + return f"{self.__class__.__name__}({params})" + + def _check_mask_exists(self, filename: str, detected_face: DetectedFace) -> None: + """ Check that the requested mask exists for the current detected face + + Parameters + ---------- + filename : str + The file path for the current image + detected_face : :class:`~lib.align.detected_face.DetectedFace` + The detected face object that holds the masks + + Raises + ------ + FaceswapError + If the requested mask type is not available an error is returned along with a list + of available masks + """ + if self._config.mask_type in detected_face.mask: + return + + exist_masks = list(detected_face.mask) + msg = "No masks exist for this face" + if exist_masks: + msg = f"The masks that exist for this face are: {exist_masks}" + raise FaceswapError( + f"You have selected the mask type '{self._config.mask_type}' but at least one " + "face does not contain the selected mask.\n" + f"The face that failed was: '{filename}'\n{msg}") + + def _preprocess(self, detected_face: DetectedFace, mask_type: str) -> align.aligned_mask.Mask: + """ Apply pre-processing to the mask + + Parameters + ---------- + detected_face : :class:`~lib.align.detected_face.DetectedFace` + The detected face object that holds the masks + mask_type : str + The stored mask type to use + + Returns + ------- + :class:`~lib.align.aligned_mask.Mask` + The pre-processed mask at its stored size and crop + """ + mask = detected_face.mask[mask_type] + mask.set_dilation(self._config.dilation) + mask.set_blur_and_threshold(blur_kernel=self._config.kernel, + threshold=self._config.threshold) + return mask + + def _crop_and_resize(self, + detected_face: DetectedFace, + mask: align.aligned_mask.Mask) -> np.ndarray: + """ Crop and resize the mask to the correct centering and training size + + Parameters + ---------- + detected_face : :class:`~lib.align.detected_face.DetectedFace` + The detected face object that holds the masks + mask : :class:`~lib.align.aligned_mask.Mask` + The pre-processed mask at its stored size and crop + + Returns + ------- + :class:`numpy.ndarray` + The processed, cropped and resized final mask + """ + pose = detected_face.aligned.pose + mask.set_sub_crop(pose.offset[mask.stored_centering], + pose.offset[self._centering], + self._centering, + self._coverage, + detected_face.aligned.y_offset) + face_mask = mask.mask + if self._size != face_mask.shape[0]: + interpolator = cv2.INTER_CUBIC if mask.stored_size < self._size else cv2.INTER_AREA + face_mask = cv2.resize(face_mask, + (self._size, self._size), + interpolation=interpolator)[..., None] + return face_mask + + def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> np.ndarray | None: + """ Obtain the training sized face mask from the DetectedFace for the requested mask type. + + Parameters + ---------- + filename : str + The file path for the current image + detected_face : :class:`~lib.align.detected_face.DetectedFace` + The detected face object that holds the masks + + Returns + ------- + :class:`numpy.ndarray` | None + The face mask used for training or ``None`` if masks are disabled + """ + if not self._config.mask_enabled: + return None + + assert self._config.mask_type is not None + self._check_mask_exists(filename, detected_face) + mask = self._preprocess(detected_face, self._config.mask_type) + retval = self._crop_and_resize(detected_face, mask) + logger.trace("Obtained face mask for: %s %s", # type:ignore[attr-defined] + filename, retval.shape) + return retval + + def _get_localized_mask(self, + filename: str, + detected_face: DetectedFace, + area: T.Literal["eye", "mouth"]) -> np.ndarray | None: + """ Obtain a localized mask for the given area if it is required for training. + + Parameters + ---------- + filename : str + The file path for the current image + detected_face : :class:`~lib.align.detected_face.DetectedFace` + The detected face object that holds the masks + area : Literal["eye", "mouth"] + The area of the face to obtain the mask for + + Raises + ------ + :class:`~lib.utils.FaceswapError` + If landmark data is not available to generate the localized mask + """ + if not self._config.multiplier_enabled: + return None + + try: + mask = detected_face.get_landmark_mask(area, self._size // 16, 2.5) + except FaceswapError as err: + logger.error(str(err)) + raise FaceswapError("Eye/Mouth multiplier masks could not be generated due to missing " + f"landmark data. The file that failed was: '{filename}'") from err + logger.trace("Caching localized '%s' mask for: %s %s", # type:ignore[attr-defined] + area, filename, mask.shape) + return mask + + def __call__(self, filename: str, detected_face: DetectedFace) -> None: + """ Prepare the masks required for training and compile into a single compressed array + within the given DetectedFaces object + + Parameters + ---------- + filename : str + The file path for the image that masks are to be prepared for + detected_face : :class:`~lib.align.detected_face.DetectedFace` + The detected face object that holds the masks + """ + masks = [(self._get_face_mask(filename, detected_face))] + for area in T.get_args(T.Literal["eye", "mouth"]): + masks.append(self._get_localized_mask(filename, detected_face, area)) + + detected_face.store_training_masks(masks, delete_masks=True) + logger.trace("Stored masks for filename: %s)", filename) # type:ignore[attr-defined] + + +def _check_reset(face_cache: "Cache") -> bool: + """ Check whether a given cache needs to be reset because a face centering change has been + detected in the other cache. + + Parameters + ---------- + face_cache : :class:`Cache` + The cache object that is checking whether it should reset + + Returns + ------- + bool + ``True`` if the given object should reset the cache, otherwise ``False`` + """ + check_cache = next((cache for cache in _FACE_CACHES.values() if cache != face_cache), None) + retval = False if check_cache is None else check_cache.check_reset() + return retval + + +@dataclass +class _CacheConfig: + """ Holds the configuration options for the cache """ + size: int + """ int : The size to load images at """ + centering: CenteringType + """ Literal["face", "head", "legacy"] : The centering type to train at """ + coverage: float + """ float : The selected coverage ration for training """ + + +class Cache(): + """ A thread safe mechanism for collecting and holding face meta information (masks, + alignments data etc.) for multiple :class:`~lib.training.generator.TrainingDataGenerator`. + + Each side may have up to 3 generators (training, preview and time-lapse). To conserve RAM + these need to share access to the same face information for the images they are processing. + + As the cache is populated at run-time, thread safe writes are required for the first epoch. + Following that, the cache is only used for reads, which is thread safe intrinsically. + + It would probably be quicker to set locks on each individual face, but for code complexity + reasons, and the fact that the lock is only taken up during cache population, and it should + only be being read multiple times on save iterations, we lock the whole cache during writes. + + Parameters + ---------- + filenames : list[str] + The filenames of all the images. This can either be the full path or the base name. If the + full paths are passed in, they are stripped to base name for use as the cache key. + size : int + The largest output size of the model + coverage_ratio : float + The coverage ratio that the model is using. + """ + def __init__(self, + filenames: list[str], + size: int, + coverage_ratio: float) -> None: + logger.debug(parse_class_init(locals())) + self._lock = Lock() + self._cache_info = {"cache_full": False, "has_reset": False} + self._partially_loaded: list[str] = [] + + self._image_count = len(filenames) + self._cache: dict[str, DetectedFace] = {} + self._aligned_landmarks: dict[str, np.ndarray] = {} + self._extract_version = 0.0 + + self._config = _CacheConfig(size=size, + centering=T.cast(CenteringType, cfg.centering()), + coverage=coverage_ratio) + self._mask_prepare = _MaskProcessing(size, coverage_ratio, self._config.centering) + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def cache_full(self) -> bool: + """ bool : ``True`` if the cache has been fully populated. ``False`` if there are items + still to be cached. """ + if self._cache_info["cache_full"]: + return self._cache_info["cache_full"] + with self._lock: + return self._cache_info["cache_full"] + + @property + def aligned_landmarks(self) -> dict[str, np.ndarray]: + """ dict[str, :class:`numpy.ndarray`] : filename as key, aligned landmarks as value. """ + # Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate + # all of the aligned landmarks for the entire cache. + if not self._aligned_landmarks: + with self._lock: + # For Warp-To-Landmarks a race condition can occur where this is referenced from + # the opposite side prior to it being populated, so block on a lock. + self._aligned_landmarks = {key: face.aligned.landmarks + for key, face in self._cache.items()} + return self._aligned_landmarks + + @property + def size(self) -> int: + """ int : The pixel size of the cropped aligned face """ + return self._config.size + + def get_items(self, filenames: list[str]) -> list[DetectedFace]: + """ Obtain the cached items for a list of filenames. The returned list is in the same order + as the provided filenames. + + Parameters + ---------- + filenames : list[str] + A list of image filenames to obtain the cached data for + + Returns + ------- + list[:class:`~lib.align.detected_face.DetectedFace`] + List of DetectedFace objects holding the cached metadata. The list returns in the same + order as the filenames received + """ + return [self._cache[os.path.basename(filename)] for filename in filenames] + + def check_reset(self) -> bool: + """ Check whether this cache has been reset due to a face centering change, and reset the + flag if it has. + + Returns + ------- + bool + ``True`` if the cache has been reset because of a face centering change due to + legacy alignments, otherwise ``False``. """ + retval = self._cache_info["has_reset"] + if retval: + logger.debug("Resetting 'has_reset' flag") + self._cache_info["has_reset"] = False + return retval + + def _reset_cache(self, set_flag: bool) -> None: + """ In the event that a legacy extracted face has been seen, and centering is not legacy + the cache will need to be reset for legacy centering. + + Parameters + ---------- + set_flag: bool + ``True`` if the flag should be set to indicate that the cache is being reset because of + a legacy face set/centering mismatch. ``False`` if the cache is being reset because it + has detected a reset flag from the opposite cache. + """ + if set_flag: + logger.warning("You are using legacy extracted faces but have selected '%s' centering " + "which is incompatible. Switching centering to 'legacy'", + self._config.centering) + cfg.centering.set("legacy") + self._config.centering = "legacy" + self._cache = {} + self._cache_info["cache_full"] = False + if set_flag: + self._cache_info["has_reset"] = True + + def _validate_version(self, png_meta: PNGHeaderDict, filename: str) -> None: + """ Validate that there are not a mix of v1.0 extracted faces and v2.x faces. + + Parameters + ---------- + png_meta : :class:`~lib.align.alignments.PNGHeaderDict` + The information held within the Faceswap PNG Header + filename: str + The full path to the file being validated + + Raises + ------ + :class:`~lib.utils.FaceswapError` + If a version 1.0 face appears in a 2.x set or vice versa + """ + alignment_version = png_meta["source"]["alignments_version"] + + if not self._extract_version: + logger.debug("Setting initial extract version: %s", alignment_version) + self._extract_version = alignment_version + if alignment_version == 1.0 and self._config.centering != "legacy": + self._reset_cache(True) + return + + if (self._extract_version == 1.0 and alignment_version > 1.0) or ( + alignment_version == 1.0 and self._extract_version > 1.0): + raise FaceswapError("Mixing legacy and full head extracted facesets is not supported. " + "The following folder contains a mix of extracted face types: " + f"'{os.path.dirname(filename)}'") + + self._extract_version = min(alignment_version, self._extract_version) + + def _load_detected_face(self, + filename: str, + alignments: PNGHeaderAlignmentsDict) -> DetectedFace: + """ Load a :class:`~lib.align.detected_face.DetectedFace` object and load its associated + `aligned` property. + + Parameters + ---------- + filename : str + The file path for the current image + alignments : :class:`~lib.align.alignments.PNGHeaderAlignmentsDict` + The alignments for a single face, extracted from a PNG header + + Returns + ------- + :class:`~lib.align.detected_face.DetectedFace` + The loaded Detected Face object + """ + y_offset = cfg.vertical_offset() + detected_face = DetectedFace() + detected_face.from_png_meta(alignments) + detected_face.load_aligned(None, + size=self._config.size, + centering=self._config.centering, + coverage_ratio=self._config.coverage, + y_offset=y_offset / 100., + is_aligned=True, + is_legacy=self._extract_version == 1.0) + logger.trace("Cached aligned face for: %s", filename) # type:ignore[attr-defined] + return detected_face + + def _populate_cache(self, + needs_cache: list[str], + metadata: list[PNGHeaderDict], + filenames: list[str]) -> None: + """ Populate the given items into the cache + + Parameters + ---------- + needs_cache : list[str] + The full path to files within this batch that require caching + metadata : list[:class:`~lib.align.alignments.PNGHeaderDict`] + The faceswap metadata loaded from the image png header + filenames : list[str] + Full path to the filenames that are being loaded in this batch + """ + for filename in needs_cache: + key = os.path.basename(filename) + meta = metadata[filenames.index(filename)] + + # Version Check + self._validate_version(meta, filename) + if self._partially_loaded: # Faces already loaded for Warp-to-landmarks + self._partially_loaded.remove(key) + detected_face = self._cache[key] + else: + detected_face = self._load_detected_face(filename, meta["alignments"]) + + self._mask_prepare(filename, detected_face) + self._cache[key] = detected_face + + def _get_batch_with_metadata(self, + filenames: list[str]) -> tuple[np.ndarray, list[PNGHeaderDict]]: + """ Load a batch of images along with their faceswap metadata for loading into the cache + + Parameters + ---------- + filenames : list[str] + Full path to the images to be loaded + + Returns + ------- + batch : :class:`numpy.ndarray` + The batch of images in a single array + metadata : :class:`~lib.align.alignments.PNGHeaderDict` + The faceswap metadata corresponding to each image in the batch + """ + try: + batch, metadata = read_image_batch(filenames, with_metadata=True) + except ValueError as err: + if "inhomogeneous" in str(err): + raise FaceswapError( + "There was an error loading a batch of images. This is most likely due to " + "non-faceswap extracted faces in your training folder." + "\nAll training images should be Faceswap extracted faces." + "\nAll training images should be the same size." + f"\nThe files that caused this error are: {filenames}") from err + raise + if len(batch.shape) == 1: + folder = os.path.dirname(filenames[0]) + keys = [os.path.basename(filename) for filename in filenames] + details = [ + f"{key} ({f'{img.shape[1]}px' if isinstance(img, np.ndarray) else type(img)})" + for key, img in zip(keys, batch)] + msg = (f"There are mismatched image sizes in the folder '{folder}'. All training " + "images for each side must have the same dimensions.\nThe batch that " + f"failed contains the following files:\n{details}.") + raise FaceswapError(msg) + return batch, metadata + + def _update_cache_full(self, filenames: list[str]) -> None: + """ Check if cache is full and update the "cache_full" flag in :attr:`_cache_info` if so + + Parameters + ---------- + filenames : list[str] + Full path to the filenames being processed in the current batch + """ + cache_full = not self._partially_loaded and len(self._cache) == self._image_count + if cache_full: + logger.verbose("Cache filled: '%s'", # type:ignore[attr-defined] + os.path.dirname(filenames[0])) + self._cache_info["cache_full"] = cache_full + + def cache_metadata(self, filenames: list[str]) -> np.ndarray: + """ Obtain the batch with metadata for items that need caching and cache DetectedFace + objects to :attr:`_cache`. + + Parameters + ---------- + filenames : list[str] + List of full paths to image file names + + Returns + ------- + :class:`numpy.ndarray` + The batch of face images loaded from disk + """ + keys = [os.path.basename(filename) for filename in filenames] + with self._lock: + if _check_reset(self): + self._reset_cache(False) + + needs_cache = [filename for filename, key in zip(filenames, keys) + if key not in self._cache or key in self._partially_loaded] + logger.trace("Needs cache: %s", needs_cache) # type:ignore[attr-defined] + + if not needs_cache: # Metadata already cached. Just get images + logger.debug("All metadata already cached for: %s", keys) + return read_image_batch(filenames) + + batch, metadata = self._get_batch_with_metadata(filenames) + self._populate_cache(needs_cache, metadata, filenames) + self._update_cache_full(filenames) + + return batch + + def pre_fill(self, filenames: list[str], side: T.Literal["a", "b"]) -> None: + """ When warp to landmarks is enabled, the cache must be pre-filled, as each side needs + access to the other side's alignments. + + Parameters + ---------- + filenames : list[str] + The list of full paths to the images to load the metadata from + side : Literal["a", "b"] + The side of the model being cached. Used for info output + + Raises + ------ + :class:`~lib.utils.FaceSwapError` + If unsupported landmark type exists or a non-faceswap image is loaded + """ + with self._lock: + for filename, meta in tqdm(read_image_meta_batch(filenames), + desc=f"WTL: Caching Landmarks ({side.upper()})", + total=len(filenames), + leave=False): + if "itxt" not in meta or "alignments" not in meta["itxt"]: + raise FaceswapError(f"Invalid face image found. Aborting: '{filename}'") + + meta = meta["itxt"] + key = os.path.basename(filename) + self._validate_version(meta, filename) + detected_face = self._load_detected_face(filename, meta["alignments"]) + + aligned = detected_face.aligned + assert aligned is not None + if aligned.landmark_type != LandmarkType.LM_2D_68: + raise FaceswapError("68 Point facial Landmarks are required for Warp-to-" + f"landmarks. The face that failed was: '{filename}'") + + self._cache[key] = detected_face + self._partially_loaded.append(key) + + +def get_cache(side: T.Literal["a", "b"], + filenames: list[str] | None = None, + size: int | None = None, + coverage_ratio: float | None = None) -> Cache: + """ Obtain a :class:`Cache` object for the given side. If the object does not pre-exist then + create it. + + Parameters + ---------- + side : Literal["a", "b"] + The side of the model to obtain the cache for + filenames : list[str] | None, optional + The filenames of all the images. This can either be the full path or the base name. If the + full paths are passed in, they are stripped to base name for use as the cache key. Must be + passed for the first call of this function for each side. For subsequent calls this + parameter is ignored. Default: ``None`` + size: int | None, optional + The largest output size of the model. Must be passed for the first call of this function + for each side. For subsequent calls this parameter is ignored. Default: ``None`` + coverage_ratio : float | None, optional + The coverage ratio that the model is using. Must be passed for the first call of this + function for each side. For subsequent calls this parameter is ignored. Default: ``None`` + + Returns + ------- + :class:`Cache` + The face meta information cache for the requested side + """ + assert side in ("a", "b") + if not _FACE_CACHES.get(side): + assert filenames is not None, "filenames must be provided for first call to cache" + assert size is not None, "size must be provided for first call to cache" + assert coverage_ratio is not None, ("coverage_ratio must be provided for first call to " + "cache") + logger.debug("Creating cache. side: %s, size: %s, coverage_ratio: %s", + side, size, coverage_ratio) + _FACE_CACHES[side] = Cache(filenames, size, coverage_ratio) + return _FACE_CACHES[side] + + +class RingBuffer(): + """ Rolling buffer for holding training/preview batches + + Parameters + ---------- + batch_size : int + The batch size to create the buffer for + image_shape : tuple[int, int, int] + The height/width/channels shape of a single image in the batch + buffer_size : int, optional + The number of arrays to hold in the rolling buffer. Default: `2` + dtype : str, optional + The datatype to create the buffer as. Default: `"uint8"` + """ + def __init__(self, + batch_size: int, + image_shape: tuple[int, int, int], + buffer_size: int = 2, + dtype: str = "uint8") -> None: + logger.debug(parse_class_init(locals())) + self._max_index = buffer_size - 1 + self._index = 0 + self._buffer = [np.empty((batch_size, *image_shape), dtype=dtype) + for _ in range(buffer_size)] + logger.debug("Initialized: %s", self) + + def __repr__(self) -> str: + """ Pretty string representation for logging """ + params = {"batch_size": repr(self._buffer[0].shape[0]), + "image_shape": repr(self._buffer[0].shape[1:]), + "buffer_size": repr(len(self._buffer)), + "dtype": repr(str(self._buffer[0].dtype))} + str_params = [f"{k}={v}" for k, v in params.items()] + return f"{self.__class__.__name__}({', '.join(str_params)})" + + def __call__(self) -> np.ndarray: + """ Obtain the next array from the ring buffer + + Returns + ------- + :class:`np.ndarray` + A pre-allocated numpy array from the buffer + """ + retval = self._buffer[self._index] + self._index += 1 if self._index < self._max_index else -self._max_index + return retval + + +__all__ = get_module_objects(__name__) diff --git a/lib/training/generator.py b/lib/training/generator.py new file mode 100644 index 0000000000..f674e07aae --- /dev/null +++ b/lib/training/generator.py @@ -0,0 +1,969 @@ +#!/usr/bin/env python3 +""" Handles Data Augmentation for feeding Faceswap Models """ +from __future__ import annotations +import logging +import os +import typing as T + +from concurrent import futures +from random import shuffle, choice + +import cv2 +import numpy as np +import numexpr as ne +from lib.align import AlignedFace, DetectedFace +from lib.align.aligned_face import CenteringType +from lib.image import read_image_batch +from lib.multithreading import BackgroundGenerator +from lib.utils import FaceswapError, get_module_objects +from plugins.train import train_config as mod_cfg +from plugins.train.trainer import trainer_config as trn_cfg + +from . import ImageAugmentation +from .cache import get_cache, RingBuffer + +if T.TYPE_CHECKING: + from collections.abc import Generator + from plugins.train.model._base import ModelBase + from .cache import Cache + +logger = logging.getLogger(__name__) +BatchType = tuple[np.ndarray, list[np.ndarray]] + + +class DataGenerator(): # pylint:disable=too-many-instance-attributes + """ Parent class for Training and Preview Data Generators. + + This class is called from :mod:`plugins.train.trainer._base` and launches a background + iterator that compiles augmented data, target data and sample data. + + Parameters + ---------- + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding + side: {'a' or 'b'} + The side of the model that this iterator is for. + images: list + A list of image paths that will be used to compile the final augmented data from. + batch_size: int + The batch size for this iterator. Images will be returned in :class:`numpy.ndarray` + objects of this size from the iterator. + """ + def __init__(self, + model: ModelBase, + side: T.Literal["a", "b"], + images: list[str], + batch_size: int) -> None: + logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " + "batch_size: %s)", self.__class__.__name__, model.name, side, + len(images), batch_size) + self._side = side + self._images = images + self._batch_size = batch_size + + self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes) + self._output_sizes = self._get_output_sizes(model) + self._model_input_size = max(img[1] for img in model.input_shapes) + + self._coverage_ratio = model.coverage_ratio + self._color_order = model.color_order.lower() + self._use_mask = mod_cfg.Loss.mask_type() and (mod_cfg.Loss.penalized_mask_loss() or + mod_cfg.Loss.learn_mask()) + + self._validate_samples() + self._buffer = RingBuffer(batch_size, + (self._process_size, self._process_size, self._total_channels), + dtype="uint8") + self._face_cache: Cache = get_cache(side, + filenames=images, + size=self._process_size, + coverage_ratio=self._coverage_ratio) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _total_channels(self) -> int: + """int: The total number of channels, including mask channels that the target image + should hold. """ + channels = 3 + if mod_cfg.Loss.mask_type() and (mod_cfg.Loss.learn_mask() or + mod_cfg.Loss.penalized_mask_loss()): + channels += 1 + + mults = [area + for area, amount in zip(["eye", "mouth"], + [mod_cfg.Loss.eye_multiplier(), + mod_cfg.Loss.mouth_multiplier()]) + if amount > 1] + if mod_cfg.Loss.penalized_mask_loss() and mults: + channels += len(mults) + return channels + + def _get_output_sizes(self, model: ModelBase) -> list[int]: + """ Obtain the size of each output tensor for the model. + + Parameters + ---------- + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding + + Returns + ------- + list + A list of integers for the model output size for the current side + """ + out_shapes = model.output_shapes + split = len(out_shapes) // 2 + side_out = out_shapes[:split] if self._side == "a" else out_shapes[split:] + retval = [shape[1] for shape in side_out if shape[-1] != 1] + logger.debug("side: %s, model output shapes: %s, output sizes: %s", + self._side, model.output_shapes, retval) + return retval + + def minibatch_ab(self, do_shuffle: bool = True) -> Generator[BatchType, None, None]: + """ A Background iterator to return augmented images, samples and targets. + + The exit point from this class and the sole attribute that should be referenced. Called + from :mod:`plugins.train.trainer._base`. Returns an iterator that yields images for + training, preview and time-lapses. + + Parameters + ---------- + do_shuffle: bool, optional + Whether data should be shuffled prior to loading from disk. If true, each time the full + list of filenames are processed, the data will be reshuffled to make sure they are not + returned in the same order. Default: ``True`` + + Yields + ------ + feed: list + 4-dimensional array of faces to feed the training the model (:attr:`x` parameter for + :func:`keras.models.model.train_on_batch`.). The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + targets: list + List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each + output of the model. The format of these arrays will be (`batch size`, `height`, + `width`, `x`). This is the :attr:`y` parameter for + :func:`keras.models.model.train_on_batch`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + logger.debug("do_shuffle: %s", do_shuffle) + args = (do_shuffle, ) + batcher = BackgroundGenerator(self._minibatch, args=args) + return batcher.iterator() + + # << INTERNAL METHODS >> # + def _validate_samples(self) -> None: + """ Ensures that the total number of images within :attr:`images` is greater or equal to + the selected :attr:`batch_size`. + + Raises + ------ + :class:`FaceswapError` + If the number of images loaded is smaller than the selected batch size + """ + length = len(self._images) + msg = ("Number of images is lower than batch-size (Note that too few images may lead to " + f"bad training). # images: {length}, batch-size: {self._batch_size}") + try: + assert length >= self._batch_size, msg + except AssertionError as err: + msg += ("\nYou should increase the number of images in your training set or lower " + "your batch-size.") + raise FaceswapError(msg) from err + + def _minibatch(self, do_shuffle: bool) -> Generator[BatchType, None, None]: + """ A generator function that yields the augmented, target and sample images for the + current batch on the current side. + + Parameters + ---------- + do_shuffle: bool, optional + Whether data should be shuffled prior to loading from disk. If true, each time the full + list of filenames are processed, the data will be reshuffled to make sure they are not + returned in the same order. Default: ``True`` + + Yields + ------ + feed: list + 4-dimensional array of faces to feed the training the model (:attr:`x` parameter for + :func:`keras.models.model.train_on_batch`.). The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + targets: list + List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each + output of the model. The format of these arrays will be (`batch size`, `height`, + `width`, `x`). This is the :attr:`y` parameter for + :func:`keras.models.model.train_on_batch`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + logger.debug("Loading minibatch generator: (image_count: %s, do_shuffle: %s)", + len(self._images), do_shuffle) + + def _img_iter(imgs): + """ Infinite iterator for recursing through image list and reshuffling at each epoch""" + while True: + if do_shuffle: + shuffle(imgs) + yield from imgs + + img_iter = _img_iter(self._images[:]) + while True: + img_paths = [next(img_iter) # pylint:disable=stop-iteration-return + for _ in range(self._batch_size)] + retval = self._process_batch(img_paths) + yield retval + + def _get_images_with_meta(self, filenames: list[str]) -> tuple[np.ndarray, list[DetectedFace]]: + """ Obtain the raw face images with associated :class:`DetectedFace` objects for this + batch. + + If this is the first time a face has been loaded, then it's meta data is extracted + from the png header and added to :attr:`_face_cache`. + + Parameters + ---------- + filenames: list + List of full paths to image file names + + Returns + ------- + raw_faces: :class:`numpy.ndarray` + The full sized batch of training images for the given filenames + list + Batch of :class:`~lib.align.DetectedFace` objects for the given filename including the + aligned face objects for the model output size + """ + if not self._face_cache.cache_full: + raw_faces = self._face_cache.cache_metadata(filenames) + else: + raw_faces = read_image_batch(filenames) + + detected_faces = self._face_cache.get_items(filenames) + logger.trace( # type:ignore[attr-defined] + "filenames: %s, raw_faces: '%s', detected_faces: %s", + filenames, raw_faces.shape, len(detected_faces)) + return raw_faces, detected_faces + + def _crop_to_coverage(self, + filenames: list[str], + images: np.ndarray, + detected_faces: list[DetectedFace], + batch: np.ndarray) -> None: + """ Crops the training image out of the full extract image based on the centering and + coveage used in the user's configuration settings. + + If legacy extract images are being used then this just returns the extracted batch with + their corresponding landmarks. + + Uses thread pool execution for about a 33% speed increase @ 64 batch size + + Parameters + ---------- + filenames: list + The list of filenames that correspond to this batch + images: :class:`numpy.ndarray` + The batch of faces that have been loaded from disk + detected_faces: list + The list of :class:`lib.align.DetectedFace` items corresponding to the batch + batch: :class:`np.ndarray` + The pre-allocated array to hold this batch + """ + logger.trace( # type:ignore[attr-defined] + "Cropping training images info: (filenames: %s, side: '%s')", filenames, self._side) + + with futures.ThreadPoolExecutor() as executor: + proc = {executor.submit(face.aligned.extract_face, img): idx + for idx, (face, img) in enumerate(zip(detected_faces, images))} + + for future in futures.as_completed(proc): + batch[proc[future], ..., :3] = future.result() + + def _apply_mask(self, detected_faces: list[DetectedFace], batch: np.ndarray) -> None: + """ Applies the masks to the 4th channel of the batch. + + If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1 + then these masks are applied to the final channels of the batch respectively. + + If masks are not being used then this function returns having done nothing + + Parameters + ---------- + detected_face: list + The list of :class:`~lib.align.DetectedFace` objects corresponding to the batch + batch: :class:`numpy.ndarray` + The preallocated array to apply masks to + side: str + '"a"' or '"b"' the side that is being processed + """ + if not self._use_mask: + return + + masks = np.array([face.get_training_masks() for face in detected_faces]) + batch[..., 3:] = masks + + logger.trace("side: %s, masks: %s, batch: %s", # type:ignore[attr-defined] + self._side, masks.shape, batch.shape) + + def _process_batch(self, filenames: list[str]) -> BatchType: + """ Prepares data for feeding through subclassed methods. + + If this is the first time a face has been loaded, then it's meta data is extracted from the + png header and added to :attr:`_face_cache` + + Parameters + ---------- + filenames: list + List of full paths to image file names for a single batch + + Returns + ------- + :class:`numpy.ndarray` + 4-dimensional array of faces to feed the training the model. + list + List of 4-dimensional :class:`numpy.ndarray`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + raw_faces, detected_faces = self._get_images_with_meta(filenames) + batch = self._buffer() + self._crop_to_coverage(filenames, raw_faces, detected_faces, batch) + self._apply_mask(detected_faces, batch) + feed, targets = self.process_batch(filenames, raw_faces, detected_faces, batch) + + logger.trace( # type:ignore[attr-defined] + "Processed %s batch side %s. (filenames: %s, feed: %s, targets: %s)", + self.__class__.__name__, self._side, filenames, feed.shape, [t.shape for t in targets]) + + return feed, targets + + def process_batch(self, + filenames: list[str], + images: np.ndarray, + detected_faces: list[DetectedFace], + batch: np.ndarray) -> BatchType: + """ Override for processing the batch for the current generator. + + Parameters + ---------- + filenames: list + List of full paths to image file names for a single batch + images: :class:`numpy.ndarray` + The batch of faces corresponding to the filenames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for + the current batch + batch: :class:`numpy.ndarray` + The pre-allocated batch with images and masks populated for the selected coverage and + centering + + Returns + ------- + list + 4-dimensional array of faces to feed the training the model. + list + List of 4-dimensional :class:`numpy.ndarray`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + raise NotImplementedError() + + def _set_color_order(self, batch) -> None: + """ Set the color order correctly for the model's input type. + + batch: :class:`numpy.ndarray` + The pre-allocated batch with images in the first 3 channels in BGR order + """ + if self._color_order == "rgb": + batch[..., :3] = batch[..., [2, 1, 0]] + + def _to_float32(self, in_array: np.ndarray) -> np.ndarray: + """ Cast an UINT8 array in 0-255 range to float32 in 0.0-1.0 range. + + in_array: :class:`numpy.ndarray` + The input uint8 array + """ + return ne.evaluate("x / c", + local_dict={"x": in_array, "c": np.float32(255)}, + casting="unsafe") + + +class TrainingDataGenerator(DataGenerator): + """ A Training Data Generator for compiling data for feeding to a model. + + This class is called from :mod:`plugins.train.trainer._base` and launches a background + iterator that compiles augmented data, target data and sample data. + + Parameters + ---------- + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding + side: {'a' or 'b'} + The side of the model that this iterator is for. + images: list + A list of image paths that will be used to compile the final augmented data from. + batch_size: int + The batch size for this iterator. Images will be returned in :class:`numpy.ndarray` + objects of this size from the iterator. + """ + def __init__(self, + model: ModelBase, + side: T.Literal["a", "b"], + images: list[str], + batch_size: int) -> None: + super().__init__(model, side, images, batch_size) + self._augment_color = not model.command_line_arguments.no_augment_color + self._no_flip = model.command_line_arguments.no_flip + self._no_warp = model.command_line_arguments.no_warp + self._warp_to_landmarks = (not self._no_warp + and model.command_line_arguments.warp_to_landmarks) + + if self._warp_to_landmarks: + self._face_cache.pre_fill(images, side) + self._processing = ImageAugmentation(batch_size, + self._process_size) + self._nearest_landmarks: dict[str, tuple[str, ...]] = {} + logger.debug("Initialized %s", self.__class__.__name__) + + def _create_targets(self, batch: np.ndarray) -> list[np.ndarray]: + """ Compile target images, with masks, for the model output sizes. + + Parameters + ---------- + batch: :class:`numpy.ndarray` + This should be a 4-dimensional array of training images in the format (`batch size`, + `height`, `width`, `channels`). Targets should be requested after performing image + transformations but prior to performing warps. The 4th channel should be the mask. + Any channels above the 4th should be any additional area masks (e.g. eye/mouth) that + are required. + + Returns + ------- + list + List of 4-dimensional target images, at all model output sizes, with masks compiled + into channels 4+ for each output size + """ + logger.trace("Compiling targets: batch shape: %s", # type:ignore[attr-defined] + batch.shape) + if len(self._output_sizes) == 1 and self._output_sizes[0] == self._process_size: + # Rolling buffer here makes next to no difference, so just create array on the fly + retval = [self._to_float32(batch)] + else: + retval = [self._to_float32(np.array([cv2.resize(image, + (size, size), + interpolation=cv2.INTER_AREA) + for image in batch])) + for size in self._output_sizes] + logger.trace("Processed targets: %s", # type:ignore[attr-defined] + [t.shape for t in retval]) + return retval + + def process_batch(self, + filenames: list[str], + images: np.ndarray, + detected_faces: list[DetectedFace], + batch: np.ndarray) -> BatchType: + """ Performs the augmentation and compiles target images and samples. + + Parameters + ---------- + filenames: list + List of full paths to image file names for a single batch + images: :class:`numpy.ndarray` + The batch of faces corresponding to the filenames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for + the current batch + batch: :class:`numpy.ndarray` + The pre-allocated batch with images and masks populated for the selected coverage and + centering + + Returns + ------- + feed: :class:`numpy.ndarray` + 4-dimensional array of faces to feed the training the model (:attr:`x` parameter for + :func:`keras.models.model.train_on_batch`.). The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + targets: list + List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each + output of the model. The format of these arrays will be (`batch size`, `height`, + `width`, `x`). This is the :attr:`y` parameter for + :func:`keras.models.model.train_on_batch`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + logger.trace("Process training: (side: '%s', filenames: '%s', images: %s, " # type:ignore + "batch: %s, detected_faces: %s)", self._side, filenames, images.shape, + batch.shape, len(detected_faces)) + + # Color Augmentation of the image only + if self._augment_color: + batch[..., :3] = self._processing.color_adjust(batch[..., :3]) + + # Random Transform and flip + self._processing.transform(batch) + + if not self._no_flip: + self._processing.random_flip(batch) + + # Switch color order for RGB models + self._set_color_order(batch) + + # Get Targets + targets = self._create_targets(batch) + + # TODO Look at potential for applying mask on input + # Random Warp + if self._warp_to_landmarks: + landmarks = np.array([face.aligned.landmarks for face in detected_faces]) + batch_dst_pts = self._get_closest_match(filenames, landmarks) + warp_kwargs = {"batch_src_points": landmarks, "batch_dst_points": batch_dst_pts} + else: + warp_kwargs = {} + + warped = batch[..., :3] if self._no_warp else self._processing.warp( + batch[..., :3], + self._warp_to_landmarks, + **warp_kwargs) + + if self._model_input_size != self._process_size: + feed = self._to_float32(np.array([cv2.resize(image, + (self._model_input_size, + self._model_input_size), + interpolation=cv2.INTER_AREA) + for image in warped])) + else: + feed = self._to_float32(warped) + + return feed, targets + + def _get_closest_match(self, filenames: list[str], batch_src_points: np.ndarray) -> np.ndarray: + """ Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest + matched 68 point landmarks from the opposite training set. + + Parameters + ---------- + filenames: list + Filenames for current batch + batch_src_points: :class:`np.ndarray` + The source landmarks for the current batch + + Returns + ------- + :class:`np.ndarray` + Randomly selected closest matches from the other side's landmarks + """ + logger.trace( # type:ignore[attr-defined] + "Retrieving closest matched landmarks: (filenames: '%s', src_points: '%s')", + filenames, batch_src_points) + lm_side: T.Literal["a", "b"] = "a" if self._side == "b" else "b" + other_cache = get_cache(lm_side) + landmarks = other_cache.aligned_landmarks + + try: + closest_matches = [self._nearest_landmarks[os.path.basename(filename)] + for filename in filenames] + except KeyError: + # Resize mismatched training image size landmarks + sizes = {side: cache.size for side, cache in zip((self._side, lm_side), + (self._face_cache, other_cache))} + if len(set(sizes.values())) > 1: + scale = sizes[self._side] / sizes[lm_side] + landmarks = {key: lms * scale for key, lms in landmarks.items()} + closest_matches = self._cache_closest_matches(filenames, batch_src_points, landmarks) + + batch_dst_points = np.array([landmarks[choice(fname)] for fname in closest_matches]) + logger.trace("Returning: (batch_dst_points: %s)", # type:ignore[attr-defined] + batch_dst_points.shape) + return batch_dst_points + + def _cache_closest_matches(self, + filenames: list[str], + batch_src_points: np.ndarray, + landmarks: dict[str, np.ndarray]) -> list[tuple[str, ...]]: + """ Cache the nearest landmarks for this batch + + Parameters + ---------- + filenames: list + Filenames for current batch + batch_src_points: :class:`np.ndarray` + The source landmarks for the current batch + landmarks: dict + The destination landmarks with associated filenames + + """ + logger.trace("Caching closest matches") # type:ignore + dst_landmarks = list(landmarks.items()) + dst_points = np.array([lm[1] for lm in dst_landmarks]) + batch_closest_matches: list[tuple[str, ...]] = [] + + for filename, src_points in zip(filenames, batch_src_points): + closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] + closest_matches = tuple(dst_landmarks[i][0] for i in closest) + self._nearest_landmarks[os.path.basename(filename)] = closest_matches + batch_closest_matches.append(closest_matches) + logger.trace("Cached closest matches") # type:ignore + return batch_closest_matches + + +class PreviewDataGenerator(DataGenerator): + """ Generator for compiling images for generating previews. + + This class is called from :mod:`plugins.train.trainer._base` and launches a background + iterator that compiles sample preview data for feeding the model's predict function and for + display. + + Parameters + ---------- + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding + side: {'a' or 'b'} + The side of the model that this iterator is for. + images: list + A list of image paths that will be used to compile the final images. + batch_size: int + The batch size for this iterator. Images will be returned in :class:`numpy.ndarray` + objects of this size from the iterator. + """ + def _create_samples(self, + images: np.ndarray, + detected_faces: list[DetectedFace]) -> list[np.ndarray]: + """ Compile the 'sample' images. These are the 100% coverage images which hold the model + output in the preview window. + + Parameters + ---------- + images: :class:`numpy.ndarray` + The original batch of images as loaded from disk. + detected_faces: list + List of :class:`~lib.align.DetectedFace` for the current batch + + Returns + ------- + list + List of 4-dimensional target images, at final model output size + """ + logger.trace( # type:ignore[attr-defined] + "Compiling samples: images shape: %s, detected_faces: %s ", + images.shape, len(detected_faces)) + output_size = self._output_sizes[-1] + full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2)) + + assert mod_cfg.centering() in T.get_args(CenteringType) + retval = np.empty((full_size, full_size, 3), dtype="float32") + y_offset = mod_cfg.vertical_offset() + assert isinstance(y_offset, int) + retval = self._to_float32(np.array([ + AlignedFace(face.landmarks_xy, + image=images[idx], + centering=T.cast(CenteringType, + mod_cfg.centering()), + y_offset=y_offset / 100., + size=full_size, + dtype="uint8", + is_aligned=True).face + for idx, face in enumerate(detected_faces)])) + + logger.trace("Processed samples: %s", retval.shape) # type:ignore[attr-defined] + return [retval] + + def process_batch(self, + filenames: list[str], + images: np.ndarray, + detected_faces: list[DetectedFace], + batch: np.ndarray) -> BatchType: + """ Creates the full size preview images and the sub-cropped images for feeding the model's + predict function. + + Parameters + ---------- + filenames: list + List of full paths to image file names for a single batch + images: :class:`numpy.ndarray` + The batch of faces corresponding to the filenames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for + the current batch + batch: :class:`numpy.ndarray` + The pre-allocated batch with images and masks populated for the selected coverage and + centering + + Returns + ------- + feed: :class:`numpy.ndarray` + List of 4-dimensional :class:`numpy.ndarray` objects at model output size for feeding + the model's predict function. The first 3 channels are (rgb/bgr). The 4th channel is + the face mask. + samples: list + 4-dimensional array containing the 100% coverage images at the model's centering for + for generating previews. The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + """ + logger.trace("Process preview: (side: '%s', filenames: '%s', images: %s, " # type:ignore + "batch: %s, detected_faces: %s)", self._side, filenames, images.shape, + batch.shape, len(detected_faces)) + + # Switch color order for RGB models + self._set_color_order(batch) + self._set_color_order(images) + + if not self._use_mask: + mask = np.zeros_like(batch[..., 0])[..., None] + 255 + batch = np.concatenate([batch, mask], axis=-1) + + feed = self._to_float32(batch[..., :4]) # Don't resize here: we want masks at output res. + + # If user sets model input size as larger than output size, the preview will error, so + # resize in these rare instances + out_size = max(self._output_sizes) + if self._process_size > out_size: + feed = np.array([cv2.resize(img, (out_size, out_size), interpolation=cv2.INTER_AREA) + for img in feed]) + + samples = self._create_samples(images, detected_faces) + + return feed, samples + + +class Feeder(): + """ Handles the processing of a Batch for training the model and generating samples. + + Parameters + ---------- + images: dict + The list of full paths to the training images for this :class:`_Feeder` for each side + model: plugin from :mod:`plugins.train.model` + The selected model that will be running this trainer + batch_size: int + The size of the batch to be processed for each side at each iteration + include_preview: bool, optional + ``True`` to create a feeder for generating previews. Default: ``True`` + """ + def __init__(self, + images: dict[T.Literal["a", "b"], list[str]], + model: ModelBase, + batch_size: int, + include_preview: bool = True) -> None: + logger.debug("Initializing %s: num_images: %s, batch_size: %s, include_preview: %s)", + self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size, + include_preview) + self._model = model + self._images = images + self._batch_size = batch_size + self._feeds = { + side: self._load_generator(side, False).minibatch_ab() + for side in T.get_args(T.Literal["a", "b"])} + + self._display_feeds = {"preview": self._set_preview_feed() if include_preview else {}, + "timelapse": {}} + logger.debug("Initialized %s:", self.__class__.__name__) + + def _load_generator(self, + side: T.Literal["a", "b"], + is_display: bool, + batch_size: int | None = None, + images: list[str] | None = None) -> DataGenerator: + """ Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder. + + Parameters + ---------- + side: ["a", "b"] + The side of the model to load the generator for + is_display: bool + ``True`` if the generator is for creating preview/time-lapse images. ``False`` if it is + for creating training images + batch_size: int, optional + If ``None`` then the batch size selected in command line arguments is used, otherwise + the batch size provided here is used. + images: list, optional. Default: ``None`` + If provided then this will be used as the list of images for the generator. If ``None`` + then the training folder images for the side will be used. Default: ``None`` + + Returns + ------- + :class:`~lib.training_data.TrainingDataGenerator` + The training data generator + """ + logger.debug("Loading generator, side: %s, is_display: %s, batch_size: %s", + side, is_display, batch_size) + generator = PreviewDataGenerator if is_display else TrainingDataGenerator + retval = generator(self._model, + side, + self._images[side] if images is None else images, + self._batch_size if batch_size is None else batch_size) + return retval + + def _set_preview_feed(self) -> dict[T.Literal["a", "b"], Generator[BatchType, None, None]]: + """ Set the preview feed for this feeder. + + Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically + for previews for the feeder. + + Returns + ------- + dict + The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as + value. + """ + retval: dict[T.Literal["a", "b"], Generator[BatchType, None, None]] = {} + num_images = trn_cfg.preview_images() + assert isinstance(num_images, int) + for side in T.get_args(T.Literal["a", "b"]): + logger.debug("Setting preview feed: (side: '%s')", side) + preview_images = min(max(num_images, 2), 16) + batchsize = min(len(self._images[side]), preview_images) + retval[side] = self._load_generator(side, + True, + batch_size=batchsize).minibatch_ab() + return retval + + def get_batch(self) -> tuple[np.ndarray, list[np.ndarray]]: + """ Get the feed data and the targets for each training side for feeding into the model's + train function. + + Returns + ------- + model_inputs : :class:`numpy.ndarray` + The inputs to the model for each side A and B. The array is returned in `(side, + batch_size, *dims)` where `side` 0 is "A" and `side` 1 is "B" + model_targets : list[:class:`numpy.ndarray`] + The targets for the model for each side A and B. For each target resolution output + required an array is inserted to the list in format `(side, batch_size, *dims) + where `side` 0 is "A" and `side` 1 is "B" + """ + model_inputs: list[np.ndarray] = [] + model_targets: tuple[list[np.ndarray], list[np.ndarray]] = ([], []) + for idx, side in enumerate(("a", "b")): + side_feed, side_targets = next(self._feeds[side]) + if mod_cfg.Loss.learn_mask(): # Add the face mask as it's own target + side_targets += [side_targets[-1][..., 3][..., None]] + logger.trace( # type:ignore[attr-defined] + "side: %s, input_shapes: %s, target_shapes: %s", + side, side_feed.shape, [i.shape for i in side_targets]) + model_inputs.append(side_feed) + model_targets[idx].extend(side_targets) + + grouped_targets = [] + + for tgt_a, tgt_b in zip(*model_targets): + grouped_targets.append(np.stack([tgt_a, tgt_b], axis=0)) + inputs = np.stack(model_inputs, axis=0) + assert inputs.shape[0] == 2, "1st dimension should represent side A/B" + assert all(x.shape[0] == 2 for x in grouped_targets), ("1st dimension should represent " + "side A/B") + return inputs, grouped_targets + + def generate_preview(self, is_timelapse: bool = False + ) -> dict[T.Literal["a", "b"], list[np.ndarray]]: + """ Generate the images for preview window or timelapse + + Parameters + ---------- + is_timelapse, bool, optional + ``True`` if preview is to be generated for a Timelapse otherwise ``False``. + Default: ``False`` + + Returns + ------- + dict + Dictionary for side A and B of list of numpy arrays corresponding to the + samples, targets and masks for this preview + """ + logger.debug("Generating preview (is_timelapse: %s)", is_timelapse) + + batchsizes: list[int] = [] + feed: dict[T.Literal["a", "b"], np.ndarray] = {} + samples: dict[T.Literal["a", "b"], np.ndarray] = {} + masks: dict[T.Literal["a", "b"], np.ndarray] = {} + + # MyPy can't recurse into nested dicts to get the type :( + iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"], + self._display_feeds["timelapse" if is_timelapse else "preview"]) + for side in T.get_args(T.Literal["a", "b"]): + side_feed, side_samples = next(iterator[side]) + batchsizes.append(len(side_samples[0])) + samples[side] = side_samples[0] + feed[side] = side_feed[..., :3] + masks[side] = side_feed[..., 3][..., None] + + logger.debug("Generated samples: is_timelapse: %s, images: %s", is_timelapse, + {key: {k: v.shape for k, v in item.items()} + for key, item + in zip(("feed", "samples", "sides"), (feed, samples, masks))}) + return self.compile_sample(min(batchsizes), feed, samples, masks) + + def compile_sample(self, + image_count: int, + feed: dict[T.Literal["a", "b"], np.ndarray], + samples: dict[T.Literal["a", "b"], np.ndarray], + masks: dict[T.Literal["a", "b"], np.ndarray] + ) -> dict[T.Literal["a", "b"], list[np.ndarray]]: + """ Compile the preview samples for display. + + Parameters + ---------- + image_count: int + The number of images to limit the sample output to. + feed: dict + Dictionary for side "a", "b" of :class:`numpy.ndarray`. The images that should be fed + into the model for obtaining a prediction + samples: dict + Dictionary for side "a", "b" of :class:`numpy.ndarray`. The 100% coverage target images + that should be used for creating the preview. + masks: dict + Dictionary for side "a", "b" of :class:`numpy.ndarray`. The masks that should be used + for creating the preview. + + Returns + ------- + list + The list of samples, targets and masks as :class:`numpy.ndarrays` for creating a + preview image + """ + num_images = trn_cfg.preview_images() + assert isinstance(num_images, int) + num_images = min(image_count, num_images) + retval: dict[T.Literal["a", "b"], list[np.ndarray]] = {} + for side in T.get_args(T.Literal["a", "b"]): + logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images) + retval[side] = [feed[side][0:num_images], + samples[side][0:num_images], + masks[side][0:num_images]] + logger.debug("Compiled Samples: %s", {k: [i.shape for i in v] for k, v in retval.items()}) + return retval + + def set_timelapse_feed(self, + images: dict[T.Literal["a", "b"], list[str]], + batch_size: int) -> None: + """ Set the time-lapse feed for this feeder. + + Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically + for generating time-lapse previews for the feeder. + + Parameters + ---------- + images: dict + The list of full paths to the images for creating the time-lapse for each side + batch_size: int + The number of images to be used to create the time-lapse preview. + """ + logger.debug("Setting time-lapse feed: (input_images: '%s', batch_size: %s)", + images, batch_size) + + # MyPy can't recurse into nested dicts to get the type :( + iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"], + self._display_feeds["timelapse"]) + + for side in T.get_args(T.Literal["a", "b"]): + imgs = images[side] + logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs)) + + iterator[side] = self._load_generator(side, + True, + batch_size=batch_size, + images=imgs).minibatch_ab(do_shuffle=False) + logger.debug("Set time-lapse feed: %s", self._display_feeds["timelapse"]) + + +__all__ = get_module_objects(__name__) diff --git a/lib/training/lr_finder.py b/lib/training/lr_finder.py new file mode 100644 index 0000000000..b12bd677a4 --- /dev/null +++ b/lib/training/lr_finder.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" Learning Rate Finder for faceswap.py. """ +from __future__ import annotations +import logging +import os +import shutil +import typing as T +from datetime import datetime +from enum import Enum + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from tqdm import tqdm + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from plugins.train import train_config as cfg + +if T.TYPE_CHECKING: + from keras import optimizers + from plugins.train import training + +logger = logging.getLogger(__name__) + + +class LRStrength(Enum): + """ Enum for how aggressively to set the optimal learning rate """ + DEFAULT = 10 + AGGRESSIVE = 5 + EXTREME = 2.5 + + +class LearningRateFinder: # pylint:disable=too-many-instance-attributes + """ Learning Rate Finder + + Parameters + ---------- + trainer : :class:`plugins.train.run_trainer.Trainer` + The training loop with the loaded training plugin + stop_factor : int + When to stop finding the optimal learning rate + beta : float + Amount to smooth loss by, for graphing purposes + """ + def __init__(self, # pylint:disable=too-many-positional-arguments + trainer: training.Trainer, + stop_factor: int = 4, + beta: float = 0.98) -> None: + logger.debug(parse_class_init(locals())) + self._iterations = cfg.lr_finder_iterations() + self._save_graph = cfg.lr_finder_mode() in ("graph_and_set", "graph_and_exit") + self._strength = LRStrength[cfg.lr_finder_strength().upper()].value + + self._start_lr = 1e-10 + end_lr = 1e+1 + + self._trainer = trainer + + self._model = trainer._plugin.model + self._optimizer = trainer._plugin.model.model.optimizer + + self._stop_factor = stop_factor + self._beta = beta + self._lr_multiplier: float = (end_lr / self._start_lr) ** (1.0 / self._iterations) + + self._metrics: dict[T.Literal["learning_rates", "losses"], list[float]] = { + "learning_rates": [], + "losses": []} + self._loss: dict[T.Literal["avg", "best"], float] = {"avg": 0.0, "best": 1e9} + + logger.debug("Initialized %s", self.__class__.__name__) + + def _on_batch_end(self, iteration: int, loss: float) -> None: + """ Learning rate actions to perform at the end of a batch + + Parameters + ---------- + iteration: int + The current iteration + loss: float + The loss value for the current batch + """ + learning_rate = float(self._optimizer.learning_rate.numpy()) + self._metrics["learning_rates"].append(learning_rate) + + self._loss["avg"] = (self._beta * self._loss["avg"]) + ((1 - self._beta) * loss) + smoothed = self._loss["avg"] / (1 - (self._beta ** iteration)) + self._metrics["losses"].append(smoothed) + + stop_loss = self._stop_factor * self._loss["best"] + + if iteration > 1 and smoothed > stop_loss: + self._model.model.stop_training = True + return + + if iteration == 1 or smoothed < self._loss["best"]: + self._loss["best"] = smoothed + + learning_rate *= self._lr_multiplier + + self._optimizer.learning_rate.assign(learning_rate) + + def _update_description(self, progress_bar: tqdm) -> None: + """ Update the description of the progress bar for the current iteration + + Parameters + ---------- + progress_bar: :class:`tqdm.tqdm` + The learning rate finder progress bar to update + """ + current = self._metrics['learning_rates'][-1] + best_idx = self._metrics["losses"].index(self._loss["best"]) + best = self._metrics["learning_rates"][best_idx] / self._strength + progress_bar.set_description(f"Current: {current:.1e} Best: {best:.1e}") + + def _train(self) -> None: + """ Train the model for the given number of iterations to find the optimal + learning rate and show progress""" + logger.info("Finding optimal learning rate...") + pbar = tqdm(range(1, self._iterations + 1), + desc="Current: N/A Best: N/A ", + leave=False) + for idx in pbar: + loss = self._trainer.train_one_batch() + + if any(np.isnan(x) for x in loss): + logger.warning("NaN detected! Exiting early") + break + self._on_batch_end(idx, loss[0]) + self._update_description(pbar) + + def _rebuild_optimizer(self, optimizer: optimizers.Optimizer) -> optimizers.Optimizer: + """ Pass through nested Optimizers (eg LossScaleOptimizer) and create new nested + optimizers based on their original config + + Returns + ------- + :class:`keras.optimizers.Optimizer` + A new optimizer of the same type as the given one, with the same config + """ + logger.debug("Processing optimizer: '%s'", optimizer.name) + config = optimizer.get_config() + if hasattr(optimizer, "inner_optimizer"): + config["inner_optimizer"] = self._rebuild_optimizer(optimizer.inner_optimizer) + retval = optimizer.__class__(**config) + logger.debug("Created optimizer '%s': (old: %s, new: %s)", + optimizer.name, optimizer, retval) + return retval + + def _reset_model(self, original_lr: float, new_lr: float) -> None: + """ Reset the model's weights to initial values, reset the model's optimizer and set the + learning rate + + Parameters + ---------- + original_lr: float + The model's original learning rate + new_lr: float + The discovered optimal learning rate + """ + self._model.state.add_lr_finder(new_lr) + self._model.state.save() + + if cfg.lr_finder_mode() == "graph_and_exit": + return + + logger.debug("Resetting optimizer") + optimizer = self._rebuild_optimizer(self._optimizer) + del self._optimizer + del self._model.model.optimizer + + logger.info("Loading initial weights") + self._model.model.load_weights(self._model.io.filename) + + self._model.model.compile(optimizer=optimizer, + loss=self._model.model.loss, + metrics=self._model.model.loss) + + logger.info("Updating Learning Rate from %s to %s", f"{original_lr:.1e}", f"{new_lr:.1e}") + self._model.model.optimizer.learning_rate.assign(new_lr) + self._optimizer = self._model.model.optimizer + + def find(self) -> bool: + """ Find the optimal learning rate + + Returns + ------- + bool + ``True`` if the learning rate was succesfully discovered otherwise ``False`` + """ + if not self._model.io.model_exists: + self._model.io.save() + + original_lr = float(self._model.model.optimizer.learning_rate.numpy()) + self._model.model.optimizer.learning_rate.assign(self._start_lr) + + self._train() + print("\x1b[2K", end="\r") # Clear line + + best_idx = self._metrics["losses"].index(self._loss["best"]) + new_lr = self._metrics["learning_rates"][best_idx] / self._strength + if new_lr < 1e-9: + logger.error("The optimal learning rate could not be found. This is most likely " + "because you did not run the finder for enough iterations.") + shutil.rmtree(self._model.io.model_dir) + return False + + self._plot_loss() + self._reset_model(original_lr, new_lr) + return True + + def _plot_loss(self, skip_begin: int = 10, skip_end: int = 1) -> None: + """ Plot a graph of loss vs learning rate and save to the training folder + + Parameters + ---------- + skip_begin: int, optional + Number of iterations to skip at the start. Default: `10` + skip_end: int, optional + Number of iterations to skip at the end. Default: `1` + """ + if not self._save_graph: + return + + matplotlib.use("Agg") + lrs = self._metrics["learning_rates"][skip_begin:-skip_end] + losses = self._metrics["losses"][skip_begin:-skip_end] + plt.plot(lrs, losses, label="Learning Rate") + best_idx = self._metrics["losses"].index(self._loss["best"]) + best_lr = self._metrics["learning_rates"][best_idx] + for val, color in zip(LRStrength, ("g", "y", "r")): + l_r = best_lr / val.value + idx = lrs.index(next(r for r in lrs if r >= l_r)) + plt.plot(l_r, losses[idx], + f"{color}o", + label=f"{val.name.title()}: {l_r:.1e}") + + plt.xscale("log") + plt.xlabel("Learning Rate (Log Scale)") + plt.ylabel("Loss") + plt.title("Learning Rate Finder") + plt.legend() + + now = datetime.now().strftime("%Y-%m-%d_%H.%M.%S") + output = os.path.join(self._model.io.model_dir, f"learning_rate_finder_{now}.png") + logger.info("Saving Learning Rate Finder graph to: '%s'", output) + plt.savefig(output) + + +__all__ = get_module_objects(__name__) diff --git a/lib/training/lr_warmup.py b/lib/training/lr_warmup.py new file mode 100644 index 0000000000..6bbb33ee7e --- /dev/null +++ b/lib/training/lr_warmup.py @@ -0,0 +1,105 @@ +#! /usr/env/bin/python3 +""" Handles Learning Rate Warmup when training a model """ +from __future__ import annotations + +import logging +import typing as T + +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from keras import models + +logger = logging.getLogger(__name__) + + +class LearningRateWarmup(): + """ Handles the updating of the model's learning rate during Learning Rate Warmup + + Parameters + ---------- + model : :class:`keras.models.Model` + The keras model that is to be trained + target_learning_rate : float + The final learning rate at the end of warmup + steps : int + The number of iterations to warmup the learning rate for + """ + def __init__(self, model: models.Model, target_learning_rate: float, steps: int) -> None: + self._model = model + self._target_lr = target_learning_rate + self._steps = steps + self._current_lr = 0.0 + self._current_step = 0 + self._reporting_points = [int(self._steps * i / 10) for i in range(11)] + logger.debug("Initialized %s", self) + + def __repr__(self) -> str: + """ Pretty string representation for logging """ + call_args = ", ".join(f"{k}={v}" for k, v in {"model": self._model, + "target_learning_rate": self._target_lr, + "steps": self._steps}.items()) + current_params = ", ".join(f"{k[1:]}: {v}" for k, v in self.__dict__.items() + if k not in ("_model", "_target_lr", "_steps")) + return f"{self.__class__.__name__}({call_args}) [{current_params}]" + + @classmethod + def _format_notation(cls, value: float) -> str: + """ Format a float to scientific notation at 1 decimal place + + Parameters + ---------- + value : float + The value to format + + Returns + ------- + str + The formatted float in scientific notation at 1 decimal place + """ + return f"{value:.1e}" + + def _set_learning_rate(self) -> None: + """ Set the learning rate for the current step """ + self._current_lr = self._current_step / self._steps * self._target_lr + self._model.optimizer.learning_rate.assign(self._current_lr) + logger.debug("Learning rate set to %s for step %s/%s", + self._current_lr, self._current_step, self._steps) + + def _output_status(self) -> None: + """ Output the progress of Learning Rate Warmup at set intervals """ + if self._current_step == 1: + logger.info("[Learning Rate Warmup] Start: %s, Target: %s, Steps: %s", + self._format_notation(self._current_lr), + self._format_notation(self._target_lr), self._steps) + return + + if self._current_step == self._steps: + print() + logger.info("[Learning Rate Warmup] Final Learning Rate: %s", + self._format_notation(self._target_lr)) + return + + if self._current_step in self._reporting_points: + print() + progress = int(round(100 / (len(self._reporting_points) - 1) * + self._reporting_points.index(self._current_step), 0)) + logger.info("[Learning Rate Warmup] Step: %s/%s (%s), Current: %s, Target: %s", + self._current_step, + self._steps, + f"{progress}%", + self._format_notation(self._current_lr), + self._format_notation(self._target_lr)) + + def __call__(self) -> None: + """ If a learning rate update is required, update the model's learning rate, otherwise + do nothing """ + if self._steps == 0 or self._current_step >= self._steps: + return + + self._current_step += 1 + self._set_learning_rate() + self._output_status() + + +__all__ = get_module_objects(__name__) diff --git a/lib/training/preview_cv.py b/lib/training/preview_cv.py new file mode 100644 index 0000000000..6a0d0a2ff1 --- /dev/null +++ b/lib/training/preview_cv.py @@ -0,0 +1,197 @@ +#!/usr/bin/python +""" The pop up preview window for Faceswap. + +If Tkinter is installed, then this will be used to manage the preview image, otherwise we +fallback to opencv's imshow +""" +from __future__ import annotations +import logging +import typing as T + +from threading import Event, Lock +from time import sleep + +import cv2 + +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from collections.abc import Generator + import numpy as np + +logger = logging.getLogger(__name__) +TriggerType = dict[T.Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event] +TriggerKeysType = T.Literal["m", "r", "s", "enter"] +TriggerNamesType = T.Literal["toggle_mask", "refresh", "save", "quit"] + + +class PreviewBuffer(): + """ A thread safe class for holding preview images """ + def __init__(self) -> None: + logger.debug("Initializing: %s", self.__class__.__name__) + self._images: dict[str, np.ndarray] = {} + self._lock = Lock() + self._updated = Event() + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def is_updated(self) -> bool: + """ bool: ``True`` when new images have been loaded into the preview buffer """ + return self._updated.is_set() + + def add_image(self, name: str, image: np.ndarray) -> None: + """ Add an image to the preview buffer in a thread safe way """ + logger.debug("Adding image: (name: '%s', shape: %s)", name, image.shape) + with self._lock: + self._images[name] = image + logger.debug("Added images: %s", list(self._images)) + self._updated.set() + + def get_images(self) -> Generator[tuple[str, np.ndarray], None, None]: + """ Get the latest images from the preview buffer. When iterator is exhausted clears the + :attr:`updated` event. + + Yields + ------ + name: str + The name of the image + :class:`numpy.ndarray` + The image in BGR format + """ + logger.debug("Retrieving images: %s", list(self._images)) + with self._lock: + for name, image in self._images.items(): + logger.debug("Yielding: '%s' (%s)", name, image.shape) + yield name, image + if self.is_updated: + logger.debug("Clearing updated event") + self._updated.clear() + logger.debug("Retrieved images") + + +class PreviewBase(): # pylint:disable=too-few-public-methods + """ Parent class for OpenCV and Tkinter Preview Windows + + Parameters + ---------- + preview_buffer: :class:`PreviewBuffer` + The thread safe object holding the preview images + triggers: dict, optional + Dictionary of event triggers for pop-up preview. Not required when running inside the GUI. + Default: `None` + """ + def __init__(self, + preview_buffer: PreviewBuffer, + triggers: TriggerType | None = None) -> None: + logger.debug("Initializing %s parent (triggers: %s)", + self.__class__.__name__, triggers) + self._triggers = triggers + self._buffer = preview_buffer + self._keymaps: dict[TriggerKeysType, TriggerNamesType] = {"m": "toggle_mask", + "r": "refresh", + "s": "save", + "enter": "quit"} + self._title = "" + logger.debug("Initialized %s parent", self.__class__.__name__) + + @property + def _should_shutdown(self) -> bool: + """ bool: ``True`` if the preview has received an external signal to shutdown otherwise + ``False`` """ + if self._triggers is None or not self._triggers["shutdown"].is_set(): + return False + logger.debug("Shutdown signal received") + return True + + def _launch(self) -> None: + """ Wait until an image is loaded into the preview buffer and call the child's + :func:`_display_preview` function """ + logger.debug("Launching %s", self.__class__.__name__) + while True: + if self._should_shutdown: + logger.debug("Shutdown received") + return + if not self._buffer.is_updated: + logger.debug("Waiting for preview image") + sleep(1) + continue + break + logger.debug("Launching preview") + self._display_preview() + + def _display_preview(self) -> None: + """ Override for preview viewer's display loop """ + raise NotImplementedError() + + +class PreviewCV(PreviewBase): # pylint:disable=too-few-public-methods + """ Simple fall back preview viewer using OpenCV for when TKinter is not available + + Parameters + ---------- + preview_buffer: :class:`PreviewBuffer` + The thread safe object holding the preview images + triggers: dict + Dictionary of event triggers for pop-up preview. + """ + def __init__(self, + preview_buffer: PreviewBuffer, + triggers: TriggerType) -> None: + logger.debug("Unable to import Tkinter. Falling back to OpenCV") + super().__init__(preview_buffer, triggers=triggers) + self._triggers: TriggerType = self._triggers + self._windows: list[str] = [] + + self._lookup = {ord(key): val + for key, val in self._keymaps.items() if key != "enter"} + self._lookup[ord("\n")] = self._keymaps["enter"] + self._lookup[ord("\r")] = self._keymaps["enter"] + + self._launch() + + @property + def _window_closed(self) -> bool: + """ bool: ``True`` if any window has been closed otherwise ``False`` """ + retval = any(cv2.getWindowProperty(win, cv2.WND_PROP_VISIBLE) < 1 for win in self._windows) + if retval: + logger.debug("Window closed detected") + return retval + + def _check_keypress(self, key: int): + """ Check whether we have received a valid key press from OpenCV window and handle + accordingly. + + Parameters + ---------- + key_press: int + The key press received from OpenCV + """ + if not key or key == -1 or key not in self._lookup: + return + + if key == ord("r"): + print("\x1b[2K", end="\r") # clear last line + logger.info("Refresh preview requested...") + + self._triggers[self._lookup[key]].set() + logger.debug("Processed keypress '%s'. Set event for '%s'", key, self._lookup[key]) + + def _display_preview(self): + """ Handle the displaying of the images currently in :attr:`_preview_buffer`""" + while True: + if self._buffer.is_updated or self._window_closed: + for name, image in self._buffer.get_images(): + logger.debug("showing image: '%s' (%s)", name, image.shape) + cv2.imshow(name, image) + self._windows.append(name) + + key = cv2.waitKey(1000) + self._check_keypress(key) + + if self._triggers["shutdown"].is_set(): + logger.debug("Shutdown received") + break + logger.debug("%s shutdown", self.__class__.__name__) + + +__all__ = get_module_objects(__name__) diff --git a/lib/training/preview_tk.py b/lib/training/preview_tk.py new file mode 100644 index 0000000000..c71610231c --- /dev/null +++ b/lib/training/preview_tk.py @@ -0,0 +1,948 @@ +#!/usr/bin/python +""" The pop up preview window for Faceswap. + +If Tkinter is installed, then this will be used to manage the preview image, otherwise we +fallback to opencv's imshow +""" +from __future__ import annotations +import logging +import os +import sys +import tkinter as tk +import typing as T + +from datetime import datetime +from platform import system +from tkinter import ttk +from math import ceil, floor + +from PIL import Image, ImageTk + +import cv2 + +from lib.utils import get_module_objects + +from .preview_cv import PreviewBase, TriggerKeysType + +if T.TYPE_CHECKING: + import numpy as np + from .preview_cv import PreviewBuffer, TriggerType + +logger = logging.getLogger(__name__) + + +class _Taskbar(): + """ Taskbar at bottom of Preview window + + Parameters + ---------- + parent: :class:`tkinter.Frame` + The parent frame that holds the canvas and taskbar + taskbar: :class:`tkinter.ttk.Frame` or ``None`` + None if preview is a pop-up window otherwise ttk.Frame if taskbar is managed by the GUI + """ + def __init__(self, parent: tk.Frame, taskbar: ttk.Frame | None) -> None: + logger.debug("Initializing %s (parent: '%s', taskbar: %s)", + self.__class__.__name__, parent, taskbar) + self._is_standalone = taskbar is None + self._gui_mapped: list[tk.Widget] = [] + self._frame = tk.Frame(parent) if taskbar is None else taskbar + + self._min_max_scales = (20, 400) + self._vars = {"save": tk.BooleanVar(), + "scale": tk.StringVar(), + "slider": tk.IntVar(), + "interpolator": tk.IntVar()} + self._interpolators = [("nearest_neighbour", cv2.INTER_NEAREST), + ("bicubic", cv2.INTER_CUBIC)] + self._scale = self._add_scale_combo() + self._slider = self._add_scale_slider() + self._add_interpolator_radio() + + if self._is_standalone: + self._add_save_button() + self._frame.pack(side=tk.BOTTOM, fill=tk.X, padx=2, pady=2) + + logger.debug("Initialized %s ('%s')", self.__class__.__name__, self) + + @property + def min_scale(self) -> int: + """ int: The minimum allowed scale """ + return self._min_max_scales[0] + + @property + def max_scale(self) -> int: + """ int: The maximum allowed scale """ + return self._min_max_scales[1] + + @property + def save_var(self) -> tk.BooleanVar: + """:class:`tkinter.IntVar`: Variable which is set to ``True`` when the save button has + been. pressed """ + retval = self._vars["save"] + assert isinstance(retval, tk.BooleanVar) + return retval + + @property + def scale_var(self) -> tk.StringVar: + """:class:`tkinter.StringVar`: The variable holding the currently selected "##%" formatted + percentage scaling amount displayed in the Combobox. """ + retval = self._vars["scale"] + assert isinstance(retval, tk.StringVar) + return retval + + @property + def slider_var(self) -> tk.IntVar: + """:class:`tkinter.IntVar`: The variable holding the currently selected percentage scaling + amount in the slider. """ + retval = self._vars["slider"] + assert isinstance(retval, tk.IntVar) + return retval + + @property + def interpolator_var(self) -> tk.IntVar: + """:class:`tkinter.IntVar`: The variable holding the CV2 Interpolator Enum. """ + retval = self._vars["interpolator"] + assert isinstance(retval, tk.IntVar) + return retval + + def _track_widget(self, widget: tk.Widget) -> None: + """ If running embedded in the GUI track the widgets so that they can be destroyed if + the preview is disabled """ + if self._is_standalone: + return + logger.debug("Tracking option bar widget for GUI: %s", widget) + self._gui_mapped.append(widget) + + def _add_scale_combo(self) -> ttk.Combobox: + """ Add a scale combo for selecting zoom amount. + + Returns + ------- + :class:`tkinter.ttk.Combobox` + The Combobox widget + """ + logger.debug("Adding scale combo") + self.scale_var.set("100%") + scale = ttk.Combobox(self._frame, + textvariable=self.scale_var, + values=["Fit"], + state="readonly", + width=10) + scale.pack(side=tk.RIGHT) + scale.bind("", self._clear_combo_focus) # Remove auto-focus on widget text box + self._track_widget(scale) + logger.debug("Added scale combo: '%s'", scale) + return scale + + def _clear_combo_focus(self, *args) -> None: # pylint:disable=unused-argument + """ Remove the highlighting and stealing of focus that the combobox annoyingly + implements. """ + logger.debug("Clearing scale combo focus") + self._scale.selection_clear() + self._scale.winfo_toplevel().focus_set() + logger.debug("Cleared scale combo focus") + + def _add_scale_slider(self) -> tk.Scale: + """ Add a scale slider for zooming the image. + + Returns + ------- + :class:`tkinter.Scale` + The scale widget + """ + logger.debug("Adding scale slider") + self.slider_var.set(100) + slider = tk.Scale(self._frame, + orient=tk.HORIZONTAL, + to=self.max_scale, + showvalue=False, + variable=self.slider_var, + command=self._on_slider_update) + slider.pack(side=tk.RIGHT) + self._track_widget(slider) + logger.debug("Added scale slider: '%s'", slider) + return slider + + def _add_interpolator_radio(self) -> None: + """ Add a radio box to choose interpolator """ + frame = tk.Frame(self._frame) + for text, mode in self._interpolators: + logger.debug("Adding %s radio button", text) + radio = tk.Radiobutton(frame, text=text, value=mode, variable=self.interpolator_var) + radio.pack(side=tk.LEFT, anchor=tk.W) + self._track_widget(radio) + + logger.debug("Added %s radio button", radio) + self.interpolator_var.set(cv2.INTER_NEAREST) + frame.pack(side=tk.RIGHT) + self._track_widget(frame) + + def _add_save_button(self) -> None: + """ Add a save button for saving out original preview """ + logger.debug("Adding save button") + button = tk.Button(self._frame, + text="Save", + cursor="hand2", + command=lambda: self.save_var.set(True)) + button.pack(side=tk.LEFT) + logger.debug("Added save burron: '%s'", button) + + def _on_slider_update(self, value) -> None: + """ Callback for when the scale slider is adjusted. Adjusts the combo box display to the + current slider value. + + Parameters + ---------- + value: int + The value that the slider has been set to + """ + self.scale_var.set(f"{value}%") + + def set_min_max_scale(self, min_scale: int, max_scale: int) -> None: + """ Set the minimum and maximum value that we allow an image to be scaled down to. This + impacts the slider and combo box min/max values: + + Parameters + ---------- + min_scale: int + The minimum percentage scale that is permitted + max_scale: int + The maximum percentage scale that is permitted + """ + logger.debug("Setting min/max scales: (min: %s, max: %s)", min_scale, max_scale) + self._min_max_scales = (min_scale, max_scale) + self._slider.config(from_=self.min_scale, to=max_scale) + scales = [10, 25, 50, 75, 100, 200, 300, 400, 800] + if min_scale not in scales: + scales.insert(0, min_scale) + if max_scale not in scales: + scales.append(max_scale) + choices = ["Fit", *[f"{x}%" for x in scales if self.max_scale >= x >= self.min_scale]] + self._scale.config(values=choices) + logger.debug("Set min/max scale. min_max_scales: %s, scale combo choices: %s", + self._min_max_scales, choices) + + def cycle_interpolators(self, *args) -> None: # pylint:disable=unused-argument + """ Cycle interpolators on a keypress callback """ + current = next(i for i in self._interpolators if i[1] == self.interpolator_var.get()) + next_idx = self._interpolators.index(current) + 1 + next_idx = 0 if next_idx == len(self._interpolators) else next_idx + self.interpolator_var.set(self._interpolators[next_idx][1]) + + def destroy_widgets(self) -> None: + """ Remove the taskbar widgets when the preview within the GUI has been disabled """ + if self._is_standalone: + return + + for widget in reversed(self._gui_mapped): + if widget.winfo_ismapped(): + logger.debug("Removing widget: %s", widget) + widget.pack_forget() + widget.destroy() + del widget + + for var in list(self._vars): + logger.debug("Deleting tk variable: %s", var) + del self._vars[var] + + +class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors + """ The canvas that holds the preview image + + Parameters + ---------- + parent: :class:`tkinter.Frame` + The parent frame that will hold the Canvas and taskbar + scale_var: :class:`tkinter.StringVar` + The variable that holds the value from the scale combo box + screen_dimensions: tuple + The (`width`, `height`) of the displaying monitor + is_standalone: bool + ``True`` if the preview is standalone, ``False`` if it is in the GUI + """ + def __init__(self, + parent: tk.Frame, + scale_var: tk.StringVar, + screen_dimensions: tuple[int, int], + is_standalone: bool) -> None: + logger.debug("Initializing %s (parent: '%s', scale_var: %s, screen_dimensions: %s)", + self.__class__.__name__, parent, scale_var, screen_dimensions) + frame = tk.Frame(parent) + super().__init__(frame) + + self._is_standalone = is_standalone + self._screen_dimensions = screen_dimensions + self._var_scale = scale_var + self._configure_scrollbars(frame) + self._image: ImageTk.PhotoImage | None = None + self._image_id = self.create_image(self.width / 2, + self.height / 2, + anchor=tk.CENTER, + image=self._image) + self.pack(fill=tk.BOTH, expand=True) + self.bind("", self._resize) + frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + logger.debug("Initialized %s ('%s')", self.__class__.__name__, self) + + @property + def image_id(self) -> int: + """ int: The ID of the preview image item within the canvas """ + return self._image_id + + @property + def width(self) -> int: + """int: The pixel width of canvas""" + return self.winfo_width() + + @property + def height(self) -> int: + """int: The pixel width of the canvas""" + return self.winfo_height() + + def _configure_scrollbars(self, frame: tk.Frame) -> None: + """ Add X and Y scrollbars to the frame and set to scroll the canvas. + + Parameters + ---------- + frame: :class:`tkinter.Frame` + The parent frame to the canvas + """ + logger.debug("Configuring scrollbars") + x_scrollbar = tk.Scrollbar(frame, orient="horizontal", command=self.xview) + x_scrollbar.pack(side=tk.BOTTOM, fill=tk.X) + + y_scrollbar = tk.Scrollbar(frame, command=self.yview) + y_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + self.configure(xscrollcommand=x_scrollbar.set, yscrollcommand=y_scrollbar.set) + logger.debug("Configured scrollbars. x: '%s', y: '%s'", x_scrollbar, y_scrollbar) + + def _resize(self, event: tk.Event) -> None: # pylint:disable=unused-argument + """ Place the image in center of canvas on resize event and move to top left + + Parameters + ---------- + event: :class:`tkinter.Event` + The canvas resize event. Unused. + """ + if self._var_scale.get() == "Fit": # Trigger an update to resize image + logger.debug("Triggering redraw for 'Fit' Scaling") + self._var_scale.set("Fit") + return + + self.configure(scrollregion=self.bbox("all")) + self.update_idletasks() + + assert self._image is not None + self._center_image(self.width / 2, self.height / 2) + + # Move to top left when resizing into screen dimensions (initial startup) + if self.width > self._screen_dimensions[0]: + logger.debug("Moving image to left edge") + self.xview_moveto(0.0) + if self.height > self._screen_dimensions[1]: + logger.debug("Moving image to top edge") + self.yview_moveto(0.0) + + def _center_image(self, point_x: float, point_y: float) -> None: + """ Center the image on the canvas on a resize or image update. + + Parameters + ---------- + point_x: int + The x point to center on + point_y: int + The y point to center on + """ + canvas_location = (self.canvasx(point_x), self.canvasy(point_y)) + logger.debug("Centering canvas for size (%s, %s). New image coordinates: %s", + point_x, point_y, canvas_location) + self.coords(self.image_id, canvas_location) + + def set_image(self, + image: ImageTk.PhotoImage, + center_image: bool = False) -> None: + """ Update the canvas with the given image and update area/scrollbars accordingly + + Parameters + ---------- + image: :class:`ImageTK.PhotoImage` + The preview image to display in the canvas + bool, optional + ``True`` if the image should be re-centered. Default ``True`` + """ + logger.debug("Setting canvas image. ID: %s, size: %s for canvas size: %s (recenter: %s)", + self.image_id, (image.width(), image.height()), (self.width, self.height), + center_image) + self._image = image + self.itemconfig(self.image_id, image=self._image) + + if self._is_standalone: # canvas size should not be updated inside GUI + self.config(width=self._image.width(), height=self._image.height()) + + self.update_idletasks() + if center_image: + self._center_image(self.width / 2, self.height / 2) + self.configure(scrollregion=self.bbox("all")) + logger.debug("set canvas image. Canvas size: %s", (self.width, self.height)) + + +class _Image(): + """ Holds the source image and the resized display image for the canvas + + Parameters + ---------- + save_variable: :class:`tkinter.BooleanVar` + Variable that indicates a save preview has been requested in standalone mode + is_standalone: bool + ``True`` if the preview is running in standalone mode. ``False`` if it is running in the + GUI + """ + def __init__(self, save_variable: tk.BooleanVar, is_standalone: bool) -> None: + logger.debug("Initializing %s: (save_variable: %s, is_standalone: %s)", + self.__class__.__name__, save_variable, is_standalone) + self._is_standalone = is_standalone + self._source: np.ndarray | None = None + self._display: ImageTk.PhotoImage | None = None + self._scale = 1.0 + self._interpolation = cv2.INTER_NEAREST + + self._save_var = save_variable + self._save_var.trace("w", self.save_preview) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def display_image(self) -> ImageTk.PhotoImage: + """ :class:`PIL.ImageTk.PhotoImage`: The current display image """ + assert self._display is not None + return self._display + + @property + def source(self) -> np.ndarray: + """ :class:`PIL.Image.Image`: The current source preview image """ + assert self._source is not None + return self._source + + @property + def scale(self) -> int: + """int: The current display scale as a percentage of original image size """ + return int(self._scale * 100) + + def set_source_image(self, name: str, image: np.ndarray) -> None: + """ Set the source image to :attr:`source` + + Parameters + ---------- + name: str + The name of the preview image to load + image: :class:`numpy.ndarray` + The image to use in RGB format + """ + logger.debug("Setting source image. name: '%s', shape: %s", name, image.shape) + self._source = image + + def set_display_image(self) -> None: + """ Obtain the scaled image and set to :attr:`display_image` """ + logger.debug("Setting display image. Scale: %s", self._scale) + image = self.source[..., 2::-1] # TO RGB + if self._scale not in (0.0, 1.0): # Scale will be 0,0 on initial load in GUI + interp = self._interpolation if self._scale > 1.0 else cv2.INTER_NEAREST + dims = (int(round(self.source.shape[1] * self._scale, 0)), + int(round(self.source.shape[0] * self._scale, 0))) + image = cv2.resize(image, dims, interpolation=interp) + self._display = ImageTk.PhotoImage(Image.fromarray(image)) + logger.debug("Set display image. Size: %s", + (self._display.width(), self._display.height())) + + def set_scale(self, scale: float) -> bool: + """ Set the display scale to the given value. + + Parameters + ---------- + scale: float + The value to set scaling to + + Returns + ------- + bool + ``True`` if the scale has been changed otherwise ``False`` + """ + if self._scale == scale: + return False + logger.debug("Setting scale: %s", scale) + self._scale = scale + return True + + def set_interpolation(self, interpolation: int) -> bool: + """ Set the interpolation enum to the given value. + + Parameters + ---------- + interpolation: int + The value to set interpolation to + + Returns + ------- + bool + ``True`` if the interpolation has been changed otherwise ``False`` + """ + if self._interpolation == interpolation: + return False + logger.debug("Setting interpolation: %s", interpolation) + self._interpolation = interpolation + return True + + def save_preview(self, *args) -> None: + """ Save out the full size preview to the faceswap folder on a save button press + + Parameters + ---------- + args: tuple + Tuple containing either the key press event (Ctrl+s shortcut), the tk variable + arguments (standalone save button press) or the folder location (GUI save button press) + """ + if self._is_standalone and not self._save_var.get() and not isinstance(args[0], tk.Event): + return + + if self._is_standalone: + root_path = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0]))) + else: + root_path = args[0] + + now = datetime.now().strftime("%Y-%m-%d_%H.%M.%S") + filename = os.path.join(root_path, f"preview_{now}.png") + cv2.imwrite(filename, self.source) + print("\x1b[2K", end="\r") # Clear last line + logger.info("Saved preview to: '%s'", filename) + + if self._is_standalone: + self._save_var.set(False) + + +class _Bindings(): # pylint:disable=too-few-public-methods + """ Handle Mouse and Keyboard bindings for the canvas. + + Parameters + ---------- + canvas: :class:`_PreviewCanvas` + The canvas that holds the preview image + taskbar: :class:`_Taskbar` + The taskbar widget which holds the scaling variables + image: :class:`_Image` + The object which holds the source and display version of the preview image + is_standalone: bool + ``True`` if the preview is standalone, ``False`` if it is embedded in the GUI + """ + def __init__(self, + canvas: _PreviewCanvas, + taskbar: _Taskbar, + image: _Image, + is_standalone: bool) -> None: + logger.debug("Initializing %s (canvas: '%s', taskbar: '%s', image: '%s')", + self.__class__.__name__, canvas, taskbar, image) + self._canvas = canvas + self._taskbar = taskbar + self._image = image + + self._drag_data: list[float] = [0., 0.] + self._set_mouse_bindings() + self._set_key_bindings(is_standalone) + logger.debug("Initialized %s", self.__class__.__name__,) + + def _on_bound_zoom(self, event: tk.Event) -> None: + """ Action to perform on a valid zoom key press or mouse wheel action + + Parameters + ---------- + event: :class:`tkinter.Event` + The key press or mouse wheel event + """ + if event.keysym in ("KP_Add", "plus") or event.num == 4 or event.delta > 0: + scale = min(self._taskbar.max_scale, self._image.scale + 25) + else: + scale = max(self._taskbar.min_scale, self._image.scale - 25) + logger.trace("Bound zoom action: (event: %s, scale: %s)", event, scale) # type: ignore + self._taskbar.scale_var.set(f"{scale}%") + + def _on_mouse_click(self, event: tk.Event) -> None: + """ log initial click coordinates for mouse click + drag action + + Parameters + ---------- + event: :class:`tkinter.Event` + The mouse event + """ + self._drag_data = [event.x / self._image.display_image.width(), + event.y / self._image.display_image.height()] + logger.trace("Mouse click action: (event: %s, drag_data: %s)", # type: ignore + event, self._drag_data) + + def _on_mouse_drag(self, event: tk.Event) -> None: + """ Drag image left, right, up or down + + Parameters + ---------- + event: :class:`tkinter.Event` + The mouse event + """ + location_x = event.x / self._image.display_image.width() + location_y = event.y / self._image.display_image.height() + + if self._canvas.xview() != (0.0, 1.0): + to_x = min(1.0, max(0.0, self._drag_data[0] - location_x + self._canvas.xview()[0])) + self._canvas.xview_moveto(to_x) + if self._canvas.yview() != (0.0, 1.0): + to_y = min(1.0, max(0.0, self._drag_data[1] - location_y + self._canvas.yview()[0])) + self._canvas.yview_moveto(to_y) + + self._drag_data = [location_x, location_y] + + def _on_key_move(self, event: tk.Event) -> None: + """ Action to perform on a valid move key press + + Parameters + ---------- + event: :class:`tkinter.Event` + The key press event + """ + move_axis = self._canvas.xview if event.keysym in ("Left", "Right") else self._canvas.yview + visible = move_axis()[1] - move_axis()[0] + amount = -visible / 25 if event.keysym in ("Up", "Left") else visible / 25 + logger.trace("Key move event: (event: %s, move_axis: %s, visible: %s, " # type: ignore + "amount: %s)", move_axis, visible, amount) + move_axis(tk.MOVETO, min(1.0, max(0.0, move_axis()[0] + amount))) + + def _set_mouse_bindings(self) -> None: + """ Set the mouse bindings for interacting with the preview image + + Mousewheel: Zoom in and out + Mouse click: Move image + """ + logger.debug("Binding mouse events") + if system() == "Linux": + self._canvas.tag_bind(self._canvas.image_id, "", self._on_bound_zoom) + self._canvas.tag_bind(self._canvas.image_id, "", self._on_bound_zoom) + else: + self._canvas.bind("", self._on_bound_zoom) + + self._canvas.tag_bind(self._canvas.image_id, "", self._on_mouse_click) + self._canvas.tag_bind(self._canvas.image_id, "", self._on_mouse_drag) + logger.debug("Bound mouse events") + + def _set_key_bindings(self, is_standalone: bool) -> None: + """ Set the keyboard bindings. + + Up/Down/Left/Right: Moves image + +/-: Zooms image + ctrl+s: Save + i: Cycle interpolators + + Parameters + ---------- + ``True`` if the preview is standalone, ``False`` if it is embedded in the GUI + """ + if not is_standalone: + # Don't bind keys for GUI as it adds complication + return + logger.debug("Binding key events") + root = self._canvas.winfo_toplevel() + for key in ("Left", "Right", "Up", "Down"): + root.bind(f"<{key}>", self._on_key_move) + for key in ("Key-plus", "Key-minus", "Key-KP_Add", "Key-KP_Subtract"): + root.bind(f"<{key}>", self._on_bound_zoom) + root.bind("", self._image.save_preview) + root.bind("", self._taskbar.cycle_interpolators) + logger.debug("Bound key events") + + +class PreviewTk(PreviewBase): + """ Holds a preview window for displaying the pop out preview. + + Parameters + ---------- + preview_buffer: :class:`PreviewBuffer` + The thread safe object holding the preview images + parent: tkinter widget, optional + If this viewer is being called from the GUI the parent widget should be passed in here. + If this is a standalone pop-up window then pass ``None``. Default: ``None`` + taskbar: :class:`tkinter.ttk.Frame`, optional + If this viewer is being called from the GUI the parent's option frame should be passed in + here. If this is a standalone pop-up window then pass ``None``. Default: ``None`` + triggers: dict, optional + Dictionary of event triggers for pop-up preview. Not required when running inside the GUI. + Default: `None` + """ + def __init__(self, + preview_buffer: PreviewBuffer, + parent: tk.Widget | None = None, + taskbar: ttk.Frame | None = None, + triggers: TriggerType | None = None) -> None: + logger.debug("Initializing %s (parent: '%s')", self.__class__.__name__, parent) + super().__init__(preview_buffer, triggers=triggers) + self._is_standalone = parent is None + self._initialized = False + self._root = parent if parent is not None else tk.Tk() + self._master_frame = tk.Frame(self._root) + + self._taskbar = _Taskbar(self._master_frame, taskbar) + + self._screen_dimensions = self._get_geometry() + self._canvas = _PreviewCanvas(self._master_frame, + self._taskbar.scale_var, + self._screen_dimensions, + self._is_standalone) + + self._image = _Image(self._taskbar.save_var, self._is_standalone) + + _Bindings(self._canvas, self._taskbar, self._image, self._is_standalone) + + self._taskbar.scale_var.trace("w", self._set_scale) + self._taskbar.interpolator_var.trace("w", self._set_interpolation) + + self._process_triggers() + + if self._is_standalone: + self.pack(fill=tk.BOTH, expand=True) + + self._output_helptext() + + logger.debug("Initialized %s", self.__class__.__name__) + + self._launch() + + @property + def master_frame(self) -> tk.Frame: + """ :class:`tkinter.Frame`: The master frame that holds the preview window """ + return self._master_frame + + def pack(self, *args, **kwargs): + """ Redirect calls to pack the widget to pack the actual :attr:`_master_frame`. + + Takes standard :class:`tkinter.Frame` pack arguments + """ + logger.debug("Packing master frame: (args: %s, kwargs: %s)", args, kwargs) + self._master_frame.pack(*args, **kwargs) + + def save(self, location: str) -> None: + """ Save action to be performed when save button pressed from the GUI. + + location: str + Full path to the folder to save the preview image to + """ + self._image.save_preview(location) + + def remove_option_controls(self) -> None: + """ Remove the taskbar options controls when the preview is disabled in the GUI """ + self._taskbar.destroy_widgets() + + def _output_helptext(self) -> None: + """ Output the keybindings to Console. """ + if not self._is_standalone: + return + logger.info("---------------------------------------------------") + logger.info(" Preview key bindings:") + logger.info(" Zoom: +/-") + logger.info(" Toggle Zoom Mode: i") + logger.info(" Move: arrow keys") + logger.info(" Save Preview: Ctrl+s") + logger.info("---------------------------------------------------") + + def _get_geometry(self) -> tuple[int, int]: + """ Obtain the geometry of the current screen (standalone) or the dimensions of the widget + holding the preview window (GUI). + + Just pulling screen width and height does not account for multiple monitors, so dummy in a + window to pull actual dimensions before hiding it again. + + Returns + ------- + Tuple + The (`width`, `height`) of the current monitor's display + """ + if not self._is_standalone: + root = self._root.winfo_toplevel() # Get dims of whole GUI + retval = root.winfo_width(), root.winfo_height() + logger.debug("Obtained frame geometry: %s", retval) + return retval + + assert isinstance(self._root, tk.Tk) + logger.debug("Obtaining screen geometry") + self._root.update_idletasks() + self._root.attributes("-fullscreen", True) + self._root.state("iconic") + retval = self._root.winfo_width(), self._root.winfo_height() + self._root.attributes("-fullscreen", False) + self._root.state("withdraw") + logger.debug("Obtained screen geometry: %s", retval) + return retval + + def _set_min_max_scales(self) -> None: + """ Set the minimum and maximum area that we allow to scale image to. """ + logger.debug("Calculating minimum scale for screen dimensions %s", self._screen_dimensions) + half_screen = tuple(x // 2 for x in self._screen_dimensions) + min_scales = (half_screen[0] / self._image.source.shape[1], + half_screen[1] / self._image.source.shape[0]) + min_scale = min(1.0, *min_scales) + min_scale = (ceil(min_scale * 10)) * 10 + + eight_screen = tuple(x * 8 for x in self._screen_dimensions) + max_scales = (eight_screen[0] / self._image.source.shape[1], + eight_screen[1] / self._image.source.shape[0]) + max_scale = min(8.0, max(1.0, min(max_scales))) + max_scale = (floor(max_scale * 10)) * 10 + + logger.debug("Calculated minimum scale: %s, maximum_scale: %s", min_scale, max_scale) + self._taskbar.set_min_max_scale(min_scale, max_scale) + + def _initialize_window(self) -> None: + """ Initialize the window to fit into the current screen """ + logger.debug("Initializing window") + assert isinstance(self._root, tk.Tk) + width = min(self._master_frame.winfo_reqwidth(), self._screen_dimensions[0]) + height = min(self._master_frame.winfo_reqheight(), self._screen_dimensions[1]) + self._set_min_max_scales() + self._root.state("normal") + self._root.geometry(f"{width}x{height}") + self._root.protocol("WM_DELETE_WINDOW", lambda: None) # Intercept close window + self._initialized = True + logger.debug("Initialized window: (width: %s, height: %s)", width, height) + + def _update_image(self, center_image: bool = False) -> None: + """ Update the image displayed in the canvas and set the canvas size and scroll region + accordingly + + center_image: bool = ``True`` + ``True`` if the image in the canvas should be recentered. Defaul:``True`` + """ + logger.debug("Updating image (center_image: %s)", center_image) + self._image.set_display_image() + self._canvas.set_image(self._image.display_image, center_image) + logger.debug("Updated image") + + def _convert_fit_scale(self) -> str: + """ Convert "Fit" scale to the actual scaling amount + + Returns + ------- + str + The fit scaling in '##%' format + """ + logger.debug("Converting 'Fit' scaling") + width_scale = self._canvas.width / self._image.source.shape[1] + height_scale = self._canvas.height / self._image.source.shape[0] + scale = min(width_scale, height_scale) * 100 + retval = f"{floor(scale)}%" + logger.debug("Converted 'Fit' scaling: (width_scale: %s, height_scale: %s, scale: %s, " + "retval: '%s'", width_scale, height_scale, scale, retval) + return retval + + def _set_scale(self, *args) -> None: # pylint:disable=unused-argument + """ Update the image on a scale request """ + txtscale = self._taskbar.scale_var.get() + logger.debug("Setting scale: '%s'", txtscale) + txtscale = self._convert_fit_scale() if txtscale == "Fit" else txtscale + scale = int(txtscale[:-1]) # Strip percentage and convert to int + logger.debug("Got scale: %s", scale) + + if self._image.set_scale(scale / 100): + logger.debug("Updating for new scale") + self._taskbar.slider_var.set(scale) + self._update_image(center_image=True) + + def _set_interpolation(self, *args) -> None: # pylint:disable=unused-argument + """ Callback for when the interpolator is change""" + interp = self._taskbar.interpolator_var.get() + if not self._image.set_interpolation(interp) or self._image.scale <= 1.0: + return + self._update_image(center_image=False) + + def _process_triggers(self) -> None: + """ Process the standard faceswap key press triggers: + + m = toggle_mask + r = refresh + s = save + enter = quit + """ + if self._triggers is None: # Don't need triggers for GUI + return + logger.debug("Processing triggers") + root = self._canvas.winfo_toplevel() + for key in self._keymaps: + bindkey = "Return" if key == "enter" else key + logger.debug("Adding trigger for key: '%s'", bindkey) + + root.bind(f"<{bindkey}>", self._on_keypress) + logger.debug("Processed triggers") + + def _on_keypress(self, event: tk.Event) -> None: + """ Update the triggers on a keypress event for picking up by main faceswap process. + + Parameters + ---------- + event: :class:`tkinter.Event` + The valid preview trigger keypress + """ + if self._triggers is None: # Don't need triggers for GUI + return + keypress = "enter" if event.keysym == "Return" else event.keysym + key = T.cast(TriggerKeysType, keypress) + logger.debug("Processing keypress '%s'", key) + if key == "r": + print("\x1b[2K", end="\r") # Clear last line + logger.info("Refresh preview requested...") + + self._triggers[self._keymaps[key]].set() + logger.debug("Processed keypress '%s'. Set event for '%s'", key, self._keymaps[key]) + + def _display_preview(self) -> None: + """ Handle the displaying of the images currently in :attr:`_preview_buffer`""" + if self._should_shutdown: + self._root.destroy() + + if not self._buffer.is_updated: + self._root.after(1000, self._display_preview) + return + + for name, image in self._buffer.get_images(): + logger.debug("Updating image: (name: '%s', shape: %s)", name, image.shape) + if self._is_standalone and not self._title: + assert isinstance(self._root, tk.Tk) + self._title = name + logger.debug("Setting title: '%s;", self._title) + self._root.title(self._title) + self._image.set_source_image(name, image) + self._update_image(center_image=not self._initialized) + + self._root.after(1000, self._display_preview) + + if not self._initialized and self._is_standalone: + self._initialize_window() + self._root.mainloop() + if not self._initialized: # Set initialized to True for GUI + self._set_min_max_scales() + self._taskbar.scale_var.set("Fit") + self._initialized = True + + +def main(): + """ Load image from first given argument and display + + python -m lib.training.preview_tk + """ + from lib.logger import log_setup # pylint:disable=import-outside-toplevel + from .preview_cv import PreviewBuffer # pylint:disable=import-outside-toplevel + log_setup("DEBUG", "faceswap_preview.log", "Test", False) + + img = cv2.imread(sys.argv[-1], cv2.IMREAD_UNCHANGED) + buff = PreviewBuffer() # pylint:disable=used-before-assignment + buff.add_image("test_image", img) + PreviewTk(buff) + + +__all__ = get_module_objects(__name__) + + +if __name__ == "__main__": + main() diff --git a/lib/training/tensorboard.py b/lib/training/tensorboard.py new file mode 100644 index 0000000000..07ed576a26 --- /dev/null +++ b/lib/training/tensorboard.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" Tensorboard call back for PyTorch logging. Hopefully temporary until a native Keras version +is implemented """ +from __future__ import annotations + +import logging +import os +import struct +import typing as T + +import keras +from torch.utils.tensorboard import SummaryWriter + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +logger = logging.getLogger(__name__) + + +class RecordIterator: + """ A replacement for tensorflow's :func:`compat.v1.io.tf_record_iterator` + + Parameters + ---------- + log_file : str + The event log file to obtain records from + is_live : bool, optional + ``True`` if the log file is for a live training session that will constantly provide data. + Default: ``False`` + """ + def __init__(self, log_file, is_live: bool = False) -> None: + logger.debug(parse_class_init(locals())) + self._file_path = log_file + self._log_file = open(self._file_path, "rb") # pylint:disable=consider-using-with + self._is_live = is_live + self._position = 0 + logger.debug("Initialized %s", self.__class__.__name__) + + def __iter__(self) -> RecordIterator: + """ Iterate over a Tensorboard event file""" + return self + + def _on_file_read(self) -> None: + """ If the file is closed and we are reading live data, re-open the file and seek to the + correct position """ + if not self._is_live or not self._log_file.closed: + return + + logger.trace("Re-opening '%s' and Seeking to %s", # type:ignore[attr-defined] + self._file_path, self._position) + self._log_file = open(self._file_path, "rb") # pylint:disable=consider-using-with + self._log_file.seek(self._position, 0) + + def _on_file_end(self) -> None: + """ Close the event file. If live data, record the current position""" + if self._is_live: + self._position = self._log_file.tell() + logger.trace("Setting live position to %s", # type:ignore[attr-defined] + self._position) + + logger.trace("EOF. Closing '%s'", self._file_path) # type:ignore[attr-defined] + self._log_file.close() + + def __next__(self) -> bytes: + """ Get the next event log from a Tensorboard event file + + Returns + ------- + bytes + A Tensorboard event log + + Raises + ------ + StopIteration + When the event log is fully consumed + """ + self._on_file_read() + + b_header = self._log_file.read(8) + + if not b_header: + self._on_file_end() + raise StopIteration + + read_len = int(struct.unpack('Q', b_header)[0]) + self._log_file.seek(4, 1) + data = self._log_file.read(read_len) + + self._log_file.seek(4, 1) + logger.trace("Returning event data of len %s", read_len) # type:ignore[attr-defined] + + return data + + +class TorchTensorBoard(keras.callbacks.Callback): + """Enable visualizations for TensorBoard. Adapted from Keras' Tensorboard Callback keeping + only the parts we need, and using Torch rather than TensorFlow + + Parameters + ---------- + log_dir str + The path of the directory where to save the log files to be parsed by TensorBoard. e.g., + `log_dir = os.path.join(working_dir, 'logs')`. This directory should not be reused by any + other callbacks. + write_graph: bool (Not supported at this time) + Whether to visualize the graph in TensorBoard. Note that the log file can become quite + large when `write_graph` is set to `True`. + update_freq: Literal["batch", "epoch"] | int + When using `"epoch"`, writes the losses and metrics to TensorBoard after every epoch. + If using an integer, let's say `1000`, all metrics and losses (including custom ones + added by `Model.compile`) will be logged to TensorBoard every 1000 batches. `"batch"` + is a synonym for 1, meaning that they will be written every batch. Note however that + writing too frequently to TensorBoard can slow down your training, especially when used + with distribution strategies as it will incur additional synchronization overhead. Batch- + level summary writing is also available via `train_step` override. Please see [TensorBoard + Scalars + tutorial](https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) + """ + + def __init__(self, + log_dir: str = "logs", + write_graph: bool = True, + update_freq: T.Literal["batch", "epoch"] | int = "epoch") -> None: + logger.debug(parse_class_init(locals())) + super().__init__() + self.log_dir = str(log_dir) + self.write_graph = write_graph + self.update_freq = 1 if update_freq == "batch" else update_freq + + self._should_write_train_graph = False + self._train_dir = os.path.join(self.log_dir, "train") + self._train_step = 0 + self._global_train_batch = 0 + self._previous_epoch_iterations = 0 + + self._model: keras.models.Model | None = None + self._writers: dict[str, SummaryWriter] = {} + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _train_writer(self) -> SummaryWriter: + """:class:`torch.utils.tensorboard.SummaryWriter`: The summary writer """ + if "train" not in self._writers: + self._writers["train"] = SummaryWriter(self._train_dir) + return self._writers["train"] + + def _write_keras_model_summary(self) -> None: + """Writes Keras graph network summary to TensorBoard.""" + assert self._model is not None + summary = self._model.to_json() + self._train_writer.add_text("keras", summary, global_step=0) + + def _write_keras_model_train_graph(self) -> None: + """Writes Keras graph to TensorBoard.""" + # TODO implement + logger.debug("Tensorboard graph logging not yet implemented") + + def set_model(self, model: keras.models.Model) -> None: + """Sets Keras model and writes graph if specified. + + Parameters + ---------- + model: :class:`keras.models.Model` + The model that is being trained + """ + self._model = model + + if self.write_graph: + self._write_keras_model_summary() + self._should_write_train_graph = True + + def on_train_begin(self, logs=None) -> None: + """ Initialize the call back on train start + + Parameters + ---------- + logs: None + Unused + """ + self._global_train_batch = 0 + self._previous_epoch_iterations = 0 + + def on_train_batch_end(self, batch: int, logs: dict[str, float] | None = None) -> None: + """ Update Tensorboard logs on batch end + + Parameters + ---------- + batch: int + The current iteration count + logs: dict[str, float] + The logs to write + """ + assert logs is not None + if self._should_write_train_graph: + self._write_keras_model_train_graph() + self._should_write_train_graph = False + + for key, value in logs.items(): + self._train_writer.add_scalar(f"batch_{key}", + value, + global_step=batch) + + def on_save(self) -> None: + """ Flush data to disk on save """ + logger.debug("Flushing Tensorboard writer") + self._train_writer.flush() + + def on_train_end(self, logs=None) -> None: + """ Close the writer on train completion + + Parameters + ---------- + logs: None + Unused + """ + for writer in self._writers.values(): + writer.flush() + writer.close() + + +__all__ = get_module_objects(__name__) diff --git a/lib/training_data.py b/lib/training_data.py deleted file mode 100644 index 32037c9245..0000000000 --- a/lib/training_data.py +++ /dev/null @@ -1,401 +0,0 @@ -#!/usr/bin/env python3 -""" Process training data for model training """ - -import logging - -from hashlib import sha1 -from random import shuffle - -import cv2 -import numpy as np -from scipy.interpolate import griddata - -from lib.model import masks -from lib.multithreading import MultiThread -from lib.queue_manager import queue_manager -from lib.umeyama import umeyama - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class TrainingDataGenerator(): - """ Generate training data for models """ - def __init__(self, model_input_size, model_output_size, training_opts): - logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, " - "training_opts: %s, landmarks: %s)", - self.__class__.__name__, model_input_size, model_output_size, - {key: val for key, val in training_opts.items() if key != "landmarks"}, - bool(training_opts.get("landmarks", None))) - self.batchsize = 0 - self.model_input_size = model_input_size - self.training_opts = training_opts - self.mask_function = self.set_mask_function() - self.landmarks = self.training_opts.get("landmarks", None) - - self.processing = ImageManipulation(model_input_size, - model_output_size, - training_opts.get("coverage_ratio", 0.625)) - logger.debug("Initialized %s", self.__class__.__name__) - - def set_mask_function(self): - """ Set the mask function to use if using mask """ - mask_type = self.training_opts.get("mask_type", None) - if mask_type: - logger.debug("Mask type: '%s'", mask_type) - mask_func = getattr(masks, mask_type) - else: - mask_func = None - logger.debug("Mask function: %s", mask_func) - return mask_func - - def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False): - """ Keep a queue filled to 8x Batch Size """ - logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " - "is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse) - self.batchsize = batchsize - q_name = "timelapse_{}".format(side) if is_timelapse else "train_{}".format(side) - q_size = batchsize * 8 - # Don't use a multiprocessing queue because sometimes the MP Manager borks on numpy arrays - queue_manager.add_queue(q_name, maxsize=q_size, multiprocessing_queue=False) - load_thread = MultiThread(self.load_batches, - images, - q_name, - side, - is_timelapse, - do_shuffle) - load_thread.start() - logger.debug("Batching to queue: (side: '%s', queue: '%s')", side, q_name) - return self.minibatch(q_name, load_thread) - - def load_batches(self, images, q_name, side, is_timelapse, do_shuffle=True): - """ Load the warped images and target images to queue """ - logger.debug("Loading batch: (image_count: %s, q_name: '%s', side: '%s', " - "is_timelapse: %s, do_shuffle: %s)", - len(images), q_name, side, is_timelapse, do_shuffle) - epoch = 0 - queue = queue_manager.get_queue(q_name) - self.validate_samples(images) - while True: - if do_shuffle: - shuffle(images) - for img in images: - logger.trace("Putting to batch queue: (q_name: '%s', side: '%s')", q_name, side) - queue.put(self.process_face(img, side, is_timelapse)) - epoch += 1 - logger.debug("Finished batching: (epoch: %s, q_name: '%s', side: '%s')", - epoch, q_name, side) - - def validate_samples(self, data): - """ Check the total number of images against batchsize and return - the total number of images """ - length = len(data) - msg = ("Number of images is lower than batch-size (Note that too few " - "images may lead to bad training). # images: {}, " - "batch-size: {}".format(length, self.batchsize)) - assert length >= self.batchsize, msg - - def minibatch(self, q_name, load_thread): - """ A generator function that yields epoch, batchsize of warped_img - and batchsize of target_img from the load queue """ - logger.debug("Launching minibatch generator for queue: '%s'", q_name) - queue = queue_manager.get_queue(q_name) - while True: - if load_thread.has_error: - logger.debug("Thread error detected") - break - batch = list() - for _ in range(self.batchsize): - images = queue.get() - for idx, image in enumerate(images): - if len(batch) < idx + 1: - batch.append(list()) - batch[idx].append(image) - batch = [np.float32(image) for image in batch] - logger.trace("Yielding batch: (size: %s, item shapes: %s, queue: '%s'", - len(batch), [item.shape for item in batch], q_name) - yield batch - logger.debug("Finished minibatch generator for queue: '%s'", q_name) - load_thread.join() - - def process_face(self, filename, side, is_timelapse): - """ Load an image and perform transformation and warping """ - logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)", - filename, side, is_timelapse) - try: - image = cv2.imread(filename) # pylint: disable=no-member - except TypeError: - raise Exception("Error while reading image", filename) - - if self.mask_function or self.training_opts["warp_to_landmarks"]: - src_pts = self.get_landmarks(filename, image, side) - if self.mask_function: - image = self.mask_function(src_pts, image, channels=4) - - image = self.processing.color_adjust(image) - - if not is_timelapse: - image = self.processing.random_transform(image) - if not self.training_opts["no_flip"]: - image = self.processing.do_random_flip(image) - sample = image.copy()[:, :, :3] - - if self.training_opts["warp_to_landmarks"]: - dst_pts = self.get_closest_match(filename, side, src_pts) - processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts) - else: - processed = self.processing.random_warp(image) - - processed.insert(0, sample) - logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)", - filename, side, [img.shape for img in processed]) - return processed - - def get_landmarks(self, filename, image, side): - """ Return the landmarks for this face """ - logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side) - lm_key = sha1(image).hexdigest() - try: - src_points = self.landmarks[side][lm_key] - except KeyError: - raise Exception("Landmarks not found for hash: '{}' file: '{}'".format(lm_key, - filename)) - logger.trace("Returning: (src_points: %s)", src_points) - return src_points - - def get_closest_match(self, filename, side, src_points): - """ Return closest matched landmarks from opposite set """ - logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'", - filename, src_points) - dst_points = self.landmarks["a"] if side == "b" else self.landmarks["b"] - dst_points = list(dst_points.values()) - closest = (np.mean(np.square(src_points - dst_points), - axis=(1, 2))).argsort()[:10] - closest = np.random.choice(closest) - dst_points = dst_points[closest] - logger.trace("Returning: (dst_points: %s)", dst_points) - return dst_points - - -class ImageManipulation(): - """ Manipulations to be performed on training images """ - def __init__(self, input_size, output_size, coverage_ratio): - """ input_size: Size of the face input into the model - output_size: Size of the face that comes out of the modell - coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160 - """ - logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)", - self.__class__.__name__, input_size, output_size, coverage_ratio) - # Transform args - self.rotation_range = 10 # Range to randomly rotate the image by - self.zoom_range = 0.05 # Range to randomly zoom the image by - self.shift_range = 0.05 # Range to randomly translate the image by - self.random_flip = 0.5 # Chance to flip the image horizontally - # Transform and Warp args - self.input_size = input_size - self.output_size = output_size - # Warp args - self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160 - self.scale = 5 # Normal random variable scale - logger.debug("Initialized %s", self.__class__.__name__) - - @staticmethod - def color_adjust(img): - """ Color adjust RGB image """ - logger.trace("Color adjusting image") - return img.astype('float32') / 255.0 - - @staticmethod - def separate_mask(image): - """ Return the image and the mask from a 4 channel image """ - mask = None - if image.shape[2] == 4: - logger.trace("Image contains mask") - mask = np.expand_dims(image[:, :, -1], axis=2) - image = image[:, :, :3] - else: - logger.trace("Image has no mask") - return image, mask - - def get_coverage(self, image): - """ Return coverage value for given image """ - coverage = int(image.shape[0] * self.coverage_ratio) - logger.trace("Coverage: %s", coverage) - return coverage - - def random_transform(self, image): - """ Randomly transform an image """ - logger.trace("Randomly transforming image") - height, width = image.shape[0:2] - - rotation = np.random.uniform(-self.rotation_range, self.rotation_range) - scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range) - tnx = np.random.uniform(-self.shift_range, self.shift_range) * width - tny = np.random.uniform(-self.shift_range, self.shift_range) * height - - mat = cv2.getRotationMatrix2D( # pylint: disable=no-member - (width // 2, height // 2), rotation, scale) - mat[:, 2] += (tnx, tny) - result = cv2.warpAffine( # pylint: disable=no-member - image, mat, (width, height), - borderMode=cv2.BORDER_REPLICATE) # pylint: disable=no-member - - logger.trace("Randomly transformed image") - return result - - def do_random_flip(self, image): - """ Perform flip on image if random number is within threshold """ - logger.trace("Randomly flipping image") - if np.random.random() < self.random_flip: - logger.trace("Flip within threshold. Flipping") - retval = image[:, ::-1] - else: - logger.trace("Flip outside threshold. Not Flipping") - retval = image - logger.trace("Randomly flipped image") - return retval - - def random_warp(self, image): - """ get pair of random warped images from aligned face image """ - logger.trace("Randomly warping image") - height, width = image.shape[0:2] - coverage = self.get_coverage(image) - assert height == width and height % 2 == 0 - - range_ = np.linspace(height // 2 - coverage // 2, - height // 2 + coverage // 2, - 5, dtype='float32') - mapx = np.broadcast_to(range_, (5, 5)).copy() - mapy = mapx.T - # mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast - - pad = int(1.25 * self.input_size) - slices = slice(pad // 10, -pad // 10) - dst_slice = slice(0, (self.output_size + 1), (self.output_size // 4)) - interp = np.empty((2, self.input_size, self.input_size), dtype='float32') - #### - - for i, map_ in enumerate([mapx, mapy]): - map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale) - interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint: disable=no-member - - warped_image = cv2.remap( # pylint: disable=no-member - image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint: disable=no-member - logger.trace("Warped image shape: %s", warped_image.shape) - - src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1) - dst_points = np.mgrid[dst_slice, dst_slice] - mat = umeyama(src_points, True, dst_points.T.reshape(-1, 2))[0:2] - target_image = cv2.warpAffine( # pylint: disable=no-member - image, mat, (self.output_size, self.output_size)) - logger.trace("Target image shape: %s", target_image.shape) - - warped_image, warped_mask = self.separate_mask(warped_image) - target_image, target_mask = self.separate_mask(target_image) - - if target_mask is None: - logger.trace("Randomly warped image") - return [warped_image, target_image] - - logger.trace("Target mask shape: %s", target_mask.shape) - logger.trace("Randomly warped image and mask") - return [warped_image, target_image, target_mask] - - def random_warp_landmarks(self, image, src_points=None, dst_points=None): - """ get warped image, target image and target mask - From DFAKER plugin """ - logger.trace("Randomly warping landmarks") - size = image.shape[0] - coverage = self.get_coverage(image) - - p_mx = size - 1 - p_hf = (size // 2) - 1 - - edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0), - (p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)] - grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)] - - source = src_points - destination = (dst_points.copy().astype('float32') + - np.random.normal(size=dst_points.shape, scale=2.0)) - destination = destination.astype('uint8') - - face_core = cv2.convexHull(np.concatenate( # pylint: disable=no-member - [source[17:], destination[17:]], axis=0).astype(int)) - - source = [(pty, ptx) for ptx, pty in source] + edge_anchors - destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors - - indicies_to_remove = set() - for fpl in source, destination: - for idx, (pty, ptx) in enumerate(fpl): - if idx > 17: - break - elif cv2.pointPolygonTest(face_core, # pylint: disable=no-member - (pty, ptx), - False) >= 0: - indicies_to_remove.add(idx) - - for idx in sorted(indicies_to_remove, reverse=True): - source.pop(idx) - destination.pop(idx) - - grid_z = griddata(destination, source, (grid_x, grid_y), method="linear") - map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size) - map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size) - map_x_32 = map_x.astype('float32') - map_y_32 = map_y.astype('float32') - - warped_image = cv2.remap(image, # pylint: disable=no-member - map_x_32, - map_y_32, - cv2.INTER_LINEAR, # pylint: disable=no-member - cv2.BORDER_TRANSPARENT) # pylint: disable=no-member - target_image = image - - # TODO Make sure this replacement is correct - slices = slice(size // 2 - coverage // 2, size // 2 + coverage // 2) -# slices = slice(size // 32, size - size // 32) # 8px on a 256px image - warped_image = cv2.resize( # pylint: disable=no-member - warped_image[slices, slices, :], (self.input_size, self.input_size), - cv2.INTER_AREA) # pylint: disable=no-member - logger.trace("Warped image shape: %s", warped_image.shape) - target_image = cv2.resize( # pylint: disable=no-member - target_image[slices, slices, :], (self.output_size, self.output_size), - cv2.INTER_AREA) # pylint: disable=no-member - logger.trace("Target image shape: %s", target_image.shape) - - warped_image, warped_mask = self.separate_mask(warped_image) - target_image, target_mask = self.separate_mask(target_image) - - if target_mask is None: - logger.trace("Randomly warped image") - return [warped_image, target_image] - - logger.trace("Target mask shape: %s", target_mask.shape) - logger.trace("Randomly warped image and mask") - return [warped_image, target_image, target_mask] - - -def stack_images(images): - """ Stack images """ - logger.debug("Stack images") - - def get_transpose_axes(num): - if num % 2 == 0: - logger.debug("Even number of images to stack") - y_axes = list(range(1, num - 1, 2)) - x_axes = list(range(0, num - 1, 2)) - else: - logger.debug("Odd number of images to stack") - y_axes = list(range(0, num - 1, 2)) - x_axes = list(range(1, num - 1, 2)) - return y_axes, x_axes, [num - 1] - - images_shape = np.array(images.shape) - new_axes = get_transpose_axes(len(images_shape)) - new_shape = [np.prod(images_shape[x]) for x in new_axes] - logger.debug("Stacked images") - return np.transpose( - images, - axes=np.concatenate(new_axes) - ).reshape(new_shape) diff --git a/lib/umeyama.py b/lib/umeyama.py deleted file mode 100644 index f3af365a9a..0000000000 --- a/lib/umeyama.py +++ /dev/null @@ -1,105 +0,0 @@ -## License (Modified BSD) -## Copyright (C) 2011, the scikit-image team All rights reserved. -## -## Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -## -## Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -## Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -## Neither the name of skimage nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. -## THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# umeyama function from scikit-image/skimage/transform/_geometric.py - -import numpy as np - -MEAN_FACE_X = np.array([ - 0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, - 0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, - 0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122, - 0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132, - 0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127, - .551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532, - 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364, - 0.490127, 0.42689]) - -MEAN_FACE_Y = np.array([ - 0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, - 0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, - 0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, - 0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758, - 0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578, - 0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192, - 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182, - 0.831803, 0.824182]) - -def umeyama(src, estimate_scale, dst=None): - """Estimate N-D similarity transformation with or without scaling. - Parameters - ---------- - src : (M, N) array - Source coordinates. - dst : (M, N) array - Destination coordinates. - estimate_scale : bool - Whether to estimate scaling factor. - Returns - ------- - T : (N + 1, N + 1) - The homogeneous similarity transformation matrix. The matrix contains - NaN values only if the problem is not well-conditioned. - References - ---------- - .. [1] "Least-squares estimation of transformation parameters between two - point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 - """ - if dst is None: - dst = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1) - - num = src.shape[0] - dim = src.shape[1] - - # Compute mean of src and dst. - src_mean = src.mean(axis=0) - dst_mean = dst.mean(axis=0) - - # Subtract mean from src and dst. - src_demean = src - src_mean - dst_demean = dst - dst_mean - - # Eq. (38). - A = np.dot(dst_demean.T, src_demean) / num - - # Eq. (39). - d = np.ones((dim,), dtype=np.double) - if np.linalg.det(A) < 0: - d[dim - 1] = -1 - - T = np.eye(dim + 1, dtype=np.double) - - U, S, V = np.linalg.svd(A) - - # Eq. (40) and (43). - rank = np.linalg.matrix_rank(A) - if rank == 0: - return np.nan * T - elif rank == dim - 1: - if np.linalg.det(U) * np.linalg.det(V) > 0: - T[:dim, :dim] = np.dot(U, V) - else: - s = d[dim - 1] - d[dim - 1] = -1 - T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) - d[dim - 1] = s - else: - T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T)) - - if estimate_scale: - # Eq. (41) and (42). - scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d) - else: - scale = 1.0 - - T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) - T[:dim, :dim] *= scale - - return T diff --git a/lib/utils.py b/lib/utils.py index 72d5dd2f15..747f6ca02e 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,45 +1,286 @@ #!/usr/bin python3 """ Utilities available across all scripts """ +# NOTE: Do not import keras/pytorch in this script, as it is accessed before they should be loaded +from __future__ import annotations +import inspect +import json import logging import os -import warnings +import sys +import tkinter as tk +import typing as T +import zipfile -from hashlib import sha1 -from pathlib import Path +from importlib import import_module +from multiprocessing import current_process from re import finditer - -import cv2 -import numpy as np - -import dlib - -from lib.faces_detect import DetectedFace -from lib.logger import get_loglevel - - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +from socket import timeout as socket_timeout, error as socket_error +from threading import get_ident +from time import time +from urllib import request, error as urlliberror + +try: + import numpy as np + from tqdm import tqdm +except: # noqa[E722] # pylint:disable=bare-except + # Importing outside of faceswap environment, these packages should not be required + np = None # type:ignore[assignment] # pylint:disable=invalid-name + tqdm = None # pylint:disable=invalid-name + +if T.TYPE_CHECKING: + from argparse import Namespace + from http.client import HTTPResponse # Global variables -_image_extensions = [ # pylint: disable=invalid-name - ".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"] -_video_extensions = [ # pylint: disable=invalid-name - ".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".webm"] - - -def get_folder(path): - """ Return a path to a folder, creating it if it doesn't exist """ +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +""" str : Full path to the root faceswap folder """ +IMAGE_EXTENSIONS = [".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"] +VIDEO_EXTENSIONS = [".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv", + ".ts", ".vob"] +ValidBackends = T.Literal["nvidia", "cpu", "apple_silicon", "rocm"] +_FS_BACKEND: ValidBackends | None = None + + +class _Backend(): # pylint:disable=too-few-public-methods + """ Return the backend from config/.faceswap of from the `FACESWAP_BACKEND` Environment + Variable. + + If file doesn't exist and a variable hasn't been set, create the config file. """ + def __init__(self) -> None: + self._backends: dict[str, ValidBackends] = {"1": "cpu", + "2": "nvidia", + "3": "apple_silicon", + "4": "rocm"} + self._valid_backends = list(self._backends.values()) + self._config_file = self._get_config_file() + self.backend: ValidBackends = self._get_backend() + + @classmethod + def _get_config_file(cls) -> str: + """ Obtain the location of the main Faceswap configuration file. + + Returns + ------- + str + The path to the Faceswap configuration file + """ + config_file = os.path.join(PROJECT_ROOT, "config", ".faceswap") + return config_file + + def _get_backend(self) -> ValidBackends: + """ Return the backend from either the `FACESWAP_BACKEND` Environment Variable or from + the :file:`config/.faceswap` configuration file. If neither of these exist, prompt the user + to select a backend. + + Returns + ------- + str + The backend configuration in use by Faceswap + """ + # Check if environment variable is set, if so use that + if "FACESWAP_BACKEND" in os.environ: + fs_backend = T.cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower()) + assert fs_backend in T.get_args(ValidBackends), ( + f"Faceswap backend must be one of {T.get_args(ValidBackends)}") + print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}") + return fs_backend + # Intercept for sphinx docs build + if sys.argv[0].endswith("sphinx-build"): + return "nvidia" + if not os.path.isfile(self._config_file): + self._configure_backend() + while True: + try: + with open(self._config_file, "r", encoding="utf8") as cnf: + config = json.load(cnf) + break + except json.decoder.JSONDecodeError: + self._configure_backend() + continue + fs_backend = config.get("backend", "").lower() + if not fs_backend or fs_backend not in self._backends.values(): + fs_backend = self._configure_backend() + if current_process().name == "MainProcess": + print(f"Setting Faceswap backend to {fs_backend.upper()}") + return fs_backend + + def _configure_backend(self) -> ValidBackends: + """ Get user input to select the backend that Faceswap should use. + + Returns + ------- + str + The backend configuration in use by Faceswap + """ + print("First time configuration. Please select the required backend") + while True: + txt = ", ".join([": ".join([key, val.upper().replace("_", " ")]) + for key, val in self._backends.items()]) + selection = input(f"{txt}: ") + if selection not in self._backends: + print(f"'{selection}' is not a valid selection. Please try again") + continue + break + fs_backend = self._backends[selection] + config = {"backend": fs_backend} + with open(self._config_file, "w", encoding="utf8") as cnf: + json.dump(config, cnf) + print(f"Faceswap config written to: {self._config_file}") + return fs_backend + + +def get_backend() -> ValidBackends: + """ Get the backend that Faceswap is currently configured to use. + + Returns + ------- + str + The backend configuration in use by Faceswap. One of ["cpu", "nvidia", "rocm", + "apple_silicon"] + + Example + ------- + >>> from lib.utils import get_backend + >>> get_backend() + 'nvidia' + """ + global _FS_BACKEND # pylint:disable=global-statement + if _FS_BACKEND is None: + _FS_BACKEND = _Backend().backend + return _FS_BACKEND + + +def set_backend(backend: str) -> None: + """ Override the configured backend with the given backend. + + Parameters + ---------- + backend: ["cpu", "nvidia", "rocm", "apple_silicon"] + The backend to set faceswap to + + Example + ------- + >>> from lib.utils import set_backend + >>> set_backend("nvidia") + """ + global _FS_BACKEND # pylint:disable=global-statement + backend = T.cast(ValidBackends, backend.lower()) + _FS_BACKEND = backend + + +_versions: dict[T.Literal["torch", "keras"], tuple[int, int]] = {} + + +def get_torch_version() -> tuple[int, int]: + """ Obtain the major. minor version of currently installed PyTorch. + + Returns + ------- + tuple[int, int] + A tuple of the form (major, minor) representing the version of PyTorch that is installed + + Example + ------- + >>> from lib.utils import get_torch_version + >>> get_torch_version() + (2, 2) + """ + if "torch" not in _versions: + torch = import_module("torch") + split = torch.__version__.split(".")[:2] + _versions["torch"] = (int(split[0]), int(split[1])) + return _versions["torch"] + + +def get_keras_version() -> tuple[int, int]: + """ Obtain the major. minor version of currently installed Keras. + + Returns + ------- + tuple[int, int] + A tuple of the form (major, minor) representing the version of Keras that is installed + + Example + ------- + >>> from lib.utils import get_torch_version + >>> get_torch_version() + (2, 2) + """ + if "keras" not in _versions: + keras = import_module("keras") + split = keras.__version__.split(".")[:2] + _versions["keras"] = (int(split[0]), int(split[1])) + return _versions["keras"] + + +def get_folder(path: str, make_folder: bool = True) -> str: + """ Return a path to a folder, creating it if it doesn't exist + + Parameters + ---------- + path: str + The path to the folder to obtain + make_folder: bool, optional + ``True`` if the folder should be created if it does not already exist, ``False`` if the + folder should not be created + + Returns + ------- + str or `None` + The path to the requested folder. If `make_folder` is set to ``False`` and the requested + path does not exist, then ``None`` is returned + + Example + ------- + >>> from lib.utils import get_folder + >>> get_folder('/tmp/myfolder') + '/tmp/myfolder' + + >>> get_folder('/tmp/myfolder', make_folder=False) + '' + """ + logger = logging.getLogger(__name__) logger.debug("Requested path: '%s'", path) - output_dir = Path(path) - output_dir.mkdir(parents=True, exist_ok=True) - logger.debug("Returning: '%s'", output_dir) - return output_dir - - -def get_image_paths(directory): - """ Return a list of images that reside in a folder """ - image_extensions = _image_extensions - dir_contents = list() + if not make_folder and not os.path.isdir(path): + logger.debug("%s does not exist", path) + return "" + os.makedirs(path, exist_ok=True) + logger.debug("Returning: '%s'", path) + return path + + +def get_image_paths(directory: str, extension: str | None = None) -> list[str]: + """ Gets the image paths from a given directory. + + The function searches for files with the specified extension(s) in the given directory, and + returns a list of their paths. If no extension is provided, the function will search for files + with any of the following extensions: '.bmp', '.jpeg', '.jpg', '.png', '.tif', '.tiff' + + Parameters + ---------- + directory: str + The directory to search in + extension: str + The file extension to search for. If not provided, all image file types will be searched + for + + Returns + ------- + list[str] + The list of full paths to the images contained within the given folder + + Example + ------- + >>> from lib.utils import get_image_paths + >>> get_image_paths('/path/to/directory') + ['/path/to/directory/image1.jpg', '/path/to/directory/image2.png'] + >>> get_image_paths('/path/to/directory', '.jpg') + ['/path/to/directory/image1.jpg'] + """ + logger = logging.getLogger(__name__) + image_extensions = IMAGE_EXTENSIONS if extension is None else [extension] + dir_contents = [] if not os.path.exists(directory): logger.debug("Creating folder: '%s'", directory) @@ -47,168 +288,663 @@ def get_image_paths(directory): dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name) logger.debug("Scanned Folder contains %s files", len(dir_scanned)) - logger.trace("Scanned Folder Contents: %s", dir_scanned) + logger.trace("Scanned Folder Contents: %s", dir_scanned) # type:ignore[attr-defined] for chkfile in dir_scanned: - if any([chkfile.name.lower().endswith(ext) - for ext in image_extensions]): - logger.trace("Adding '%s' to image list", chkfile.path) + if any(chkfile.name.lower().endswith(ext) for ext in image_extensions): + logger.trace("Adding '%s' to image list", chkfile.path) # type:ignore[attr-defined] dir_contents.append(chkfile.path) logger.debug("Returning %s images", len(dir_contents)) return dir_contents -def hash_image_file(filename): - """ Return an image file's sha1 hash """ - img = cv2.imread(filename) # pylint: disable=no-member - img_hash = sha1(img).hexdigest() - logger.trace("filename: '%s', hash: %s", filename, img_hash) - return img_hash - - -def hash_encode_image(image, extension): - """ Encode the image, get the hash and return the hash with - encoded image """ - img = cv2.imencode(extension, image)[1] # pylint: disable=no-member - f_hash = sha1( - cv2.imdecode(img, cv2.IMREAD_UNCHANGED)).hexdigest() # pylint: disable=no-member - return f_hash, img - - -def backup_file(directory, filename): - """ Backup a given file by appending .bk to the end """ - logger.trace("Backing up: '%s'", filename) - origfile = os.path.join(directory, filename) - backupfile = origfile + '.bk' - if os.path.exists(backupfile): - logger.trace("Removing existing file: '%s'", backup_file) - os.remove(backupfile) - if os.path.exists(origfile): - logger.trace("Renaming: '%s' to '%s'", origfile, backup_file) - os.rename(origfile, backupfile) - - -def set_system_verbosity(loglevel): - """ Set the verbosity level of tensorflow and suppresses - future and deprecation warnings from any modules - From: - https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information - Can be set to: - 0 - all logs shown - 1 - filter out INFO logs - 2 - filter out WARNING logs - 3 - filter out ERROR logs """ - - numeric_level = get_loglevel(loglevel) - loglevel = "2" if numeric_level > 15 else "0" - logger.debug("System Verbosity level: %s", loglevel) - os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel - if loglevel != '0': - for warncat in (FutureWarning, DeprecationWarning, UserWarning): - warnings.simplefilter(action='ignore', category=warncat) - - -def rotate_landmarks(face, rotation_matrix): - # pylint: disable=c-extension-no-member - """ Rotate the landmarks and bounding box for faces - found in rotated images. - Pass in a DetectedFace object, Alignments dict or DLib rectangle""" - logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s", - rotation_matrix, type(face)) - if isinstance(face, DetectedFace): - bounding_box = [[face.x, face.y], - [face.x + face.w, face.y], - [face.x + face.w, face.y + face.h], - [face.x, face.y + face.h]] - landmarks = face.landmarksXY - - elif isinstance(face, dict): - bounding_box = [[face.get("x", 0), face.get("y", 0)], - [face.get("x", 0) + face.get("w", 0), - face.get("y", 0)], - [face.get("x", 0) + face.get("w", 0), - face.get("y", 0) + face.get("h", 0)], - [face.get("x", 0), - face.get("y", 0) + face.get("h", 0)]] - landmarks = face.get("landmarksXY", list()) - - elif isinstance(face, - dlib.rectangle): # pylint: disable=c-extension-no-member - bounding_box = [[face.left(), face.top()], - [face.right(), face.top()], - [face.right(), face.bottom()], - [face.left(), face.bottom()]] - landmarks = list() - else: - raise ValueError("Unsupported face type") - - logger.trace("Original landmarks: %s", landmarks) - - rotation_matrix = cv2.invertAffineTransform( # pylint: disable=no-member - rotation_matrix) - rotated = list() - for item in (bounding_box, landmarks): - if not item: - continue - points = np.array(item, np.int32) - points = np.expand_dims(points, axis=0) - transformed = cv2.transform(points, # pylint: disable=no-member - rotation_matrix).astype(np.int32) - rotated.append(transformed.squeeze()) - - # Bounding box should follow x, y planes, so get min/max - # for non-90 degree rotations - pt_x = min([pnt[0] for pnt in rotated[0]]) - pt_y = min([pnt[1] for pnt in rotated[0]]) - pt_x1 = max([pnt[0] for pnt in rotated[0]]) - pt_y1 = max([pnt[1] for pnt in rotated[0]]) - - if isinstance(face, DetectedFace): - face.x = int(pt_x) - face.y = int(pt_y) - face.w = int(pt_x1 - pt_x) - face.h = int(pt_y1 - pt_y) - face.r = 0 - if len(rotated) > 1: - rotated_landmarks = [tuple(point) for point in rotated[1].tolist()] - face.landmarksXY = rotated_landmarks - elif isinstance(face, dict): - face["x"] = int(pt_x) - face["y"] = int(pt_y) - face["w"] = int(pt_x1 - pt_x) - face["h"] = int(pt_y1 - pt_y) - face["r"] = 0 - if len(rotated) > 1: - rotated_landmarks = [tuple(point) for point in rotated[1].tolist()] - face["landmarksXY"] = rotated_landmarks - else: - rotated_landmarks = dlib.rectangle( # pylint: disable=c-extension-no-member - int(pt_x), int(pt_y), int(pt_x1), int(pt_y1)) - face = rotated_landmarks - - logger.trace("Rotated landmarks: %s", rotated_landmarks) - return face - - -def camel_case_split(identifier): - """ Split a camel case name - from: https://stackoverflow.com/questions/29916065 """ +def get_dpi() -> float | None: + """ Gets the DPI (dots per inch) of the display screen. + + Returns + ------- + float or ``None`` + The DPI of the display screen or ``None`` if the dpi couldn't be obtained (ie: if the + function is called on a headless system) + + Example + ------- + >>> from lib.utils import get_dpi + >>> get_dpi() + 96.0 + """ + logger = logging.getLogger(__name__) + try: + root = tk.Tk() + dpi = root.winfo_fpixels('1i') + except tk.TclError: + logger.warning("Display not detected. Could not obtain DPI") + return None + + return float(dpi) + + +def get_module_objects(module: str) -> list[str]: + """ Return a list of all public objects within the given module + + Parameters + ---------- + module : str + The module to parse for public objects + + Returns + ------- + list[str] + A list of object names that exist within the given module + + Example + ------- + >>> __all__ = get_module_objects(__name__) + ["foo", "bar", "baz"] + """ + return [name_ for name_, obj in inspect.getmembers(sys.modules[module]) + if getattr(obj, "__module__", None) == module + and not name_.startswith("_")] + + +def convert_to_secs(*args: int) -> int: + """ Convert time in hours, minutes, and seconds to seconds. + + Parameters + ---------- + *args: int + 1, 2 or 3 ints. If 2 ints are supplied, then (`minutes`, `seconds`) is implied. If 3 ints + are supplied then (`hours`, `minutes`, `seconds`) is implied. + + Returns + ------- + int + The given time converted to seconds + + Example + ------- + >>> from lib.utils import convert_to_secs + >>> convert_to_secs(1, 30, 0) + 5400 + >>> convert_to_secs(0, 15, 30) + 930 + >>> convert_to_secs(0, 0, 45) + 45 + """ + logger = logging.getLogger(__name__) + logger.debug("from time: %s", args) + retval = 0.0 + if len(args) == 1: + retval = float(args[0]) + elif len(args) == 2: + retval = 60 * float(args[0]) + float(args[1]) + elif len(args) == 3: + retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2]) + retval = int(retval) + logger.debug("to secs: %s", retval) + return retval + + +def full_path_split(path: str) -> list[str]: + """ Split a file path into all of its parts. + + Parameters + ---------- + path: str + The full path to be split + + Returns + ------- + list + The full path split into a separate item for each part + + Example + ------- + >>> from lib.utils import full_path_split + >>> full_path_split("/usr/local/bin/python") + ['usr', 'local', 'bin', 'python'] + >>> full_path_split("relative/path/to/file.txt") + ['relative', 'path', 'to', 'file.txt']] + """ + logger = logging.getLogger(__name__) + allparts: list[str] = [] + while True: + parts = os.path.split(path) + if parts[0] == path: # sentinel for absolute paths + allparts.insert(0, parts[0]) + break + if parts[1] == path: # sentinel for relative paths + allparts.insert(0, parts[1]) + break + path = parts[0] + allparts.insert(0, parts[1]) + logger.trace("path: %s, allparts: %s", path, allparts) # type:ignore[attr-defined] + # Remove any empty strings which may have got inserted + allparts = [part for part in allparts if part] + return allparts + + +def deprecation_warning(function: str, additional_info: str | None = None) -> None: + """ Log a deprecation warning message. + + This function logs a warning message to indicate that the specified function has been + deprecated and will be removed in future. An optional additional message can also be included. + + Parameters + ---------- + function: str + The name of the function that will be deprecated. + additional_info: str, optional + Any additional information to display with the deprecation message. Default: ``None`` + + Example + ------- + >>> from lib.utils import deprecation_warning + >>> deprecation_warning('old_function', 'Use new_function instead.') + """ + logger = logging.getLogger(__name__) + logger.debug("func_name: %s, additional_info: %s", function, additional_info) + msg = f"{function} has been deprecated and will be removed from a future update." + if additional_info is not None: + msg += f" {additional_info}" + logger.warning(msg) + + +def handle_deprecated_cliopts(arguments: Namespace) -> Namespace: + """ Handle deprecated command line arguments and update to correct argument. + + Deprecated cli opts will be provided in the following format: + `"depr___"` + + Parameters + ---------- + arguments: :class:`argpares.Namespace` + The passed in faceswap cli arguments + + Returns + ------- + :class:`argpares.Namespace` + The cli arguments with deprecated values mapped to the correct entry + """ + logger = logging.getLogger(__name__) + + for key, selected in vars(arguments).items(): + if not key.startswith("depr_") or key.startswith("depr_") and selected is None: + continue # Not a deprecated opt + if isinstance(selected, bool) and not selected: + continue # store-true opt with default value + + opt, old, new = key.replace("depr_", "").rsplit("_", maxsplit=2) + deprecation_warning(f"Command line option '-{old}'", f"Use '-{new}, --{opt}' instead") + + exist = getattr(arguments, opt) + if exist == selected: + logger.debug("Keeping existing '%s' value of '%s'", opt, exist) + else: + logger.debug("Updating arg '%s' from '%s' to '%s' from deprecated opt", + opt, exist, selected) + + return arguments + + +def camel_case_split(identifier: str) -> list[str]: + """ Split a camelCase string into a list of its individual parts + + Parameters + ---------- + identifier: str + The camelCase text to be split + + Returns + ------- + list[str] + A list of the individual parts of the camelCase string. + + References + ---------- + https://stackoverflow.com/questions/29916065 + + Example + ------- + >>> from lib.utils import camel_case_split + >>> camel_case_split('camelCaseExample') + ['camel', 'Case', 'Example'] + """ matches = finditer( ".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier) return [m.group(0) for m in matches] -def safe_shutdown(): - """ Close queues, threads and processes in event of crash """ +def safe_shutdown(got_error: bool = False) -> None: + """ Safely shut down the system. + + This function terminates the queue manager and exits the program in a clean and orderly manner. + An optional boolean parameter can be used to indicate whether an error occurred during the + program's execution. + + Parameters + ---------- + got_error: bool, optional + ``True`` if this function is being called as the result of raised error. Default: ``False`` + + Example + ------- + >>> from lib.utils import safe_shutdown + >>> safe_shutdown() + >>> safe_shutdown(True) + """ + logger = logging.getLogger(__name__) logger.debug("Safely shutting down") - from lib.queue_manager import queue_manager - from lib.multithreading import terminate_processes + from lib.queue_manager import queue_manager # pylint:disable=import-outside-toplevel queue_manager.terminate_queues() - terminate_processes() logger.debug("Cleanup complete. Shutting down queue manager and exiting") - queue_manager._log_queue.put(None) # pylint: disable=protected-access - while not queue_manager._log_queue.empty(): # pylint: disable=protected-access - continue - queue_manager.manager.shutdown() + sys.exit(1 if got_error else 0) + + +class FaceswapError(Exception): + """ Faceswap Error for handling specific errors with useful information. + + Raises + ------ + FaceswapError + on a captured error + + Example + ------- + >>> from lib.utils import FaceswapError + >>> try: + ... # Some code that may raise an error + ... except SomeError: + ... raise FaceswapError("There was an error while running the code") + FaceswapError: There was an error while running the code + """ + pass # pylint:disable=unnecessary-pass + + +class GetModel(): + """ Check for models in the cache path. + + If available, return the path, if not available, get, unzip and install model + + Parameters + ---------- + model_filename: str or list + The name of the model to be loaded (see notes below) + git_model_id: int + The second digit in the github tag that identifies this model. See + https://github.com/deepfakes-models/faceswap-models for more information + + Notes + ------ + Models must have a certain naming convention: `_v.` + (eg: `s3fd_v1.pb`). + + Multiple models can exist within the model_filename. They should be passed as a list and follow + the same naming convention as above. Any differences in filename should occur AFTER the version + number: `_v.` (eg: + `["mtcnn_det_v1.1.py", "mtcnn_det_v1.2.py", "mtcnn_det_v1.3.py"]`, `["resnet_ssd_v1.caffemodel" + ,"resnet_ssd_v1.prototext"]` + + Example + ------- + >>> from lib.utils import GetModel + >>> model_downloader = GetModel("s3fd_keras_v2.h5", 11) + """ + + def __init__(self, model_filename: str | list[str], git_model_id: int) -> None: + self.logger = logging.getLogger(__name__) + if not isinstance(model_filename, list): + model_filename = [model_filename] + self._model_filename = model_filename + self._cache_dir = os.path.join(PROJECT_ROOT, ".fs_cache") + self._git_model_id = git_model_id + self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download" + self._chunk_size = 1024 # Chunk size for downloading and unzipping + self._retries = 6 + self._get() + + @property + def _model_full_name(self) -> str: + """ str: The full model name from the filename(s). """ + common_prefix = os.path.commonprefix(self._model_filename) + retval = os.path.splitext(common_prefix)[0] + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def _model_name(self) -> str: + """ str: The model name from the model's full name. """ + retval = self._model_full_name[:self._model_full_name.rfind("_")] + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def _model_version(self) -> int: + """ int: The model's version number from the model full name. """ + retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:]) + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def model_path(self) -> str | list[str]: + """ str or list[str]: The model path(s) in the cache folder. + + Example + ------- + >>> from lib.utils import GetModel + >>> model_downloader = GetModel("s3fd_keras_v2.h5", 11) + >>> model_downloader.model_path + '/path/to/s3fd_keras_v2.h5' + """ + paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename] + retval: str | list[str] = paths[0] if len(paths) == 1 else paths + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def _model_zip_path(self) -> str: + """ str: The full path to downloaded zip file. """ + retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip") + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def _model_exists(self) -> bool: + """ bool: ``True`` if the model exists in the cache folder otherwise ``False``. """ + if isinstance(self.model_path, list): + retval = all(os.path.exists(pth) for pth in self.model_path) + else: + retval = os.path.exists(self.model_path) + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def _url_download(self) -> str: + """ strL Base download URL for models. """ + tag = f"v{self._git_model_id}.{self._model_version}" + retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip" + self.logger.trace("Download url: %s", retval) # type:ignore[attr-defined] + return retval + + @property + def _url_partial_size(self) -> int: + """ int: How many bytes have already been downloaded. """ + zip_file = self._model_zip_path + retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0 + self.logger.trace(retval) # type:ignore[attr-defined] + return retval + + def _get(self) -> None: + """ Check the model exists, if not, download the model, unzip it and place it in the + model's cache folder. """ + if self._model_exists: + self.logger.debug("Model exists: %s", self.model_path) + return + self._download_model() + self._unzip_model() + os.remove(self._model_zip_path) + + def _download_model(self) -> None: + """ Download the model zip from github to the cache folder. """ + self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download) + for attempt in range(self._retries): + try: + downloaded_size = self._url_partial_size + req = request.Request(self._url_download) + if downloaded_size != 0: + req.add_header("Range", f"bytes={downloaded_size}-") + with request.urlopen(req, timeout=10) as response: + self.logger.debug("header info: {%s}", response.info()) + self.logger.debug("Return Code: %s", response.getcode()) + self._write_zipfile(response, downloaded_size) + break + except (socket_error, socket_timeout, + urlliberror.HTTPError, urlliberror.URLError) as err: + if attempt + 1 < self._retries: + self.logger.warning("Error downloading model (%s). Retrying %s of %s...", + str(err), attempt + 2, self._retries) + else: + self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: " + "'%s')", str(err), self._url_download) + self.logger.info("You can try running again to resume the download.") + self.logger.info("Alternatively, you can manually download the model from: %s " + "and unzip the contents to: %s", + self._url_download, self._cache_dir) + sys.exit(1) + + def _write_zipfile(self, response: HTTPResponse, downloaded_size: int) -> None: + """ Write the model zip file to disk. + + Parameters + ---------- + response: :class:`http.client.HTTPResponse` + The response from the model download task + downloaded_size: int + The amount of bytes downloaded so far + """ + content_length = response.getheader("content-length") + content_length = "0" if content_length is None else content_length + length = int(content_length) + downloaded_size + if length == downloaded_size: + self.logger.info("Zip already exists. Skipping download") + return + write_type = "wb" if downloaded_size == 0 else "ab" + assert tqdm is not None + with open(self._model_zip_path, write_type) as out_file: + pbar = tqdm(desc="Downloading", + unit="B", + total=length, + unit_scale=True, + unit_divisor=1024) + if downloaded_size != 0: + pbar.update(downloaded_size) + while True: + buffer = response.read(self._chunk_size) + if not buffer: + break + pbar.update(len(buffer)) + out_file.write(buffer) + pbar.close() + + def _unzip_model(self) -> None: + """ Unzip the model file to the cache folder """ + self.logger.info("Extracting: '%s'", self._model_name) + try: + with zipfile.ZipFile(self._model_zip_path, "r") as zip_file: + self._write_model(zip_file) + except Exception as err: # pylint:disable=broad-except + self.logger.error("Unable to extract model file: %s", str(err)) + sys.exit(1) + + def _write_model(self, zip_file: zipfile.ZipFile) -> None: + """ Extract files from zip file and write, with progress bar. + + Parameters + ---------- + zip_file: :class:`zipfile.ZipFile` + The downloaded model zip file + """ + length = sum(f.file_size for f in zip_file.infolist()) + fnames = zip_file.namelist() + self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length) + assert tqdm is not None + pbar = tqdm(desc="Decompressing", + unit="B", + total=length, + unit_scale=True, + unit_divisor=1024) + for fname in fnames: + out_fname = os.path.join(self._cache_dir, fname) + self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname) + zipped = zip_file.open(fname) + with open(out_fname, "wb") as out_file: + while True: + buffer = zipped.read(self._chunk_size) + if not buffer: + break + pbar.update(len(buffer)) + out_file.write(buffer) + pbar.close() + + +class DebugTimes(): + """ A simple tool to help debug timings. + + Parameters + ---------- + min: bool, Optional + Display minimum time taken in summary stats. Default: ``True`` + mean: bool, Optional + Display mean time taken in summary stats. Default: ``True`` + max: bool, Optional + Display maximum time taken in summary stats. Default: ``True`` + + Example + ------- + >>> from lib.utils import DebugTimes + >>> debug_times = DebugTimes() + >>> debug_times.step_start("step 1") + >>> # do something here + >>> debug_times.step_end("step 1") + >>> debug_times.summary() + ---------------------------------- + Step Count Min + ---------------------------------- + step 1 1 0.000000 + """ + def __init__(self, + show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None: + self._times: dict[str, list[float]] = {} + self._steps: dict[str, float] = {} + self._interval = 1 + self._display = {"min": show_min, "mean": show_mean, "max": show_max} + + def step_start(self, name: str, record: bool = True) -> None: + """ Start the timer for the given step name. + + Parameters + ---------- + name: str + The name of the step to start the timer for + record: bool, optional + ``True`` to record the step time, ``False`` to not record it. + Used for when you have conditional code to time, but do not want to insert if/else + statements in the code. Default: `True` + + Example + ------- + >>> from lib.util import DebugTimes + >>> debug_times = DebugTimes() + >>> debug_times.step_start("Example Step") + >>> # do something here + >>> debug_times.step_end("Example Step") + """ + if not record: + return + storename = name + str(get_ident()) + self._steps[storename] = time() + + def step_end(self, name: str, record: bool = True) -> None: + """ Stop the timer and record elapsed time for the given step name. + + Parameters + ---------- + name: str + The name of the step to end the timer for + record: bool, optional + ``True`` to record the step time, ``False`` to not record it. + Used for when you have conditional code to time, but do not want to insert if/else + statements in the code. Default: `True` + + Example + ------- + >>> from lib.util import DebugTimes + >>> debug_times = DebugTimes() + >>> debug_times.step_start("Example Step") + >>> # do something here + >>> debug_times.step_end("Example Step") + """ + if not record: + return + storename = name + str(get_ident()) + self._times.setdefault(name, []).append(time() - self._steps.pop(storename)) + + @classmethod + def _format_column(cls, text: str, width: int) -> str: + """ Pad the given text to be aligned to the given width. + + Parameters + ---------- + text: str + The text to be formatted + width: int + The size of the column to insert the text into + + Returns + ------- + str + The text with the correct amount of padding applied + """ + return f"{text}{' ' * (width - len(text))}" + + def summary(self, decimal_places: int = 6, interval: int = 1) -> None: + """ Print a summary of step times. + + Parameters + ---------- + decimal_places: int, optional + The number of decimal places to display the summary elapsed times to. Default: 6 + interval: int, optional + How many times summary must be called before printing to console. Default: 1 + + Example + ------- + >>> from lib.utils import DebugTimes + >>> debug = DebugTimes() + >>> debug.step_start("test") + >>> time.sleep(0.5) + >>> debug.step_end("test") + >>> debug.summary() + ---------------------------------- + Step Count Min + ---------------------------------- + test 1 0.500000 + """ + interval = max(1, interval) + if interval != self._interval: + self._interval += 1 + return + + name_col = max(len(key) for key in self._times) + 4 + items_col = 8 + time_col = (decimal_places + 4) * sum(1 for v in self._display.values() if v) + separator = "-" * (name_col + items_col + time_col) + print("") + print(separator) + header = (f"{self._format_column('Step', name_col)}" + f"{self._format_column('Count', items_col)}") + header += f"{self._format_column('Min', time_col)}" if self._display["min"] else "" + header += f"{self._format_column('Avg', time_col)}" if self._display["mean"] else "" + header += f"{self._format_column('Max', time_col)}" if self._display["max"] else "" + print(header) + print(separator) + assert np is not None + for key, val in self._times.items(): + num = str(len(val)) + contents = f"{self._format_column(key, name_col)}{self._format_column(num, items_col)}" + if self._display["min"]: + _min = f"{np.min(val):.{decimal_places}f}" + contents += f"{self._format_column(_min, time_col)}" + if self._display["mean"]: + avg = f"{np.mean(val):.{decimal_places}f}" + contents += f"{self._format_column(avg, time_col)}" + if self._display["max"]: + _max = f"{np.max(val):.{decimal_places}f}" + contents += f"{self._format_column(_max, time_col)}" + print(contents) + self._interval = 1 + + +__all__ = get_module_objects(__name__) diff --git a/locales/es/LC_MESSAGES/faceswap.mo b/locales/es/LC_MESSAGES/faceswap.mo new file mode 100644 index 0000000000..724deab30e Binary files /dev/null and b/locales/es/LC_MESSAGES/faceswap.mo differ diff --git a/locales/es/LC_MESSAGES/faceswap.po b/locales/es/LC_MESSAGES/faceswap.po new file mode 100644 index 0000000000..c2c6381f88 --- /dev/null +++ b/locales/es/LC_MESSAGES/faceswap.po @@ -0,0 +1,34 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"POT-Creation-Date: 2021-02-18 23:48-0000\n" +"PO-Revision-Date: 2021-02-19 17:37+0000\n" +"Language-Team: tokafondo\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 2.3\n" +"Last-Translator: \n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Language: es_ES\n" + +#: faceswap.py:43 +msgid "Extract the faces from pictures or a video" +msgstr "Extraer las caras de las fotos o de un vídeo" + +#: faceswap.py:44 +msgid "Train a model for the two faces A and B" +msgstr "Entrenar un modelo para las dos caras A y B" + +#: faceswap.py:47 +msgid "Convert source pictures or video to a new one with the face swapped" +msgstr "Convertir las imágenes o el vídeo de origen en uno nuevo con la cara cambiada" + +#: faceswap.py:48 +msgid "Launch the Faceswap Graphical User Interface" +msgstr "Inicie la interfaz gráfica de usuario (GUI) de Faceswap" diff --git a/locales/es/LC_MESSAGES/gui.menu.mo b/locales/es/LC_MESSAGES/gui.menu.mo new file mode 100644 index 0000000000..f31697bbbc Binary files /dev/null and b/locales/es/LC_MESSAGES/gui.menu.mo differ diff --git a/locales/es/LC_MESSAGES/gui.menu.po b/locales/es/LC_MESSAGES/gui.menu.po new file mode 100644 index 0000000000..ba02769135 --- /dev/null +++ b/locales/es/LC_MESSAGES/gui.menu.po @@ -0,0 +1,155 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-06-07 13:54+0100\n" +"PO-Revision-Date: 2023-06-07 14:11+0100\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.3.1\n" + +#: lib/gui/menu.py:37 +msgid "faceswap.dev - Guides and Forum" +msgstr "faceswap.dev - Guías y foro" + +#: lib/gui/menu.py:38 +msgid "Patreon - Support this project" +msgstr "Patreon - Apoya este proyecto" + +#: lib/gui/menu.py:39 +msgid "Discord - The FaceSwap Discord server" +msgstr "Discord - El servidor de Discord de FaceSwap" + +#: lib/gui/menu.py:40 +msgid "Github - Our Source Code" +msgstr "Github - Nuestro código fuente" + +#: lib/gui/menu.py:60 +msgid "File" +msgstr "" + +#: lib/gui/menu.py:61 +msgid "Settings" +msgstr "" + +#: lib/gui/menu.py:62 +msgid "Help" +msgstr "" + +#: lib/gui/menu.py:85 +msgid "Configure Settings..." +msgstr "" + +#: lib/gui/menu.py:116 +msgid "New Project..." +msgstr "" + +#: lib/gui/menu.py:121 +msgid "Open Project..." +msgstr "" + +#: lib/gui/menu.py:126 +msgid "Save Project" +msgstr "" + +#: lib/gui/menu.py:131 +msgid "Save Project as..." +msgstr "" + +#: lib/gui/menu.py:136 +msgid "Reload Project from Disk" +msgstr "" + +#: lib/gui/menu.py:141 +msgid "Close Project" +msgstr "" + +#: lib/gui/menu.py:147 +msgid "Open Task..." +msgstr "" + +#: lib/gui/menu.py:154 +msgid "Open recent" +msgstr "" + +#: lib/gui/menu.py:156 +msgid "Quit" +msgstr "" + +#: lib/gui/menu.py:211 +msgid "{} Task" +msgstr "" + +#: lib/gui/menu.py:223 +msgid "Clear recent files" +msgstr "" + +#: lib/gui/menu.py:391 +msgid "Check for updates..." +msgstr "" + +#: lib/gui/menu.py:394 +msgid "Update Faceswap..." +msgstr "" + +#: lib/gui/menu.py:398 +msgid "Switch Branch" +msgstr "" + +#: lib/gui/menu.py:401 +msgid "Resources" +msgstr "" + +#: lib/gui/menu.py:404 +msgid "Output System Information" +msgstr "" + +#: lib/gui/menu.py:589 +msgid "currently selected Task" +msgstr "tarea actualmente seleccionada" + +#: lib/gui/menu.py:589 +msgid "Project" +msgstr "Proyecto" + +#: lib/gui/menu.py:591 +msgid "Reload {} from disk" +msgstr "Recargar {} del disco" + +#: lib/gui/menu.py:593 +msgid "Create a new {}..." +msgstr "Crear un nuevo {}..." + +#: lib/gui/menu.py:595 +msgid "Reset {} to default" +msgstr "Reiniciar {} a los ajustes por defecto" + +#: lib/gui/menu.py:597 +msgid "Save {}" +msgstr "Guardar {}" + +#: lib/gui/menu.py:599 +msgid "Save {} as..." +msgstr "Guardar {} como..." + +#: lib/gui/menu.py:603 +msgid " from a task or project file" +msgstr " de un archivo de tarea o proyecto" + +#: lib/gui/menu.py:604 +msgid "Load {}..." +msgstr "Cargar {}..." + +#: lib/gui/menu.py:659 +msgid "Configure {} settings..." +msgstr "Configurar los ajustes de {}..." diff --git a/locales/es/LC_MESSAGES/gui.tooltips.mo b/locales/es/LC_MESSAGES/gui.tooltips.mo new file mode 100644 index 0000000000..9df3225181 Binary files /dev/null and b/locales/es/LC_MESSAGES/gui.tooltips.mo differ diff --git a/locales/es/LC_MESSAGES/gui.tooltips.po b/locales/es/LC_MESSAGES/gui.tooltips.po new file mode 100644 index 0000000000..ab23f031fc --- /dev/null +++ b/locales/es/LC_MESSAGES/gui.tooltips.po @@ -0,0 +1,210 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"POT-Creation-Date: 2021-03-22 18:37+0000\n" +"PO-Revision-Date: 2023-06-07 14:12+0100\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.3.1\n" + +#: lib/gui/command.py:184 +msgid "Output command line options to the console" +msgstr "Devuelve las opciones de la línea de comandos a la consola" + +#: lib/gui/command.py:195 +msgid "Run the {} script" +msgstr "Ejecuta el script {}" + +#: lib/gui/control_helper.py:1234 +msgid "Select a folder..." +msgstr "" + +#: lib/gui/control_helper.py:1235 lib/gui/control_helper.py:1236 +msgid "Select a file..." +msgstr "" + +#: lib/gui/control_helper.py:1237 +msgid "Select a folder of images..." +msgstr "" + +#: lib/gui/control_helper.py:1238 +msgid "Select a video..." +msgstr "" + +#: lib/gui/control_helper.py:1239 +msgid "Select a model folder..." +msgstr "" + +#: lib/gui/control_helper.py:1240 +msgid "Select one or more files..." +msgstr "" + +#: lib/gui/control_helper.py:1241 +msgid "Select a file or folder..." +msgstr "" + +#: lib/gui/control_helper.py:1242 +msgid "Select a save location..." +msgstr "" + +#: lib/gui/display.py:71 +msgid "Summary statistics for each training session" +msgstr "Resumen de estadísticas para cada sesión de entrenamiento" + +#: lib/gui/display.py:113 +msgid "Preview updates every 5 seconds" +msgstr "Previsualiza actualizaciones cada 5 segundos" + +#: lib/gui/display.py:122 +msgid "Graph showing Loss vs Iterations" +msgstr "Gráfico mostrando Pérdida contra iteraciones" + +#: lib/gui/display.py:125 +msgid "Training preview. Updated on every save iteration" +msgstr "" +"Previsualización del entrenamiento. Actualizado en cada iteración de guardado" + +#: lib/gui/display_analysis.py:342 +msgid "Load/Refresh stats for the currently training session" +msgstr "Carga/Refresca estadísticas para la sesión actual de entrenamiento" + +#: lib/gui/display_analysis.py:344 +msgid "Clear currently displayed session stats" +msgstr "Borra las estadísticas mostradas de la sesión" + +#: lib/gui/display_analysis.py:346 +msgid "Save session stats to csv" +msgstr "Guarda las estadísticas de la sesión a un archivo csv" + +#: lib/gui/display_analysis.py:348 +msgid "Load saved session stats" +msgstr "Carga estadísticas de sesión ya guardadas" + +#: lib/gui/display_command.py:94 +msgid "Preview updates at every model save. Click to refresh now." +msgstr "" +"Previsualización de actualizaciones cada guardado de modelo. Pulsar para " +"actualizar ahora." + +#: lib/gui/display_command.py:261 +msgid "Graph updates at every model save. Click to refresh now." +msgstr "" +"Previsualización de gráficos cada guardado de modelo. Pulsar para actualizar " +"ahora." + +#: lib/gui/display_command.py:275 +msgid "Display the raw loss data" +msgstr "Muestra los datos de pérdida sin procesar" + +#: lib/gui/display_command.py:287 +msgid "Display the smoothed loss data" +msgstr "Muestra los datos de pérdida regularizados" + +#: lib/gui/display_command.py:294 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing." +msgstr "" +"Ajusta el nivel de regularización. 0 es sin regularización, 0.99 es máxima " +"regularización." + +#: lib/gui/display_command.py:324 +msgid "Set the number of iterations to display. 0 displays the full session." +msgstr "" +"Ajusta el número de iteraciones a mostrar. 0 muestra la sesión completa." + +#: lib/gui/display_page.py:238 +msgid "Save {}(s) to file" +msgstr "Grabar {} a un fichero" + +#: lib/gui/display_page.py:250 +msgid "Enable or disable {} display" +msgstr "Activar o desactivar la muestra de {}" + +#: lib/gui/popup_configure.py:209 +msgid "Close without saving" +msgstr "Cerrar sin guardar" + +#: lib/gui/popup_configure.py:210 +msgid "Save this page's config" +msgstr "Guardar la configuración de esta página" + +#: lib/gui/popup_configure.py:211 +msgid "Reset this page's config to default values" +msgstr "Reiniciar la configuración de esta página a sus valores por defecto" + +#: lib/gui/popup_configure.py:213 +msgid "Save all settings for the currently selected config" +msgstr "Guardar todos los ajustes para la configuración seleccionada" + +#: lib/gui/popup_configure.py:216 +msgid "Reset all settings for the currently selected config to default values" +msgstr "" +"Reiniciar todos los ajustes de la configuración seleccionada a sus ajustes " +"por defecto" + +#: lib/gui/popup_configure.py:538 +msgid "Select a plugin to configure:" +msgstr "" + +#: lib/gui/popup_session.py:191 +msgid "Display {}" +msgstr "Mostrar {}" + +#: lib/gui/popup_session.py:342 +msgid "Refresh graph" +msgstr "Resfrescar gráfico" + +#: lib/gui/popup_session.py:344 +msgid "Save display data to csv" +msgstr "Guardar datos de muestra a un archivo csv" + +#: lib/gui/popup_session.py:346 +msgid "Number of data points to sample for rolling average" +msgstr "Número de puntos de datos a muestrear para la media móvil" + +#: lib/gui/popup_session.py:348 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing" +msgstr "" +"Establezca la cantidad de regularización. 0 es sin regularización, 0,99 es " +"máxima regularización" + +#: lib/gui/popup_session.py:350 +msgid "" +"Flatten data points that fall more than 1 standard deviation from the mean " +"to the mean value." +msgstr "" +"Aplanar los puntos de datos que se alejan más de 1 desviación estándar de la " +"media al valor medio." + +#: lib/gui/popup_session.py:353 +msgid "Display rolling average of the data" +msgstr "Mostrar la media móvil de los datos" + +#: lib/gui/popup_session.py:355 +msgid "Smooth the data" +msgstr "Regularizar los datos" + +#: lib/gui/popup_session.py:357 +msgid "Display raw data" +msgstr "Mostrar los datos sin procesar" + +#: lib/gui/popup_session.py:359 +msgid "Display polynormal data trend" +msgstr "Mostrar la tendencia de los datos polinormales" + +#: lib/gui/popup_session.py:361 +msgid "Set the data to display" +msgstr "Ajustar los datos a mostrar" + +#: lib/gui/popup_session.py:363 +msgid "Change y-axis scale" +msgstr "Cambiar la escala del eje Y" diff --git a/locales/es/LC_MESSAGES/lib.cli.args.mo b/locales/es/LC_MESSAGES/lib.cli.args.mo new file mode 100644 index 0000000000..914c7bf3b8 Binary files /dev/null and b/locales/es/LC_MESSAGES/lib.cli.args.mo differ diff --git a/locales/es/LC_MESSAGES/lib.cli.args.po b/locales/es/LC_MESSAGES/lib.cli.args.po new file mode 100755 index 0000000000..d6b64bef41 --- /dev/null +++ b/locales/es/LC_MESSAGES/lib.cli.args.po @@ -0,0 +1,59 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 18:06+0000\n" +"PO-Revision-Date: 2024-03-28 18:14+0000\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: lib/cli/args.py:188 lib/cli/args.py:199 lib/cli/args.py:208 +#: lib/cli/args.py:219 +msgid "Global Options" +msgstr "Opciones Globales" + +#: lib/cli/args.py:190 +msgid "" +"R|Exclude GPUs from use by Faceswap. Select the number(s) which correspond " +"to any GPU(s) that you do not wish to be made available to Faceswap. " +"Selecting all GPUs here will force Faceswap into CPU mode.\n" +"L|{}" +msgstr "" +"R|Excluir GPUs de su uso por Faceswap. Seleccione el/los número(s) que " +"correpondan a cualquier GPU(s) que no desee que esté disponible para su uso " +"con Faceswap. Marcar todas las GPUs forzará a Faceswap a usar sólo la CPU,\n" +"L|{}" + +#: lib/cli/args.py:201 +msgid "" +"Optionally overide the saved config with the path to a custom config file." +msgstr "Usar un fichero alternativo de configuración, almacenado en esta ruta." + +#: lib/cli/args.py:210 +msgid "" +"Log level. Stick with INFO or VERBOSE unless you need to file an error " +"report. Be careful with TRACE as it will generate a lot of data" +msgstr "" +"Nivel de registro. Dejarlo en INFO o VERBOSE, a menos que necesite informar " +"de un error. Tenga en cuenta que TRACE generará muchísima información" + +#: lib/cli/args.py:220 +msgid "Path to store the logfile. Leave blank to store in the faceswap folder" +msgstr "" +"Ruta para almacenar el fichero de registro. Dejarlo en blanco para " +"almacenarlo en la carpeta pde instalación de faceswap" + +#: lib/cli/args.py:319 +msgid "Output to Shell console instead of GUI console" +msgstr "Salida a la consola Shell en lugar de la consola GUI" diff --git a/locales/es/LC_MESSAGES/lib.cli.args_extract_convert.mo b/locales/es/LC_MESSAGES/lib.cli.args_extract_convert.mo new file mode 100644 index 0000000000..5ccd0c2f21 Binary files /dev/null and b/locales/es/LC_MESSAGES/lib.cli.args_extract_convert.mo differ diff --git a/locales/es/LC_MESSAGES/lib.cli.args_extract_convert.po b/locales/es/LC_MESSAGES/lib.cli.args_extract_convert.po new file mode 100755 index 0000000000..4b2e299256 --- /dev/null +++ b/locales/es/LC_MESSAGES/lib.cli.args_extract_convert.po @@ -0,0 +1,720 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-12 11:56+0100\n" +"PO-Revision-Date: 2024-04-12 12:02+0100\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: lib/cli/args_extract_convert.py:46 lib/cli/args_extract_convert.py:56 +#: lib/cli/args_extract_convert.py:64 lib/cli/args_extract_convert.py:122 +#: lib/cli/args_extract_convert.py:483 lib/cli/args_extract_convert.py:492 +msgid "Data" +msgstr "Datos" + +#: lib/cli/args_extract_convert.py:48 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file. NB: This should be the source video/" +"frames NOT the source faces." +msgstr "" +"Directorio o vídeo de entrada. Un directorio que contenga los archivos de " +"imagen que desea procesar o la ruta a un archivo de vídeo. NB: Debe ser el " +"vídeo/los fotogramas de origen, NO las caras de origen." + +#: lib/cli/args_extract_convert.py:57 +msgid "Output directory. This is where the converted files will be saved." +msgstr "" +"Directorio de salida. Aquí es donde se guardarán los archivos convertidos." + +#: lib/cli/args_extract_convert.py:66 +msgid "" +"Optional path to an alignments file. Leave blank if the alignments file is " +"at the default location." +msgstr "" +"Ruta opcional a un archivo de alineaciones. Dejar en blanco si el archivo de " +"alineaciones está en la ubicación por defecto." + +#: lib/cli/args_extract_convert.py:97 +msgid "" +"Extract faces from image or video sources.\n" +"Extraction plugins can be configured in the 'Settings' Menu" +msgstr "" +"Extrae caras de fuentes de imagen o video.\n" +"Los plugins de extracción pueden ser configuradas en el menú de 'Ajustes'" + +#: lib/cli/args_extract_convert.py:124 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple videos and/or folders of images you wish to extract from. The faces " +"will be output to separate sub-folders in the output_dir." +msgstr "" +"Si se selecciona, input_dir debe ser una carpeta principal que contenga " +"varios videos y/o carpetas de imágenes de las que desea extraer. Las caras " +"se enviarán a subcarpetas separadas en output_dir." + +#: lib/cli/args_extract_convert.py:133 lib/cli/args_extract_convert.py:152 +#: lib/cli/args_extract_convert.py:167 lib/cli/args_extract_convert.py:206 +#: lib/cli/args_extract_convert.py:224 lib/cli/args_extract_convert.py:237 +#: lib/cli/args_extract_convert.py:247 lib/cli/args_extract_convert.py:257 +#: lib/cli/args_extract_convert.py:503 lib/cli/args_extract_convert.py:529 +#: lib/cli/args_extract_convert.py:568 +msgid "Plugins" +msgstr "Extensiones" + +#: lib/cli/args_extract_convert.py:135 +msgid "" +"R|Detector to use. Some of these have configurable settings in '/config/" +"extract.ini' or 'Settings > Configure Extract 'Plugins':\n" +"L|cv2-dnn: A CPU only extractor which is the least reliable and least " +"resource intensive. Use this if not using a GPU and time is important.\n" +"L|mtcnn: Good detector. Fast on CPU, faster on GPU. Uses fewer resources " +"than other GPU detectors but can often return more false positives.\n" +"L|s3fd: Best detector. Slow on CPU, faster on GPU. Can detect more faces and " +"fewer false positives than other GPU detectors, but is a lot more resource " +"intensive.\n" +"L|external: Import a face detection bounding box from a json file. " +"(configurable in Detect settings)" +msgstr "" +"R|Detector de caras a usar. Algunos tienen ajustes configurables en '/config/" +"extract.ini' o 'Ajustes > Configurar Extensiones de Extracción:\n" +"L|cv2-dnn: Extractor que usa sólo la CPU. Es el menos fiable y el que menos " +"recursos usa. Elegir este si necesita rapidez y no usar la GPU.\n" +"L|mtcnn: Buen detector. Rápido en la CPU y más rápido en la GPU. Usa menos " +"recursos que otros detectores basados en GPU, pero puede devolver más falsos " +"positivos.\n" +"L|s3fd: El mejor detector. Lento en la CPU, y más rápido en la GPU. Puede " +"detectar más caras y tiene menos falsos positivos que otros detectores " +"basados en GPU, pero uso muchos más recursos.\n" +"L|external: importe un cuadro de detección de detección de cara desde un " +"archivo JSON. (configurable en la configuración de detección)" + +#: lib/cli/args_extract_convert.py:154 +msgid "" +"R|Aligner to use.\n" +"L|cv2-dnn: A CPU only landmark detector. Faster, less resource intensive, " +"but less accurate. Only use this if not using a GPU and time is important.\n" +"L|fan: Best aligner. Fast on GPU, slow on CPU.\n" +"L|external: Import 68 point 2D landmarks or an aligned bounding box from a " +"json file. (configurable in Align settings)" +msgstr "" +"R|Alineador a usar.\n" +"L|cv2-dnn: Detector que usa sólo la CPU. Más rápido, usa menos recursos, " +"pero es menos preciso. Elegir este si necesita rapidez y no usar la GPU.\n" +"L|fan: El mejor alineador. Rápido en la GPU, y lento en la CPU.\n" +"L|external: importar 68 puntos 2D Modos de referencia o un cuadro " +"delimitador alineado de un archivo JSON. (configurable en la configuración " +"alineada)" + +#: lib/cli/args_extract_convert.py:169 +msgid "" +"R|Additional Masker(s) to use. The masks generated here will all take up GPU " +"RAM. You can select none, one or multiple masks, but the extraction may take " +"longer the more you select. NB: The Extended and Components (landmark based) " +"masks are automatically generated on extraction.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU so will not use any additional VRAM.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"The auto generated masks are as follows:\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" +msgstr "" +"R|Enmascarador(es) adicional(es) a usar. Las máscaras generadas aquí usarán " +"todas RAM de la GPU. Puede seleccionar una, varias o ninguna máscaras, pero " +"la extracción tardará más cuanto más marque. Las máscaras Extended y " +"Components son siempre generadas durante la extracción.\n" +"L|bisenet-fp: Máscara relativamente ligera basada en NN que proporciona un " +"control más refinado sobre el área a enmascarar, incluido el enmascaramiento " +"completo de la cabeza (configurable en la configuración de la máscara).\n" +"L|custom: Una máscara ficticia que llena el área de la máscara con 1 o 0 " +"(configurable en la configuración). Esto solo es necesario si tiene la " +"intención de editar manualmente las máscaras personalizadas usted mismo en " +"la herramienta manual. Esta máscara no usa la GPU, por lo que no usará VRAM " +"adicional.\n" +"L|vgg-clear: Máscara diseñada para proporcionar una segmentación inteligente " +"de rostros principalmente frontales y libres de obstrucciones. Los rostros " +"de perfil y las obstrucciones pueden dar lugar a un rendimiento inferior.\n" +"L|vgg-obstructed: Máscara diseñada para proporcionar una segmentación " +"inteligente de rostros principalmente frontales. El modelo de la máscara ha " +"sido entrenado específicamente para reconocer algunas obstrucciones faciales " +"(manos y gafas). Los rostros de perfil pueden dar lugar a un rendimiento " +"inferior.\n" +"L|unet-dfl: Máscara diseñada para proporcionar una segmentación inteligente " +"de rostros principalmente frontales. El modelo de máscara ha sido entrenado " +"por los miembros de la comunidad y necesitará ser probado para una mayor " +"descripción. Los rostros de perfil pueden dar lugar a un rendimiento " +"inferior.\n" +"Las máscaras que siempre se generan son:\n" +"L|components: Máscara diseñada para proporcionar una segmentación facial " +"basada en el posicionamiento de las ubicaciones de los puntos de referencia. " +"Se construye un casco convexo alrededor del exterior de los puntos de " +"referencia para crear una máscara.\n" +"L|extended: Máscara diseñada para proporcionar una segmentación facial " +"basada en el posicionamiento de las ubicaciones de los puntos de referencia. " +"Se construye un casco convexo alrededor del exterior de los puntos de " +"referencia y la máscara se extiende hacia arriba en la frente.\n" +"(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" + +#: lib/cli/args_extract_convert.py:208 +msgid "" +"R|Performing normalization can help the aligner better align faces with " +"difficult lighting conditions at an extraction speed cost. Different methods " +"will yield different results on different sets. NB: This does not impact the " +"output face, just the input to the aligner.\n" +"L|none: Don't perform normalization on the face.\n" +"L|clahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"L|hist: Equalize the histograms on the RGB channels.\n" +"L|mean: Normalize the face colors to the mean." +msgstr "" +"R|Realizar la normalización puede ayudar al alineador a alinear mejor las " +"caras con condiciones de iluminación difíciles a un coste de velocidad de " +"extracción. Diferentes métodos darán diferentes resultados en diferentes " +"conjuntos. NB: Esto no afecta a la cara de salida, sólo a la entrada del " +"alineador.\n" +"L|none: No realice la normalización en la cara.\n" +"L|clahe: Realice la ecualización adaptativa del histograma con contraste " +"limitado en el rostro.\n" +"L|hist: Iguala los histogramas de los canales RGB.\n" +"L|mean: Normalizar los colores de la cara a la media." + +#: lib/cli/args_extract_convert.py:226 +msgid "" +"The number of times to re-feed the detected face into the aligner. Each time " +"the face is re-fed into the aligner the bounding box is adjusted by a small " +"amount. The final landmarks are then averaged from each iteration. Helps to " +"remove 'micro-jitter' but at the cost of slower extraction speed. The more " +"times the face is re-fed into the aligner, the less micro-jitter should " +"occur but the longer extraction will take." +msgstr "" +"El número de veces que hay que volver a introducir la cara detectada en el " +"alineador. Cada vez que la cara se vuelve a introducir en el alineador, el " +"cuadro delimitador se ajusta en una pequeña cantidad. Los puntos de " +"referencia finales se promedian en cada iteración. Esto ayuda a eliminar el " +"'micro-jitter', pero a costa de una menor velocidad de extracción. Cuantas " +"más veces se vuelva a introducir la cara en el alineador, menos " +"microfluctuaciones se producirán, pero la extracción será más larga." + +#: lib/cli/args_extract_convert.py:239 +msgid "" +"Re-feed the initially found aligned face through the aligner. Can help " +"produce better alignments for faces that are rotated beyond 45 degrees in " +"the frame or are at extreme angles. Slows down extraction." +msgstr "" +"Vuelva a introducir la cara alineada encontrada inicialmente a través del " +"alineador. Puede ayudar a producir mejores alineaciones para las caras que " +"se giran más de 45 grados en el marco o se encuentran en ángulos extremos. " +"Ralentiza la extracción." + +#: lib/cli/args_extract_convert.py:249 +msgid "" +"If a face isn't found, rotate the images to try to find a face. Can find " +"more faces at the cost of extraction speed. Pass in a single number to use " +"increments of that size up to 360, or pass in a list of numbers to enumerate " +"exactly what angles to check." +msgstr "" +"Si no se encuentra una cara, gira las imágenes para intentar encontrar una " +"cara. Puede encontrar más caras a costa de la velocidad de extracción. Pase " +"un solo número para usar incrementos de ese tamaño hasta 360, o pase una " +"lista de números para enumerar exactamente qué ángulos comprobar." + +#: lib/cli/args_extract_convert.py:259 +msgid "" +"Obtain and store face identity encodings from VGGFace2. Slows down extract a " +"little, but will save time if using 'sort by face'" +msgstr "" +"Obtenga y almacene codificaciones de identidad facial de VGGFace2. Ralentiza " +"un poco la extracción, pero ahorrará tiempo si usa 'sort by face'" + +#: lib/cli/args_extract_convert.py:269 lib/cli/args_extract_convert.py:280 +#: lib/cli/args_extract_convert.py:293 lib/cli/args_extract_convert.py:307 +#: lib/cli/args_extract_convert.py:614 lib/cli/args_extract_convert.py:623 +#: lib/cli/args_extract_convert.py:638 lib/cli/args_extract_convert.py:651 +#: lib/cli/args_extract_convert.py:665 +msgid "Face Processing" +msgstr "Proceso de Caras" + +#: lib/cli/args_extract_convert.py:271 +msgid "" +"Filters out faces detected below this size. Length, in pixels across the " +"diagonal of the bounding box. Set to 0 for off" +msgstr "" +"Filtra las caras detectadas por debajo de este tamaño. Longitud, en píxeles " +"a lo largo de la diagonal del cuadro delimitador. Establecer a 0 para " +"desactivar" + +#: lib/cli/args_extract_convert.py:282 +msgid "" +"Optionally filter out people who you do not wish to extract by passing in " +"images of those people. Should be a small variety of images at different " +"angles and in different conditions. A folder containing the required images " +"or multiple image files, space separated, can be selected." +msgstr "" +"Opcionalmente, filtre a las personas que no desea extraer pasando imágenes " +"de esas personas. Debe ser una pequeña variedad de imágenes en diferentes " +"ángulos y en diferentes condiciones. Se puede seleccionar una carpeta que " +"contenga las imágenes requeridas o múltiples archivos de imágenes, separados " +"por espacios." + +#: lib/cli/args_extract_convert.py:295 +msgid "" +"Optionally select people you wish to extract by passing in images of that " +"person. Should be a small variety of images at different angles and in " +"different conditions A folder containing the required images or multiple " +"image files, space separated, can be selected." +msgstr "" +"Opcionalmente, seleccione las personas que desea extraer pasando imágenes de " +"esa persona. Debe haber una pequeña variedad de imágenes en diferentes " +"ángulos y en diferentes condiciones. Se puede seleccionar una carpeta que " +"contenga las imágenes requeridas o múltiples archivos de imágenes, separados " +"por espacios." + +#: lib/cli/args_extract_convert.py:309 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Higher values are stricter." +msgstr "" +"Para usar con los archivos nfilter/filter opcionales. Umbral para el " +"reconocimiento facial positivo. Los valores más altos son más estrictos." + +#: lib/cli/args_extract_convert.py:318 lib/cli/args_extract_convert.py:331 +#: lib/cli/args_extract_convert.py:344 lib/cli/args_extract_convert.py:356 +msgid "output" +msgstr "salida" + +#: lib/cli/args_extract_convert.py:320 +msgid "" +"The output size of extracted faces. Make sure that the model you intend to " +"train supports your required size. This will only need to be changed for hi-" +"res models." +msgstr "" +"El tamaño de salida de las caras extraídas. Asegúrese de que el modelo que " +"pretende entrenar admite el tamaño deseado. Esto sólo tendrá que ser " +"cambiado para los modelos de alta resolución." + +#: lib/cli/args_extract_convert.py:333 +msgid "" +"Extract every 'nth' frame. This option will skip frames when extracting " +"faces. For example a value of 1 will extract faces from every frame, a value " +"of 10 will extract faces from every 10th frame." +msgstr "" +"Extraer cada 'enésimo' fotograma. Esta opción omitirá los fotogramas al " +"extraer las caras. Por ejemplo, un valor de 1 extraerá las caras de cada " +"fotograma, un valor de 10 extraerá las caras de cada 10 fotogramas." + +#: lib/cli/args_extract_convert.py:346 +msgid "" +"Automatically save the alignments file after a set amount of frames. By " +"default the alignments file is only saved at the end of the extraction " +"process. NB: If extracting in 2 passes then the alignments file will only " +"start to be saved out during the second pass. WARNING: Don't interrupt the " +"script when writing the file because it might get corrupted. Set to 0 to " +"turn off" +msgstr "" +"Guardar automáticamente el archivo de alineaciones después de una cantidad " +"determinada de cuadros. Por defecto, el archivo de alineaciones sólo se " +"guarda al final del proceso de extracción. Nota: Si se extrae en 2 pases, el " +"archivo de alineaciones sólo se empezará a guardar durante el segundo pase. " +"ADVERTENCIA: No interrumpa el script al escribir el archivo porque podría " +"corromperse. Poner a 0 para desactivar" + +#: lib/cli/args_extract_convert.py:357 +msgid "Draw landmarks on the ouput faces for debugging purposes." +msgstr "" +"Dibujar puntos de referencia en las caras de salida para fines de depuración." + +#: lib/cli/args_extract_convert.py:363 lib/cli/args_extract_convert.py:373 +#: lib/cli/args_extract_convert.py:381 lib/cli/args_extract_convert.py:388 +#: lib/cli/args_extract_convert.py:678 lib/cli/args_extract_convert.py:691 +#: lib/cli/args_extract_convert.py:712 lib/cli/args_extract_convert.py:718 +msgid "settings" +msgstr "ajustes" + +#: lib/cli/args_extract_convert.py:365 +msgid "" +"Don't run extraction in parallel. Will run each part of the extraction " +"process separately (one after the other) rather than all at the same time. " +"Useful if VRAM is at a premium." +msgstr "" +"No ejecute la extracción en paralelo. Ejecutará cada parte del proceso de " +"extracción por separado (una tras otra) en lugar de hacerlo todo al mismo " +"tiempo. Útil si la VRAM es escasa." + +#: lib/cli/args_extract_convert.py:375 +msgid "" +"Skips frames that have already been extracted and exist in the alignments " +"file" +msgstr "" +"Omite los fotogramas que ya han sido extraídos y que existen en el archivo " +"de alineaciones" + +#: lib/cli/args_extract_convert.py:382 +msgid "Skip frames that already have detected faces in the alignments file" +msgstr "" +"Omitir los fotogramas que ya tienen caras detectadas en el archivo de " +"alineaciones" + +#: lib/cli/args_extract_convert.py:389 +msgid "Skip saving the detected faces to disk. Just create an alignments file" +msgstr "" +"No guardar las caras detectadas en el disco. Crear sólo un archivo de " +"alineaciones" + +#: lib/cli/args_extract_convert.py:463 +msgid "" +"Swap the original faces in a source video/images to your final faces.\n" +"Conversion plugins can be configured in the 'Settings' Menu" +msgstr "" +"Cambia las caras originales de un vídeo/imágenes de origen por las caras " +"finales.\n" +"Los plugins de conversión pueden ser configurados en el menú " +"\"Configuración\"" + +#: lib/cli/args_extract_convert.py:485 +msgid "" +"Only required if converting from images to video. Provide The original video " +"that the source frames were extracted from (for extracting the fps and " +"audio)." +msgstr "" +"Sólo es necesario si se convierte de imágenes a vídeo. Proporcione el vídeo " +"original del que se extrajeron los fotogramas de origen (para extraer los " +"fps y el audio)." + +#: lib/cli/args_extract_convert.py:494 +msgid "" +"Model directory. The directory containing the trained model you wish to use " +"for conversion." +msgstr "" +"Directorio del modelo. El directorio que contiene el modelo entrenado que " +"desea utilizar para la conversión." + +#: lib/cli/args_extract_convert.py:505 +msgid "" +"R|Performs color adjustment to the swapped face. Some of these options have " +"configurable settings in '/config/convert.ini' or 'Settings > Configure " +"Convert Plugins':\n" +"L|avg-color: Adjust the mean of each color channel in the swapped " +"reconstruction to equal the mean of the masked area in the original image.\n" +"L|color-transfer: Transfers the color distribution from the source to the " +"target image using the mean and standard deviations of the L*a*b* color " +"space.\n" +"L|manual-balance: Manually adjust the balance of the image in a variety of " +"color spaces. Best used with the Preview tool to set correct values.\n" +"L|match-hist: Adjust the histogram of each color channel in the swapped " +"reconstruction to equal the histogram of the masked area in the original " +"image.\n" +"L|seamless-clone: Use cv2's seamless clone function to remove extreme " +"gradients at the mask seam by smoothing colors. Generally does not give very " +"satisfactory results.\n" +"L|none: Don't perform color adjustment." +msgstr "" +"R|Realiza un ajuste de color a la cara intercambiada. Algunas de estas " +"opciones tienen ajustes configurables en '/config/convert.ini' o 'Ajustes > " +"Configurar Extensiones de Conversión':\n" +"L|avg-color: Ajuste la media de cada canal de color en la reconstrucción " +"intercambiada para igualar la media del área enmascarada en la imagen " +"original.\n" +"L|color-transfer: Transfiere la distribución del color de la imagen de " +"origen a la de destino utilizando la media y las desviaciones estándar del " +"espacio de color L*a*b*.\n" +"L|manual-balance: Ajuste manualmente el equilibrio de la imagen en una " +"variedad de espacios de color. Se utiliza mejor con la herramienta de vista " +"previa para establecer los valores correctos.\n" +"L|match-hist: Ajuste el histograma de cada canal de color en la " +"reconstrucción intercambiada para igualar el histograma del área enmascarada " +"en la imagen original.\n" +"L|seamless-clone: Utilice la función de clonación sin costuras de cv2 para " +"eliminar los gradientes extremos en la costura de la máscara, suavizando los " +"colores. Generalmente no da resultados muy satisfactorios.\n" +"L|none: No realice el ajuste de color." + +#: lib/cli/args_extract_convert.py:531 +msgid "" +"R|Masker to use. NB: The mask you require must exist within the alignments " +"file. You can add additional masks with the Mask Tool.\n" +"L|none: Don't use a mask.\n" +"L|bisenet-fp_face: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'face' or " +"'legacy' centering.\n" +"L|bisenet-fp_head: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'head' " +"centering.\n" +"L|custom_face: Custom user created, face centered mask.\n" +"L|custom_head: Custom user created, head centered mask.\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"L|predicted: If the 'Learn Mask' option was enabled during training, this " +"will use the mask that was created by the trained model." +msgstr "" +"R|Máscara a utilizar. NB: La máscara que necesita debe existir en el archivo " +"de alineaciones. Puede añadir máscaras adicionales con la herramienta de " +"máscaras.\n" +"L|none: No utilizar una máscara.\n" +"L|bisenet-fp-face: Máscara relativamente ligera basada en NN que proporciona " +"un control más refinado sobre el área a enmascarar (configurable en la " +"configuración de la máscara). Utilice esta versión de bisenet-fp si su " +"modelo está entrenado con centrado 'face' o 'legacy'.\n" +"L|bisenet-fp-head: Máscara relativamente ligera basada en NN que proporciona " +"un control más refinado sobre el área a enmascarar (configurable en la " +"configuración de la máscara). Utilice esta versión de bisenet-fp si su " +"modelo está entrenado con centrado de 'cabeza'.\n" +"L|custom_face: Máscara personalizada creada por el usuario y centrada en el " +"rostro..\n" +"L|custom_head: Máscara personalizada centrada en la cabeza creada por el " +"usuario.\n" +"L|components: Máscara diseñada para proporcionar una segmentación facial " +"basada en el posicionamiento de las ubicaciones de los puntos de referencia. " +"Se construye un casco convexo alrededor del exterior de los puntos de " +"referencia para crear una máscara.\n" +"L|extended: Máscara diseñada para proporcionar una segmentación facial " +"basada en el posicionamiento de las ubicaciones de los puntos de referencia. " +"Se construye un casco convexo alrededor del exterior de los puntos de " +"referencia y la máscara se extiende hacia arriba en la frente.\n" +"L|vgg-clear: Máscara diseñada para proporcionar una segmentación inteligente " +"de rostros principalmente frontales y libres de obstrucciones. Los rostros " +"de perfil y las obstrucciones pueden dar lugar a un rendimiento inferior.\n" +"L|vgg-obstructed: Máscara diseñada para proporcionar una segmentación " +"inteligente de rostros principalmente frontales. El modelo de la máscara ha " +"sido entrenado específicamente para reconocer algunas obstrucciones faciales " +"(manos y gafas). Los rostros de perfil pueden dar lugar a un rendimiento " +"inferior.\n" +"L|unet-dfl: Máscara diseñada para proporcionar una segmentación inteligente " +"de rostros principalmente frontales. El modelo de máscara ha sido entrenado " +"por los miembros de la comunidad y necesitará ser probado para una mayor " +"descripción. Los rostros de perfil pueden dar lugar a un rendimiento " +"inferior.\n" +"L|predicted: Si la opción 'Learn Mask' se habilitó durante el entrenamiento, " +"esto usará la máscara que fue creada por el modelo entrenado." + +#: lib/cli/args_extract_convert.py:570 +msgid "" +"R|The plugin to use to output the converted images. The writers are " +"configurable in '/config/convert.ini' or 'Settings > Configure Convert " +"Plugins:'\n" +"L|ffmpeg: [video] Writes out the convert straight to video. When the input " +"is a series of images then the '-ref' (--reference-video) parameter must be " +"set.\n" +"L|gif: [animated image] Create an animated gif.\n" +"L|opencv: [images] The fastest image writer, but less options and formats " +"than other plugins.\n" +"L|patch: [images] Outputs the raw swapped face patch, along with the " +"transformation matrix required to re-insert the face back into the original " +"frame. Use this option if you wish to post-process and composite the final " +"face within external tools.\n" +"L|pillow: [images] Slower than opencv, but has more options and supports " +"more formats." +msgstr "" +"R|El plugin a utilizar para dar salida a las imágenes convertidas. Los " +"escritores son configurables en '/config/convert.ini' o 'Ajustes > " +"Configurar Extensiones de Conversión:'\n" +"L|ffmpeg: [video] Escribe la conversión directamente en vídeo. Cuando la " +"entrada es una serie de imágenes, el parámetro '-ref' (--reference-video) " +"debe ser establecido.\n" +"L|gif: [imagen animada] Crea un gif animado.\n" +"L|opencv: [images] El escritor de imágenes más rápido, pero con menos " +"opciones y formatos que otros plugins.\n" +"L|patch: [images] Genera el parche de cara intercambiado sin formato, junto " +"con la matriz de transformación necesaria para volver a insertar la cara en " +"el marco original.\n" +"L|pillow: [images] Más lento que opencv, pero tiene más opciones y soporta " +"más formatos." + +#: lib/cli/args_extract_convert.py:591 lib/cli/args_extract_convert.py:600 +#: lib/cli/args_extract_convert.py:703 +msgid "Frame Processing" +msgstr "Proceso de fotogramas" + +#: lib/cli/args_extract_convert.py:593 +#, python-format +msgid "" +"Scale the final output frames by this amount. 100%% will output the frames " +"at source dimensions. 50%% at half size 200%% at double size" +msgstr "" +"Escala los fotogramas finales de salida en esta cantidad. 100%% dará salida " +"a los fotogramas a las dimensiones de origen. 50%% a la mitad de tamaño. " +"200%% al doble de tamaño" + +#: lib/cli/args_extract_convert.py:602 +msgid "" +"Frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use " +"--frame-ranges 10-50 90-100. Frames falling outside of the selected range " +"will be discarded unless '-k' (--keep-unchanged) is selected. NB: If you are " +"converting from images, then the filenames must end with the frame-number!" +msgstr "" +"Rangos de fotogramas a los que aplicar la transferencia, por ejemplo, para " +"los fotogramas de 10 a 50 y de 90 a 100 utilice --frame-ranges 10-50 90-100. " +"Los fotogramas que queden fuera del rango seleccionado se descartarán a " +"menos que se seleccione '-k' (--keep-unchanged). Nota: Si está convirtiendo " +"imágenes, ¡los nombres de los archivos deben terminar con el número de " +"fotograma!" + +#: lib/cli/args_extract_convert.py:616 +msgid "" +"Scale the swapped face by this percentage. Positive values will enlarge the " +"face, Negative values will shrink the face." +msgstr "" +"Escale la cara intercambiada según este porcentaje. Los valores positivos " +"agrandarán la cara, los valores negativos la reducirán." + +#: lib/cli/args_extract_convert.py:625 +msgid "" +"If you have not cleansed your alignments file, then you can filter out faces " +"by defining a folder here that contains the faces extracted from your input " +"files/video. If this folder is defined, then only faces that exist within " +"your alignments file and also exist within the specified folder will be " +"converted. Leaving this blank will convert all faces that exist within the " +"alignments file." +msgstr "" +"Si no ha limpiado su archivo de alineaciones, puede filtrar las caras " +"definiendo aquí una carpeta que contenga las caras extraídas de sus archivos/" +"vídeos de entrada. Si se define esta carpeta, sólo se convertirán las caras " +"que existan en el archivo de alineaciones y también en la carpeta " +"especificada. Si se deja en blanco, se convertirán todas las caras que " +"existan en el archivo de alineaciones." + +#: lib/cli/args_extract_convert.py:640 +msgid "" +"Optionally filter out people who you do not wish to process by passing in an " +"image of that person. Should be a front portrait with a single person in the " +"image. Multiple images can be added space separated. NB: Using face filter " +"will significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"Opcionalmente, puede filtrar las personas que no desea procesar pasando una " +"imagen de esa persona. Debe ser un retrato frontal con una sola persona en " +"la imagen. Se pueden añadir varias imágenes separadas por espacios. NB: El " +"uso del filtro de caras disminuirá significativamente la velocidad de " +"extracción y no se puede garantizar su precisión." + +#: lib/cli/args_extract_convert.py:653 +msgid "" +"Optionally select people you wish to process by passing in an image of that " +"person. Should be a front portrait with a single person in the image. " +"Multiple images can be added space separated. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"Opcionalmente, seleccione las personas que desea procesar pasando una imagen " +"de esa persona. Debe ser un retrato frontal con una sola persona en la " +"imagen. Se pueden añadir varias imágenes separadas por espacios. NB: El uso " +"del filtro facial disminuirá significativamente la velocidad de extracción y " +"no se puede garantizar su precisión." + +#: lib/cli/args_extract_convert.py:667 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Lower values are stricter. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"Para usar con los archivos opcionales nfilter/filter. Umbral para el " +"reconocimiento positivo de caras. Los valores más bajos son más estrictos. " +"NB: El uso del filtro facial disminuirá significativamente la velocidad de " +"extracción y no se puede garantizar su precisión." + +#: lib/cli/args_extract_convert.py:680 +msgid "" +"The maximum number of parallel processes for performing conversion. " +"Converting images is system RAM heavy so it is possible to run out of memory " +"if you have a lot of processes and not enough RAM to accommodate them all. " +"Setting this to 0 will use the maximum available. No matter what you set " +"this to, it will never attempt to use more processes than are available on " +"your system. If singleprocess is enabled this setting will be ignored." +msgstr "" +"El número máximo de procesos paralelos para realizar la conversión. La " +"conversión de imágenes requiere mucha RAM del sistema, por lo que es posible " +"que se agote la memoria si tiene muchos procesos y no hay suficiente RAM " +"para acomodarlos a todos. Si se ajusta a 0, se utilizará el máximo " +"disponible. No importa lo que establezca, nunca intentará utilizar más " +"procesos que los disponibles en su sistema. Si 'singleprocess' está " +"habilitado, este ajuste será ignorado." + +#: lib/cli/args_extract_convert.py:693 +msgid "" +"Enable On-The-Fly Conversion. NOT recommended. You should generate a clean " +"alignments file for your destination video. However, if you wish you can " +"generate the alignments on-the-fly by enabling this option. This will use an " +"inferior extraction pipeline and will lead to substandard results. If an " +"alignments file is found, this option will be ignored." +msgstr "" +"Activar la conversión sobre la marcha. NO se recomienda. Debe generar un " +"archivo de alineación limpio para su vídeo de destino. Sin embargo, si lo " +"desea, puede generar las alineaciones sobre la marcha activando esta opción. " +"Esto utilizará una tubería de extracción inferior y conducirá a resultados " +"de baja calidad. Si se encuentra un archivo de alineaciones, esta opción " +"será ignorada." + +#: lib/cli/args_extract_convert.py:705 +msgid "" +"When used with --frame-ranges outputs the unchanged frames that are not " +"processed instead of discarding them." +msgstr "" +"Cuando se usa con --frame-ranges, la salida incluye los fotogramas no " +"procesados en vez de descartarlos." + +#: lib/cli/args_extract_convert.py:713 +msgid "Swap the model. Instead converting from of A -> B, converts B -> A" +msgstr "" +"Intercambiar el modelo. En vez de convertir de A a B, convierte de B a A" + +#: lib/cli/args_extract_convert.py:719 +msgid "Disable multiprocessing. Slower but less resource intensive." +msgstr "Desactiva el multiproceso. Es más lento, pero usa menos recursos." + +#~ msgid "" +#~ "[LEGACY] This only needs to be selected if a legacy model is being loaded " +#~ "or if there are multiple models in the model folder" +#~ msgstr "" +#~ "[LEGACY] Sólo es necesario seleccionar esta opción si se está cargando un " +#~ "modelo heredado si hay varios modelos en la carpeta de modelos" diff --git a/locales/es/LC_MESSAGES/lib.cli.args_train.mo b/locales/es/LC_MESSAGES/lib.cli.args_train.mo new file mode 100644 index 0000000000..84370ccc53 Binary files /dev/null and b/locales/es/LC_MESSAGES/lib.cli.args_train.mo differ diff --git a/locales/es/LC_MESSAGES/lib.cli.args_train.po b/locales/es/LC_MESSAGES/lib.cli.args_train.po new file mode 100755 index 0000000000..b97e77df4f --- /dev/null +++ b/locales/es/LC_MESSAGES/lib.cli.args_train.po @@ -0,0 +1,392 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-15 20:02+0000\n" +"PO-Revision-Date: 2025-12-16 14:55+0000\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.8\n" + +#: lib/cli/args_train.py:30 +msgid "" +"Train a model on extracted original (A) and swap (B) faces.\n" +"Training models can take a long time. Anything from 24hrs to over a week\n" +"Model plugins can be configured in the 'Settings' Menu" +msgstr "" +"Entrene un modelo con las caras originales (A) e intercambiadas (B) " +"extraídas.\n" +"El entrenamiento de los modelos puede llevar mucho tiempo. Desde 24 horas " +"hasta más de una semana.\n" +"Los plugins de los modelos pueden configurarse en el menú \"Ajustes\"" + +#: lib/cli/args_train.py:49 lib/cli/args_train.py:58 +msgid "faces" +msgstr "caras" + +#: lib/cli/args_train.py:51 +msgid "" +"Input directory. A directory containing training images for face A. This is " +"the original face, i.e. the face that you want to remove and replace with " +"face B." +msgstr "" +"Directorio de entrada. Un directorio que contiene imágenes de entrenamiento " +"para la cara A. Esta es la cara original, es decir, la cara que se quiere " +"eliminar y sustituir por la cara B." + +#: lib/cli/args_train.py:60 +msgid "" +"Input directory. A directory containing training images for face B. This is " +"the swap face, i.e. the face that you want to place onto the head of person " +"A." +msgstr "" +"Directorio de entrada. Un directorio que contiene imágenes de entrenamiento " +"para la cara B. Esta es la cara de intercambio, es decir, la cara que se " +"quiere colocar en la cabeza de la persona A." + +#: lib/cli/args_train.py:67 lib/cli/args_train.py:80 lib/cli/args_train.py:97 +#: lib/cli/args_train.py:123 lib/cli/args_train.py:133 +msgid "model" +msgstr "modelo" + +#: lib/cli/args_train.py:69 +msgid "" +"Model directory. This is where the training data will be stored. You should " +"always specify a new folder for new models. If starting a new model, select " +"either an empty folder, or a folder which does not exist (which will be " +"created). If continuing to train an existing model, specify the location of " +"the existing model." +msgstr "" +"Directorio del modelo. Aquí es donde se almacenarán los datos de " +"entrenamiento. Siempre debe especificar una nueva carpeta para los nuevos " +"modelos. Si se inicia un nuevo modelo, seleccione una carpeta vacía o una " +"carpeta que no exista (que se creará). Si continúa entrenando un modelo " +"existente, especifique la ubicación del modelo existente." + +#: lib/cli/args_train.py:82 +msgid "" +"R|Load the weights from a pre-existing model into a newly created model. For " +"most models this will load weights from the Encoder of the given model into " +"the encoder of the newly created model. Some plugins may have specific " +"configuration options allowing you to load weights from other layers. " +"Weights will only be loaded when creating a new model. This option will be " +"ignored if you are resuming an existing model. Generally you will also want " +"to 'freeze-weights' whilst the rest of your model catches up with your " +"Encoder.\n" +"NB: Weights can only be loaded from models of the same plugin as you intend " +"to train." +msgstr "" +"R|Cargue los pesos de un modelo preexistente en un modelo recién creado. " +"Para la mayoría de los modelos, esto cargará pesos del codificador del " +"modelo dado en el codificador del modelo recién creado. Algunos complementos " +"pueden tener opciones de configuración específicas que le permiten cargar " +"pesos de otras capas. Los pesos solo se cargarán al crear un nuevo modelo. " +"Esta opción se ignorará si está reanudando un modelo existente. En general, " +"también querrá 'congelar pesos' mientras el resto de su modelo se pone al " +"día con su codificador.\n" +"NB: Los pesos solo se pueden cargar desde modelos del mismo complemento que " +"desea entrenar." + +#: lib/cli/args_train.py:99 +msgid "" +"R|Select which trainer to use. Trainers can be configured from the Settings " +"menu or the config folder.\n" +"L|original: The original model created by /u/deepfakes.\n" +"L|dfaker: 64px in/128px out model from dfaker. Enable 'warp-to-landmarks' " +"for full dfaker method.\n" +"L|dfl-h128: 128px in/out model from deepfacelab\n" +"L|dfl-sae: Adaptable model from deepfacelab\n" +"L|dlight: A lightweight, high resolution DFaker variant.\n" +"L|iae: A model that uses intermediate layers to try to get better details\n" +"L|lightweight: A lightweight model for low-end cards. Don't expect great " +"results. Can train as low as 1.6GB with batch size 8.\n" +"L|realface: A high detail, dual density model based on DFaker, with " +"customizable in/out resolution. The autoencoders are unbalanced so B>A swaps " +"won't work so well. By andenixa et al. Very configurable.\n" +"L|unbalanced: 128px in/out model from andenixa. The autoencoders are " +"unbalanced so B>A swaps won't work so well. Very configurable.\n" +"L|villain: 128px in/out model from villainguy. Very resource hungry (You " +"will require a GPU with a fair amount of VRAM). Good for details, but more " +"susceptible to color differences." +msgstr "" +"R|Seleccione el entrenador que desea utilizar. Los entrenadores se pueden " +"configurar desde el menú de configuración o la carpeta de configuración.\n" +"L|original: El modelo original creado por /u/deepfakes.\n" +"L|dfaker: Modelo de 64px in/128px out de dfaker. Habilitar 'warp-to-" +"landmarks' para el método completo de dfaker.\n" +"L|dfl-h128: modelo de 128px in/out de deepfacelab\n" +"L|dfl-sae: Modelo adaptable de deepfacelab\n" +"L|dlight: Una variante de DFaker ligera y de alta resolución.\n" +"L|iae: Un modelo que utiliza capas intermedias para tratar de obtener " +"mejores detalles.\n" +"L|lightweight: Un modelo ligero para tarjetas de gama baja. No esperes " +"grandes resultados. Puede entrenar hasta 1,6GB con tamaño de lote 8.\n" +"L|realface: Un modelo de alto detalle y doble densidad basado en DFaker, con " +"resolución de entrada y salida personalizable. Los autocodificadores están " +"desequilibrados, por lo que los intercambios B>A no funcionan tan bien. Por " +"andenixa et al. Muy configurable\n" +"L|Unbalanced: modelo de 128px de entrada/salida de andenixa. Los " +"autocodificadores están desequilibrados por lo que los intercambios B>A no " +"funcionarán tan bien. Muy configurable\n" +"L|villain: Modelo de 128px de entrada/salida de villainguy. Requiere muchos " +"recursos (se necesita una GPU con una buena cantidad de VRAM). Bueno para " +"los detalles, pero más susceptible a las diferencias de color." + +#: lib/cli/args_train.py:125 +msgid "" +"Output a summary of the model and exit. If a model folder is provided then a " +"summary of the saved model is displayed. Otherwise a summary of the model " +"that would be created by the chosen plugin and configuration settings is " +"displayed." +msgstr "" +"Genere un resumen del modelo y salga. Si se proporciona una carpeta de " +"modelo, se muestra un resumen del modelo guardado. De lo contrario, se " +"muestra un resumen del modelo que crearía el complemento elegido y los " +"ajustes de configuración." + +#: lib/cli/args_train.py:135 +msgid "" +"Freeze the weights of the model. Freezing weights means that some of the " +"parameters in the model will no longer continue to learn, but those that are " +"not frozen will continue to learn. For most models, this will freeze the " +"encoder, but some models may have configuration options for freezing other " +"layers." +msgstr "" +"Congele los pesos del modelo. Congelar pesos significa que algunos de los " +"parámetros del modelo ya no seguirán aprendiendo, pero los que no están " +"congelados seguirán aprendiendo. Para la mayoría de los modelos, esto " +"congelará el codificador, pero algunos modelos pueden tener opciones de " +"configuración para congelar otras capas." + +#: lib/cli/args_train.py:147 lib/cli/args_train.py:160 +#: lib/cli/args_train.py:174 lib/cli/args_train.py:183 +#: lib/cli/args_train.py:190 lib/cli/args_train.py:199 +msgid "training" +msgstr "entrenamiento" + +#: lib/cli/args_train.py:149 +msgid "" +"Batch size. This is the number of images processed through the model for " +"each side per iteration. NB: As the model is fed 2 sides at a time, the " +"actual number of images within the model at any one time is double the " +"number that you set here. Larger batches require more GPU RAM." +msgstr "" +"Tamaño del lote. Este es el número de imágenes procesadas a través del " +"modelo para cada lado por iteración. Nota: Como el modelo se alimenta de 2 " +"lados a la vez, el número real de imágenes dentro del modelo en cualquier " +"momento es el doble del número que se establece aquí. Los lotes más grandes " +"requieren más RAM de la GPU." + +#: lib/cli/args_train.py:162 +msgid "" +"Length of training in iterations. This is only really used for automation. " +"There is no 'correct' number of iterations a model should be trained for. " +"You should stop training when you are happy with the previews. However, if " +"you want the model to stop automatically at a set number of iterations, you " +"can set that value here." +msgstr "" +"Duración del entrenamiento en iteraciones. Esto sólo se utiliza realmente " +"para la automatización. No hay un número 'correcto' de iteraciones para las " +"que deba entrenarse un modelo. Debe dejar de entrenar cuando esté satisfecho " +"con las previsiones. Sin embargo, si desea que el modelo se detenga " +"automáticamente en un número determinado de iteraciones, puede establecer " +"ese valor aquí." + +#: lib/cli/args_train.py:176 +msgid "" +"Learning rate warmup. Linearly increase the learning rate from 0 to the " +"chosen target rate over the number of iterations given here. 0 to disable." +msgstr "" +"Calentamiento de la tasa de aprendizaje. Aumenta linealmente la tasa de " +"aprendizaje desde 0 hasta la tasa objetivo elegida a lo largo del número de " +"iteraciones indicado aquí. 0 para desactivar." + +#: lib/cli/args_train.py:184 +msgid "Use distibuted training on multi-gpu setups." +msgstr "Utilice capacitación distribuida en configuraciones de múltiples GPU." + +#: lib/cli/args_train.py:192 +msgid "" +"Disables TensorBoard logging. NB: Disabling logs means that you will not be " +"able to use the graph or analysis for this session in the GUI." +msgstr "" +"Desactiva el registro de TensorBoard. NB: Desactivar los registros significa " +"que no podrá utilizar el gráfico o el análisis de esta sesión en la GUI." + +#: lib/cli/args_train.py:201 +msgid "" +"Use the Learning Rate Finder to discover the optimal learning rate for " +"training. For new models, this will calculate the optimal learning rate for " +"the model. For existing models this will use the optimal learning rate that " +"was discovered when initializing the model. Setting this option will ignore " +"the manually configured learning rate (configurable in train settings)." +msgstr "" +"Utilice el Buscador de tasa de aprendizaje para descubrir la tasa de " +"aprendizaje óptima para la capacitación. Para modelos nuevos, esto calculará " +"la tasa de aprendizaje óptima para el modelo. Para los modelos existentes, " +"esto utilizará la tasa de aprendizaje óptima que se descubrió al inicializar " +"el modelo. Configurar esta opción ignorará la tasa de aprendizaje " +"configurada manualmente (configurable en la configuración del tren)." + +#: lib/cli/args_train.py:214 lib/cli/args_train.py:224 +msgid "Saving" +msgstr "Guardar" + +#: lib/cli/args_train.py:215 +msgid "Sets the number of iterations between each model save." +msgstr "Establece el número de iteraciones entre cada guardado del modelo." + +#: lib/cli/args_train.py:226 +msgid "" +"Sets the number of iterations before saving a backup snapshot of the model " +"in it's current state. Set to 0 for off." +msgstr "" +"Establece el número de iteraciones antes de guardar una copia de seguridad " +"del modelo en su estado actual. Establece 0 para que esté desactivado." + +#: lib/cli/args_train.py:233 lib/cli/args_train.py:245 +#: lib/cli/args_train.py:257 +msgid "timelapse" +msgstr "intervalo" + +#: lib/cli/args_train.py:235 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'A' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-B parameter." +msgstr "" +"Opcional para crear un timelapse. Timelapse guardará una imagen de las caras " +"seleccionadas en la carpeta timelapse-output en cada iteración de guardado. " +"Esta debe ser la carpeta de entrada de las caras \"A\" que desea utilizar " +"para crear el timelapse. También debe suministrar un parámetro --timelapse-" +"output y un parámetro --timelapse-input-B." + +#: lib/cli/args_train.py:247 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'B' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-A parameter." +msgstr "" +"Opcional para crear un timelapse. Timelapse guardará una imagen de las caras " +"seleccionadas en la carpeta timelapse-output en cada iteración de guardado. " +"Esta debe ser la carpeta de entrada de las caras \"B\" que desea utilizar " +"para crear el timelapse. También debe suministrar un parámetro --timelapse-" +"output y un parámetro --timelapse-input-A." + +#: lib/cli/args_train.py:259 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. If " +"the input folders are supplied but no output folder, it will default to your " +"model folder/timelapse/" +msgstr "" +"Opcional para crear un timelapse. Timelapse guardará una imagen de las caras " +"seleccionadas en la carpeta timelapse-output en cada iteración de guardado. " +"Si se suministran las carpetas de entrada pero no la carpeta de salida, se " +"guardará por defecto en la carpeta del modelo/timelapse/" + +#: lib/cli/args_train.py:268 lib/cli/args_train.py:275 +msgid "preview" +msgstr "previsualización" + +#: lib/cli/args_train.py:269 +msgid "Show training preview output. in a separate window." +msgstr "" +"Mostrar la salida de la vista previa del entrenamiento. en una ventana " +"separada." + +#: lib/cli/args_train.py:277 +msgid "" +"Writes the training result to a file. The image will be stored in the root " +"of your FaceSwap folder." +msgstr "" +"Escribe el resultado del entrenamiento en un archivo. La imagen se " +"almacenará en la raíz de su carpeta FaceSwap." + +#: lib/cli/args_train.py:284 lib/cli/args_train.py:294 +#: lib/cli/args_train.py:304 lib/cli/args_train.py:314 +msgid "augmentation" +msgstr "aumento" + +#: lib/cli/args_train.py:286 +msgid "" +"Warps training faces to closely matched Landmarks from the opposite face-set " +"rather than randomly warping the face. This is the 'dfaker' way of doing " +"warping." +msgstr "" +"Deforma las caras de entrenamiento a puntos de referencia muy parecidos del " +"conjunto de caras opuestas en lugar de deformar la cara al azar. Esta es la " +"forma 'dfaker' de hacer la deformación." + +#: lib/cli/args_train.py:296 +msgid "" +"To effectively learn, a random set of images are flipped horizontally. " +"Sometimes it is desirable for this not to occur. Generally this should be " +"left off except for during 'fit training'." +msgstr "" +"Para aprender de forma efectiva, se voltea horizontalmente un conjunto " +"aleatorio de imágenes. A veces es deseable que esto no ocurra. Por lo " +"general, esto debería dejarse sin efecto, excepto durante el 'entrenamiento " +"de ajuste'." + +#: lib/cli/args_train.py:306 +msgid "" +"Color augmentation helps make the model less susceptible to color " +"differences between the A and B sets, at an increased training time cost. " +"Enable this option to disable color augmentation." +msgstr "" +"El aumento del color ayuda a que el modelo sea menos susceptible a las " +"diferencias de color entre los conjuntos A y B, con un mayor coste de tiempo " +"de entrenamiento. Activa esta opción para desactivar el aumento de color." + +#: lib/cli/args_train.py:316 +msgid "" +"Warping is integral to training the Neural Network. This option should only " +"be enabled towards the very end of training to try to bring out more detail. " +"Think of it as 'fine-tuning'. Enabling this option from the beginning is " +"likely to kill a model and lead to terrible results." +msgstr "" +"La deformación es fundamental para el entrenamiento de la red neuronal. Esta " +"opción sólo debería activarse hacia el final del entrenamiento para tratar " +"de obtener más detalles. Piense en ello como un 'ajuste fino'. Si se activa " +"esta opción desde el principio, es probable que arruine el modelo y se " +"obtengan resultados terribles." + +#~ msgid "" +#~ "R|Select the distribution stategy to use.\n" +#~ "L|default: Use Tensorflow's default distribution strategy.\n" +#~ "L|central-storage: Centralizes variables on the CPU whilst operations are " +#~ "performed on 1 or more local GPUs. This can help save some VRAM at the " +#~ "cost of some speed by not storing variables on the GPU. Note: Mixed-" +#~ "Precision is not supported on multi-GPU setups.\n" +#~ "L|mirrored: Supports synchronous distributed training across multiple " +#~ "local GPUs. A copy of the model and all variables are loaded onto each " +#~ "GPU with batches distributed to each GPU at each iteration." +#~ msgstr "" +#~ "562 / 5,000\n" +#~ "Translation results\n" +#~ "R|Seleccione la estrategia de distribución a utilizar.\n" +#~ "L|default: utiliza la estrategia de distribución predeterminada de " +#~ "Tensorflow.\n" +#~ "L|central-storage: centraliza las variables en la CPU mientras que las " +#~ "operaciones se realizan en 1 o más GPU locales. Esto puede ayudar a " +#~ "ahorrar algo de VRAM a costa de cierta velocidad al no almacenar " +#~ "variables en la GPU. Nota: Mixed-Precision no es compatible con " +#~ "configuraciones de múltiples GPU.\n" +#~ "L|mirrored: Admite el entrenamiento distribuido síncrono en varias GPU " +#~ "locales. Se carga una copia del modelo y todas las variables en cada GPU " +#~ "con lotes distribuidos a cada GPU en cada iteración." diff --git a/locales/es/LC_MESSAGES/tools.alignments.cli.mo b/locales/es/LC_MESSAGES/tools.alignments.cli.mo new file mode 100644 index 0000000000..9499ca1e8a Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.alignments.cli.mo differ diff --git a/locales/es/LC_MESSAGES/tools.alignments.cli.po b/locales/es/LC_MESSAGES/tools.alignments.cli.po new file mode 100644 index 0000000000..ceb263f11a --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.alignments.cli.po @@ -0,0 +1,296 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-19 11:28+0100\n" +"PO-Revision-Date: 2024-04-19 11:29+0100\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/alignments/cli.py:16 +msgid "" +"This command lets you perform various tasks pertaining to an alignments file." +msgstr "" +"Este comando le permite realizar varias tareas relacionadas con un archivo " +"de alineación." + +#: tools/alignments/cli.py:31 +msgid "" +"Alignments tool\n" +"This tool allows you to perform numerous actions on or using an alignments " +"file against its corresponding faceset/frame source." +msgstr "" +"Herramienta de alineación\n" +"Esta herramienta le permite realizar numerosas acciones sobre un conjunto de " +"caras o una fuente de fotogramas, usando opcionalmente su correspondiente " +"archivo de alineación." + +#: tools/alignments/cli.py:43 +msgid " Must Pass in a frames folder/source video file (-r)." +msgstr "" +" Debe indicar una carpeta de fotogramas o archivo de vídeo de origen (-r)." + +#: tools/alignments/cli.py:44 +msgid " Must Pass in a faces folder (-c)." +msgstr " Debe indicar una carpeta de caras (-c)." + +#: tools/alignments/cli.py:45 +msgid "" +" Must Pass in either a frames folder/source video file OR a faces folder (-r " +"or -c)." +msgstr "" +" Debe indicar una carpeta de fotogramas o archivo de vídeo de origen, o una " +"carpeta de caras (-r o -c)." + +#: tools/alignments/cli.py:47 +msgid "" +" Must Pass in a frames folder/source video file AND a faces folder (-r and -" +"c)." +msgstr "" +" Debe indicar una carpeta de fotogramas o archivo de vídeo de origen, y una " +"carpeta de caras (-r y -c)." + +#: tools/alignments/cli.py:49 +msgid " Use the output option (-o) to process results." +msgstr " Usar la opción de salida (-o) para procesar los resultados." + +#: tools/alignments/cli.py:58 tools/alignments/cli.py:104 +msgid "processing" +msgstr "proceso" + +#: tools/alignments/cli.py:61 +#, python-brace-format +msgid "" +"R|Choose which action you want to perform. NB: All actions require an " +"alignments file (-a) to be passed in.\n" +"L|'draw': Draw landmarks on frames in the selected folder/video. A subfolder " +"will be created within the frames folder to hold the output.{0}\n" +"L|'export': Export the contents of an alignments file to a json file. Can be " +"used for editing alignment information in external tools and then re-" +"importing by using Faceswap's Extract 'Import' plugins. Note: masks and " +"identity vectors will not be included in the exported file, so will be re-" +"generated when the json file is imported back into Faceswap. All data is " +"exported with the origin (0, 0) at the top left of the canvas.\n" +"L|'extract': Re-extract faces from the source frames/video based on " +"alignment data. This is a lot quicker than re-detecting faces. Can pass in " +"the '-een' (--extract-every-n) parameter to only extract every nth frame." +"{1}\n" +"L|'from-faces': Generate alignment file(s) from a folder of extracted faces. " +"if the folder of faces comes from multiple sources, then multiple alignments " +"files will be created. NB: for faces which have been extracted from folders " +"of source images, rather than a video, a single alignments file will be " +"created as there is no way for the process to know how many folders of " +"images were originally used. You do not need to provide an alignments file " +"path to run this job. {3}\n" +"L|'missing-alignments': Identify frames that do not exist in the alignments " +"file.{2}{0}\n" +"L|'missing-frames': Identify frames in the alignments file that do not " +"appear within the frames folder/video.{2}{0}\n" +"L|'multi-faces': Identify where multiple faces exist within the alignments " +"file.{2}{4}\n" +"L|'no-faces': Identify frames that exist within the alignment file but no " +"faces were detected.{2}{0}\n" +"L|'remove-faces': Remove deleted faces from an alignments file. The original " +"alignments file will be backed up.{3}\n" +"L|'rename' - Rename faces to correspond with their parent frame and position " +"index in the alignments file (i.e. how they are named after running extract)." +"{3}\n" +"L|'sort': Re-index the alignments from left to right. For alignments with " +"multiple faces this will ensure that the left-most face is at index 0.\n" +"L|'spatial': Perform spatial and temporal filtering to smooth alignments " +"(EXPERIMENTAL!)" +msgstr "" +"R|Elija la acción que desea realizar. NB: Todas las acciones requieren que " +"se indique un archivo de alineación (-a).\n" +"L|'draw': Dibuja puntos de referencia en los fotogramas de la carpeta o " +"vídeo seleccionado. Se creará una subcarpeta dentro de la carpeta de " +"fotogramas para guardar el resultado.{0}\n" +"L|'export': Exportar el contenido de un archivo de alineaciones a un archivo " +"JSON. Se puede utilizar para editar información de alineación en " +"herramientas externas y luego volver a importar mediante el uso de " +"complementos de 'import' de extracto de Faceswap. Nota: Las máscaras y los " +"vectores de identidad no se incluirán en el archivo exportado, por lo que se " +"volverán a generar cuando el archivo JSON se importe a FacesWap. Todos los " +"datos se exportan con el origen (0, 0) en la parte superior izquierda del " +"lienzo.\n" +"L|'extract': Reextrae las caras de los fotogramas o vídeos de origen " +"basándose en los datos de alineación. Esto es mucho más rápido que volver a " +"detectar las caras. Se puede pasar el parámetro '-een' (--extract-every-n) " +"para extraer sólo cada enésimo fotograma.{1}\n" +"L|'from-faces': genera archivos de alineación a partir de una carpeta de " +"caras extraídas. si la carpeta de caras proviene de varias fuentes, se " +"crearán varios archivos de alineación. NB: para las caras de las que se han " +"extraído carpetas de imágenes de origen, en lugar de un video, se creará un " +"único archivo de alineaciones, ya que el proceso no tiene forma de saber " +"cuántas carpetas de imágenes se usaron originalmente. No necesita " +"proporcionar una ruta de archivo de alineaciones para ejecutar este trabajo. " +"{3}\n" +"L|'missing-alignments': Identifica los fotogramas que no existen en el " +"archivo de alineaciones.{2}{0}\n" +"L|'missing-frames': Identifica los fotogramas del archivo de alineaciones " +"que no aparecen en la carpeta de fotogramas o vídeo.{2}{0}\n" +"L|'multi-faces': Identifica los casos en los que existen múltiples caras " +"dentro de un mismo fotograma, en el archivo de alineaciones.{2}{4}\n" +"L|'no-faces': Identifica los fotogramas que existen en el archivo de " +"alineación pero no se detectan caras.{2}{0}\n" +"L|'remove-faces': Elimina las caras previamente eliminadas de un archivo de " +"alineaciones. Se hará una copia de seguridad del archivo de alineaciones " +"original.{3}\n" +"L|'rename': Cambia el nombre de las caras para que se correspondan con su " +"marco padre y su índice de posición en el archivo de alineaciones (es decir, " +"cómo se nombran después de ejecutar la extracción).{3}\n" +"L|'sort': Reordena las alineaciones de izquierda a derecha. En el caso de " +"alineaciones con múltiples caras, esto asegurará que la cara más a la " +"izquierda esté en el índice 0.\n" +"L|'spatial': Realiza un filtrado espacial y temporal para suavizar las " +"alineaciones (¡EXPERIMENTAL!)" + +#: tools/alignments/cli.py:107 +msgid "" +"R|How to output discovered items ('faces' and 'frames' only):\n" +"L|'console': Print the list of frames to the screen. (DEFAULT)\n" +"L|'file': Output the list of frames to a text file (stored within the source " +"directory).\n" +"L|'move': Move the discovered items to a sub-folder within the source " +"directory." +msgstr "" +"R|Como procesar los elementos descubiertos (sólo 'caras' y 'cuadros'):\n" +"L|'console': Muestra la lista de fotogramas en la pantalla. (POR DEFECTO)\n" +"L|'file': Redirige la lista de fotogramas a un archivo de texto (almacenado " +"en el directorio de origen).\n" +"L|'move': Mueve los elementos descubiertos a una subcarpeta dentro del " +"directorio de origen." + +#: tools/alignments/cli.py:118 tools/alignments/cli.py:141 +#: tools/alignments/cli.py:148 +msgid "data" +msgstr "datos" + +#: tools/alignments/cli.py:125 +msgid "" +"Full path to the alignments file to be processed. If you have input a " +"'frames_dir' and don't provide this option, the process will try to find the " +"alignments file at the default location. All jobs require an alignments file " +"with the exception of 'from-faces' when the alignments file will be " +"generated in the specified faces folder." +msgstr "" +"Ruta completa al archivo de alineaciones a procesar. Si ingresó un " +"'frames_dir' y no proporciona esta opción, el proceso intentará encontrar el " +"archivo de alineaciones en la ubicación predeterminada. Todos los trabajos " +"requieren un archivo de alineaciones con la excepción de 'from-faces' cuando " +"el archivo de alineaciones se generará en la carpeta de caras especificada." + +#: tools/alignments/cli.py:142 +msgid "Directory containing source frames that faces were extracted from." +msgstr "" +"Directorio que contiene los fotogramas de origen de los que se extrajeron " +"las caras." + +#: tools/alignments/cli.py:150 +msgid "" +"R|Run the aligmnents tool on multiple sources. The following jobs support " +"batch mode:\n" +"L|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " +"sort, spatial.\n" +"If batch mode is selected then the other options should be set as follows:\n" +"L|alignments_file: For 'sort' and 'spatial' this should point to the parent " +"folder containing the alignments files to be processed. For all other jobs " +"this option is ignored, and the alignments files must exist at their default " +"location relative to the original frames folder/video.\n" +"L|faces_dir: For 'from-faces' this should be a parent folder, containing sub-" +"folders of extracted faces from which to generate alignments files. For " +"'extract' this should be a parent folder where sub-folders will be created " +"for each extraction to be run. For all other jobs this option is ignored.\n" +"L|frames_dir: For 'draw', 'extract', 'missing-alignments', 'missing-frames' " +"and 'no-faces' this should be a parent folder containing video files or sub-" +"folders of images to perform the alignments job on. The alignments file " +"should exist at the default location. For all other jobs this option is " +"ignored." +msgstr "" +"R|Ejecute la herramienta de alineación en varias fuentes. Los siguientes " +"trabajos admiten el modo por lotes:\n" +"L|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " +"sort, spatial.\n" +"Si se selecciona el modo por lotes, las otras opciones deben configurarse de " +"la siguiente manera:\n" +"L|alignments_file: para 'sort' y 'spatial', debe apuntar a la carpeta " +"principal que contiene los archivos de alineación que se van a procesar. " +"Para todos los demás trabajos, esta opción se ignora y los archivos de " +"alineaciones deben existir en su ubicación predeterminada en relación con la " +"carpeta/video de fotogramas originales.\n" +"L|faces_dir: para 'from-faces', esta debe ser una carpeta principal que " +"contenga subcarpetas de caras extraídas desde las cuales generar archivos de " +"alineación. Para 'extraer', esta debe ser una carpeta principal donde se " +"crearán subcarpetas para cada extracción que se ejecute. Para todos los " +"demás trabajos, esta opción se ignora.\n" +"L|frames_dir: para 'draw', 'extract', 'missing-alignments', 'missing-frames' " +"y 'no-faces', esta debe ser una carpeta principal que contenga archivos de " +"video o subcarpetas de imágenes para realizar el trabajo de alineaciones en. " +"El archivo de alineaciones debe existir en la ubicación predeterminada. Para " +"todos los demás trabajos, esta opción se ignora." + +#: tools/alignments/cli.py:176 tools/alignments/cli.py:188 +#: tools/alignments/cli.py:198 +msgid "extract" +msgstr "extracción" + +#: tools/alignments/cli.py:178 +msgid "" +"[Extract only] Extract every 'nth' frame. This option will skip frames when " +"extracting faces. For example a value of 1 will extract faces from every " +"frame, a value of 10 will extract faces from every 10th frame." +msgstr "" +"[Sólo extracción] Extraer cada 'enésimo' fotograma. Esta opción omitirá los " +"fotogramas al extraer las caras. Por ejemplo, un valor de 1 extraerá las " +"caras de cada fotograma, un valor de 10 extraerá las caras de cada 10 " +"fotogramas." + +#: tools/alignments/cli.py:189 +msgid "[Extract only] The output size of extracted faces." +msgstr "[Sólo extracción] El tamaño de salida de las caras extraídas." + +#: tools/alignments/cli.py:200 +msgid "" +"[Extract only] Only extract faces that have been resized by this percent or " +"more to meet the specified extract size (`-sz`, `--size`). Useful for " +"excluding low-res images from a training set. Set to 0 to extract all faces. " +"Eg: For an extract size of 512px, A setting of 50 will only include faces " +"that have been resized from 256px or above. Setting to 100 will only extract " +"faces that have been resized from 512px or above. A setting of 200 will only " +"extract faces that have been downscaled from 1024px or above." +msgstr "" +"[Sólo extracción] Solo extraiga las caras que hayan cambiado de tamaño en " +"este porcentaje o más para cumplir con el tamaño de extracción especificado " +"(`-sz`, `--size`). Útil para excluir imágenes de baja resolución de un " +"conjunto de entrenamiento. Establézcalo en 0 para extraer todas las caras. " +"Por ejemplo: para un tamaño de extracto de 512 px, una configuración de 50 " +"solo incluirá caras cuyo tamaño haya cambiado de 256 px o más. Si se " +"establece en 100, solo se extraerán las caras que se hayan redimensionado " +"desde 512 px o más. Una configuración de 200 solo extraerá las caras que se " +"han reducido de 1024 px o más." + +#~ msgid "Directory containing extracted faces." +#~ msgstr "Directorio que contiene las caras extraídas." + +#~ msgid "Full path to the alignments file to be processed." +#~ msgstr "Ruta completa del archivo de alineaciones a procesar." + +#~ msgid "" +#~ "[Extract only] Only extract faces that have not been upscaled to the " +#~ "required size (`-sz`, `--size). Useful for excluding low-res images from " +#~ "a training set." +#~ msgstr "" +#~ "[Sólo extracción] Sólo extraer las caras que son de origen iguales como " +#~ "mínimo al tamaño de salida (`-sz`, `--size). Es útil para excluir las " +#~ "imágenes de baja resolución de un conjunto de entrenamiento." diff --git a/locales/es/LC_MESSAGES/tools.effmpeg.cli.mo b/locales/es/LC_MESSAGES/tools.effmpeg.cli.mo new file mode 100644 index 0000000000..0b973d69f6 Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.effmpeg.cli.mo differ diff --git a/locales/es/LC_MESSAGES/tools.effmpeg.cli.po b/locales/es/LC_MESSAGES/tools.effmpeg.cli.po new file mode 100644 index 0000000000..ea47680568 --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.effmpeg.cli.po @@ -0,0 +1,199 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:50+0000\n" +"PO-Revision-Date: 2024-03-29 00:02+0000\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/effmpeg/cli.py:15 +msgid "This command allows you to easily execute common ffmpeg tasks." +msgstr "Este comando le permite ejecutar fácilmente tareas comunes de ffmpeg." + +#: tools/effmpeg/cli.py:52 +msgid "A wrapper for ffmpeg for performing image <> video converting." +msgstr "Un interfaz de ffmpeg para realizar la conversión de imagen <> vídeo." + +#: tools/effmpeg/cli.py:64 +msgid "" +"R|Choose which action you want ffmpeg ffmpeg to do.\n" +"L|'extract': turns videos into images \n" +"L|'gen-vid': turns images into videos \n" +"L|'get-fps' returns the chosen video's fps.\n" +"L|'get-info' returns information about a video.\n" +"L|'mux-audio' add audio from one video to another.\n" +"L|'rescale' resize video.\n" +"L|'rotate' rotate video.\n" +"L|'slice' cuts a portion of the video into a separate video file." +msgstr "" +"R|Elige qué acción quieres que haga ffmpeg\n" +"L|'extract': convierte los vídeos en imágenes \n" +"L|'gen-vid': convierte las imágenes en vídeos \n" +"L|'get-fps' devuelve los fps del vídeo elegido.\n" +"L|'get-info' devuelve información sobre un vídeo.\n" +"L|'mux-audio' añade audio de un vídeo a otro.\n" +"L|'rescale' cambia el tamaño del vídeo.\n" +"L|'rotate' rotar video\n" +"L|'slice' corta una parte del video en un archivo de video separado." + +#: tools/effmpeg/cli.py:78 +msgid "Input file." +msgstr "Archivo de entrada." + +#: tools/effmpeg/cli.py:79 tools/effmpeg/cli.py:86 tools/effmpeg/cli.py:100 +msgid "data" +msgstr "datos" + +#: tools/effmpeg/cli.py:89 +msgid "" +"Output file. If no output is specified then: if the output is meant to be a " +"video then a video called 'out.mkv' will be created in the input directory; " +"if the output is meant to be a directory then a directory called 'out' will " +"be created inside the input directory. Note: the chosen output file " +"extension will determine the file encoding." +msgstr "" +"R|Archivo de salida. Si se deja en blanco, entonces:\n" +"L|si la salida es un vídeo, se creará un vídeo llamado 'out.mkv' en el " +"directorio de entrada;\n" +"L|si la salida es un directorio, se creará un directorio llamado 'out' " +"dentro del directorio de entrada.\n" +"Nota: la extensión del archivo de salida elegida determinará la codificación " +"del archivo." + +#: tools/effmpeg/cli.py:102 +msgid "Path to reference video if 'input' was not a video." +msgstr "" +"Ruta de acceso al vídeo de referencia si se dio una carpeta con fotogramas " +"en vez de un vídeo." + +#: tools/effmpeg/cli.py:108 tools/effmpeg/cli.py:118 tools/effmpeg/cli.py:156 +#: tools/effmpeg/cli.py:185 +msgid "output" +msgstr "salida" + +#: tools/effmpeg/cli.py:110 +msgid "" +"Provide video fps. Can be an integer, float or fraction. Negative values " +"will will make the program try to get the fps from the input or reference " +"videos." +msgstr "" +"Introducir los fps del vídeo. Puede ser un número entero, flotante o una " +"fracción. Los valores negativos harán que el programa intente obtener los " +"fps de los vídeos de entrada o de referencia." + +#: tools/effmpeg/cli.py:120 +msgid "" +"Image format that extracted images should be saved as. '.bmp' will offer the " +"fastest extraction speed, but will take the most storage space. '.png' will " +"be slower but will take less storage." +msgstr "" +"Formato de imagen en el que se deben guardar las imágenes extraídas. '.bmp' " +"ofrecerá la mayor velocidad de extracción, pero ocupará el mayor espacio de " +"almacenamiento. '.png' será más lento pero ocupará menos espacio de " +"almacenamiento." + +#: tools/effmpeg/cli.py:127 tools/effmpeg/cli.py:136 tools/effmpeg/cli.py:145 +msgid "clip" +msgstr "recorte" + +#: tools/effmpeg/cli.py:129 +msgid "" +"Enter the start time from which an action is to be applied. Default: " +"00:00:00, in HH:MM:SS format. You can also enter the time with or without " +"the colons, e.g. 00:0000 or 026010." +msgstr "" +"Introduzca el momento a partir de la cual se debe aplicar una acción. Por " +"defecto: 00:00:00, en formato HH:MM:SS. También puede introducir la hora con " +"o sin los dos puntos, por ejemplo, 00:0000 o 026010." + +#: tools/effmpeg/cli.py:138 +msgid "" +"Enter the end time to which an action is to be applied. If both an end time " +"and duration are set, then the end time will be used and the duration will " +"be ignored. Default: 00:00:00, in HH:MM:SS." +msgstr "" +"Introduzca el momento hasta el cual se debe aplicar una acción. Por defecto: " +"00:00:00, en formato HH:MM:SS. También puede introducir la hora con o sin " +"los dos puntos, por ejemplo, 00:0000 o 026010." + +#: tools/effmpeg/cli.py:147 +msgid "" +"Enter the duration of the chosen action, for example if you enter 00:00:10 " +"for slice, then the first 10 seconds after and including the start time will " +"be cut out into a new video. Default: 00:00:00, in HH:MM:SS format. You can " +"also enter the time with or without the colons, e.g. 00:0000 or 026010." +msgstr "" +"Introduzca la duración de la acción seleccionada. Por defecto: 00:00:00, en " +"formato HH:MM:SS. También puede introducir la hora con o sin los dos puntos, " +"por ejemplo, 00:0000 o 026010." + +#: tools/effmpeg/cli.py:158 +msgid "" +"Mux the audio from the reference video into the input video. This option is " +"only used for the 'gen-vid' action. 'mux-audio' action has this turned on " +"implicitly." +msgstr "" +"Copia el audio del vídeo de referencia al vídeo de entrada. Esta opción sólo " +"se utiliza para la acción 'gen-vid'. La acción 'mux-audio' la tiene activada " +"implícitamente." + +#: tools/effmpeg/cli.py:169 tools/effmpeg/cli.py:179 +msgid "rotate" +msgstr "rotación" + +#: tools/effmpeg/cli.py:171 +msgid "" +"Transpose the video. If transpose is set, then degrees will be ignored. For " +"cli you can enter either the number or the long command name, e.g. to use " +"(1, 90Clockwise) -tr 1 or -tr 90Clockwise" +msgstr "" +"Rotar el vídeo. Si la rotación está establecida, los grados serán ignorados. " +"En la línea de comandos puede introducir el número o el nombre largo del " +"comando, por ejemplo, para usar (1, 90Clockwise) son válidas las opciones -" +"tr 1 y -tr 90Clockwise" + +#: tools/effmpeg/cli.py:180 +msgid "Rotate the video clockwise by the given number of degrees." +msgstr "" +"Gira el vídeo en el sentido de las agujas del reloj el número de grados " +"indicado." + +#: tools/effmpeg/cli.py:187 +msgid "Set the new resolution scale if the chosen action is 'rescale'." +msgstr "" +"Establece la nueva escala de resolución si la acción elegida es " +"\"reescalar\"." + +#: tools/effmpeg/cli.py:192 tools/effmpeg/cli.py:200 +msgid "settings" +msgstr "ajustes" + +#: tools/effmpeg/cli.py:194 +msgid "" +"Reduces output verbosity so that only serious errors are printed. If both " +"quiet and verbose are set, verbose will override quiet." +msgstr "" +"Reduce el detalle de la salida del registro para que sólo se impriman los " +"errores graves. Si se establecen tanto 'quiet' como 'verbose', 'verbose' " +"tendrá preferencia y anulará a 'quiet'." + +#: tools/effmpeg/cli.py:202 +msgid "" +"Increases output verbosity. If both quiet and verbose are set, verbose will " +"override quiet." +msgstr "" +"Aumenta el detalle de la información de registro. Si se establecen tanto " +"'quiet' como 'verbose', 'verbose', 'verbose' tendrá preferencia y anulará a " +"'quiet'." diff --git a/locales/es/LC_MESSAGES/tools.manual.mo b/locales/es/LC_MESSAGES/tools.manual.mo new file mode 100644 index 0000000000..33cfdb8142 Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.manual.mo differ diff --git a/locales/es/LC_MESSAGES/tools.manual.po b/locales/es/LC_MESSAGES/tools.manual.po new file mode 100644 index 0000000000..0e03295c00 --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.manual.po @@ -0,0 +1,295 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:55+0000\n" +"PO-Revision-Date: \n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: es\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/manual/cli.py:13 +msgid "" +"This command lets you perform various actions on frames, faces and " +"alignments files using visual tools." +msgstr "" +"Este comando le permite realizar varias acciones en los archivos de " +"fotogramas, caras y alineaciones utilizando herramientas visuales." + +#: tools/manual/cli.py:23 +msgid "" +"A tool to perform various actions on frames, faces and alignments files " +"using visual tools" +msgstr "" +"Una herramienta que permite realizar diversas acciones en archivos de " +"fotogramas, caras y alineaciones mediante herramientas visuales" + +#: tools/manual/cli.py:35 tools/manual/cli.py:44 +msgid "data" +msgstr "datos" + +#: tools/manual/cli.py:38 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" +"Ruta del archivo de alineaciones para la entrada, si no está en la ubicación " +"por defecto" + +#: tools/manual/cli.py:46 +msgid "" +"Video file or directory containing source frames that faces were extracted " +"from." +msgstr "" +"Archivo o directorio de vídeo que contiene los fotogramas de origen de los " +"que se extrajeron las caras." + +#: tools/manual/cli.py:53 tools/manual/cli.py:62 +msgid "options" +msgstr "opciones" + +#: tools/manual/cli.py:55 +msgid "" +"Force regeneration of the low resolution jpg thumbnails in the alignments " +"file." +msgstr "" +"Forzar la regeneración de las miniaturas jpg de baja resolución en el " +"archivo de alineaciones." + +#: tools/manual/cli.py:64 +msgid "" +"The process attempts to speed up generation of thumbnails by extracting from " +"the video in parallel threads. For some videos, this causes the caching " +"process to hang. If this happens, then set this option to generate the " +"thumbnails in a slower, but more stable single thread." +msgstr "" +"El proceso intenta acelerar la generación de miniaturas extrayendo del vídeo " +"en hilos paralelos. En algunos vídeos, esto hace que el proceso de " +"extracción se cuelgue. Si esto sucede, entonces configure esta opción para " +"generar las miniaturas en un solo hilo más lento, pero más estable." + +#: tools/manual\faceviewer\frame.py:163 +msgid "Display the landmarks mesh" +msgstr "Mostrar la malla de puntos de referencia" + +#: tools/manual\faceviewer\frame.py:164 +msgid "Display the mask" +msgstr "Mostrar la máscara" + +#: tools/manual\frameviewer\editor\_base.py:628 +#: tools/manual\frameviewer\editor\landmarks.py:44 +#: tools/manual\frameviewer\editor\mask.py:75 +msgid "Magnify/Demagnify the View" +msgstr "Ampliar/Reducir la vista" + +#: tools/manual\frameviewer\editor\bounding_box.py:33 +#: tools/manual\frameviewer\editor\extract_box.py:32 +msgid "Delete Face" +msgstr "Borrar cara" + +#: tools/manual\frameviewer\editor\bounding_box.py:36 +msgid "" +"Bounding Box Editor\n" +"Edit the bounding box being fed into the aligner to recalculate the " +"landmarks.\n" +"\n" +" - Grab the corner anchors to resize the bounding box.\n" +" - Click and drag the bounding box to relocate.\n" +" - Click in empty space to create a new bounding box.\n" +" - Right click a bounding box to delete a face." +msgstr "" +"Editor del cuadro delimitador\n" +"Edite el cuadro delimitador que el alineador usa para recalcular los puntos " +"de referencia.\n" +"\n" +" - Tire de los anclajes de las esquinas para cambiar el tamaño del cuadro " +"delimitador.\n" +" - Haga clic y arrastre el cuadro delimitador para reubicarlo.\n" +" - Haga clic en un espacio vacío para crear un nuevo cuadro delimitador.\n" +" - Haga clic con el botón derecho del ratón en un cuadro delimitador para " +"eliminar una cara." + +#: tools/manual\frameviewer\editor\bounding_box.py:70 +msgid "" +"Aligner to use. FAN will obtain better alignments, but cv2-dnn can be useful " +"if FAN cannot get decent alignments and you want to set a base to edit from." +msgstr "" +"Alineador a utilizar. FAN obtendrá mejores alineaciones, pero cv2-dnn puede " +"ser útil si FAN no puede obtener alineaciones decentes y quiere tener una " +"base inicial que luego se vaya a editar." + +#: tools/manual\frameviewer\editor\bounding_box.py:83 +msgid "" +"Normalization method to use for feeding faces to the aligner. This can help " +"the aligner better align faces with difficult lighting conditions. Different " +"methods will yield different results on different sets. NB: This does not " +"impact the output face, just the input to the aligner.\n" +"\tnone: Don't perform normalization on the face.\n" +"\tclahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"\thist: Equalize the histograms on the RGB channels.\n" +"\tmean: Normalize the face colors to the mean." +msgstr "" +"Método de normalización a utilizar para las caras que el alineador usará. " +"Esto puede ayudar al alineador a alinear mejor las caras con condiciones de " +"iluminación difíciles. Diferentes métodos darán diferentes resultados en " +"diferentes conjuntos. NB: Esto no afecta a la cara de salida, sólo la " +"entrada al alineador.\n" +"\tninguno: No realizar la normalización en la cara.\n" +"\tclahe: Realiza la ecualización adaptativa del histograma con contraste " +"limitado en la cara.\n" +"\thist: Iguala los histogramas en los canales RGB.\n" +"\tmean: Normaliza los colores de la cara a la media." + +#: tools/manual\frameviewer\editor\extract_box.py:35 +msgid "" +"Extract Box Editor\n" +"Move the extract box that has been generated by the aligner. Click and " +"drag:\n" +"\n" +" - Inside the bounding box to relocate the landmarks.\n" +" - The corner anchors to resize the landmarks.\n" +" - Outside of the corners to rotate the landmarks." +msgstr "" +"Editor de cuadros de extracción\n" +"Mueve el cuadro de extracción que ha sido generada por el alineador. Haga " +"clic y arrastre...\n" +"\n" +" - Dentro del cuadro delimitador para reubicar los puntos de referencia.\n" +" - Los anclajes de las esquinas para cambiar el tamaño de los puntos de " +"referencia.\n" +" - Fuera de las esquinas para girar los puntos de referencia." + +#: tools/manual\frameviewer\editor\landmarks.py:27 +msgid "" +"Landmark Point Editor\n" +"Edit the individual landmark points.\n" +"\n" +" - Click and drag individual points to relocate.\n" +" - Draw a box to select multiple points to relocate." +msgstr "" +"Editor de puntos de referencia\n" +"Edite los puntos de referencia individuales.\n" +"\n" +" - Haga clic y arrastre los puntos individuales para reubicarlos.\n" +" - Dibuje un cuadro para seleccionar varios puntos para reubicarlos." + +#: tools/manual\frameviewer\editor\mask.py:33 +msgid "" +"Mask Editor\n" +"Edit the mask.\n" +" - NB: For Landmark based masks (e.g. components/extended) it is better to " +"make sure the landmarks are correct rather than editing the mask directly. " +"Any change to the landmarks after editing the mask will override your manual " +"edits." +msgstr "" +"Editor de máscaras\n" +"Edite la máscara.\n" +" - Nota: En el caso de las máscaras basadas en puntos de referencia (por " +"ejemplo, componentes/extensión) es mejor asegurarse de que los puntos de " +"referencia son correctos en lugar de editar la máscara directamente. " +"Cualquier cambio en los puntos de referencia después de editar la máscara " +"anulará sus ediciones manuales." + +#: tools/manual\frameviewer\editor\mask.py:77 +msgid "Draw Tool" +msgstr "Herramienta de dibujo" + +#: tools/manual\frameviewer\editor\mask.py:78 +msgid "Erase Tool" +msgstr "Herramienta de borrado" + +#: tools/manual\frameviewer\editor\mask.py:97 +msgid "Select which mask to edit" +msgstr "Seleccionar máscara a editar" + +#: tools/manual\frameviewer\editor\mask.py:104 +msgid "Set the brush size. ([ - decrease, ] - increase)" +msgstr "Seleccionar el tamaño del pincel ([ - disminuir, ] - aumentar)" + +#: tools/manual\frameviewer\editor\mask.py:111 +msgid "Select the brush cursor color." +msgstr "Seleccionar el color del pincel." + +#: tools/manual\frameviewer\frame.py:78 +msgid "Play/Pause (SPACE)" +msgstr "Reproducir/Pausa (BARRA DE ESPACIO)" + +#: tools/manual\frameviewer\frame.py:79 +msgid "Go to First Frame (HOME)" +msgstr "Ir al primer cuadro (INICIO)" + +#: tools/manual\frameviewer\frame.py:80 +msgid "Go to Previous Frame (Z)" +msgstr "Ir al cuadro anterior (Z)" + +#: tools/manual\frameviewer\frame.py:81 +msgid "Go to Next Frame (X)" +msgstr "Ir al siguiente cuadro (X)" + +#: tools/manual\frameviewer\frame.py:82 +msgid "Go to Last Frame (END)" +msgstr "Ir al último cuadro (FIN)" + +#: tools/manual\frameviewer\frame.py:83 +msgid "Extract the faces to a folder... (Ctrl+E)" +msgstr "Extraer las caras a una carpeta... (Ctrl+E)" + +#: tools/manual\frameviewer\frame.py:84 +msgid "Save the Alignments file (Ctrl+S)" +msgstr "Guardar el fichero de alineamientos (Ctrl+S)" + +#: tools/manual\frameviewer\frame.py:85 +msgid "Filter Frames to only those Containing the Selected Item (F)" +msgstr "Mostrar cuadros que contenga únicamente el elemento seleccionado (F)" + +#: tools/manual\frameviewer\frame.py:86 +msgid "" +"Set the distance from an 'average face' to be considered misaligned. Higher " +"distances are more restrictive" +msgstr "" +"Establezca la distancia desde una 'cara promedio' para que se considere " +"desalineada. Las distancias más altas son más restrictivas" + +#: tools/manual\frameviewer\frame.py:391 +msgid "View alignments" +msgstr "Ver alineamientos" + +#: tools/manual\frameviewer\frame.py:392 +msgid "Bounding box editor" +msgstr "Editor de cuadro delimitador" + +#: tools/manual\frameviewer\frame.py:393 +msgid "Location editor" +msgstr "Editor de ubicación" + +#: tools/manual\frameviewer\frame.py:394 +msgid "Mask editor" +msgstr "Editor de máscara" + +#: tools/manual\frameviewer\frame.py:395 +msgid "Landmark point editor" +msgstr "Editor de puntos de referencia" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Next" +msgstr "Siguiente" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Previous" +msgstr "Anterior" + +#: tools/manual\frameviewer\frame.py:481 +msgid "Revert to saved Alignments ({})" +msgstr "Volver a los alineamientos guardados ({})" + +#: tools/manual\frameviewer\frame.py:487 +msgid "Copy {} Alignments ({})" +msgstr "Copiar los alineamientos del cuadro {} ({})" diff --git a/locales/es/LC_MESSAGES/tools.mask.cli.mo b/locales/es/LC_MESSAGES/tools.mask.cli.mo new file mode 100644 index 0000000000..ad86f74687 Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.mask.cli.mo differ diff --git a/locales/es/LC_MESSAGES/tools.mask.cli.po b/locales/es/LC_MESSAGES/tools.mask.cli.po new file mode 100644 index 0000000000..1d70c85ba9 --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.mask.cli.po @@ -0,0 +1,351 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-06-28 13:45+0100\n" +"PO-Revision-Date: 2024-06-28 13:47+0100\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.4\n" + +#: tools/mask/cli.py:15 +msgid "" +"This tool allows you to generate, import, export or preview masks for " +"existing alignments." +msgstr "" +"Esta herramienta le permite generar, importar, exportar o obtener una vista " +"previa de máscaras para alineaciones existentes.\n" +"Genere, importe, exporte o obtenga una vista previa de máscaras para " +"archivos de alineaciones existentes." + +#: tools/mask/cli.py:25 +msgid "" +"Mask tool\n" +"Generate, import, export or preview masks for existing alignments files." +msgstr "" +"Herramienta de máscara\n" +"Genere, importe, exporte o obtenga una vista previa de máscaras para " +"archivos de alineaciones existentes." + +#: tools/mask/cli.py:35 tools/mask/cli.py:47 tools/mask/cli.py:58 +#: tools/mask/cli.py:69 +msgid "data" +msgstr "datos" + +#: tools/mask/cli.py:39 +msgid "" +"Full path to the alignments file that contains the masks if not at the " +"default location. NB: If the input-type is faces and you wish to update the " +"corresponding alignments file, then you must provide a value here as the " +"location cannot be automatically detected." +msgstr "" +"Ruta completa al archivo de alineaciones para agregar la máscara si no está " +"en la ubicación predeterminada. NB: si el tipo de entrada es caras y desea " +"actualizar el archivo de alineaciones correspondiente, debe proporcionar un " +"valor aquí ya que la ubicación no se puede detectar automáticamente." + +#: tools/mask/cli.py:51 +msgid "Directory containing extracted faces, source frames, or a video file." +msgstr "" +"Directorio que contiene las caras extraídas, los fotogramas de origen o un " +"archivo de vídeo." + +#: tools/mask/cli.py:61 +msgid "" +"R|Whether the `input` is a folder of faces or a folder frames/video\n" +"L|faces: The input is a folder containing extracted faces.\n" +"L|frames: The input is a folder containing frames or is a video" +msgstr "" +"R|Si la entrada es una carpeta de caras o una carpeta frames o vídeo\n" +"L|faces: La entrada es una carpeta que contiene caras extraídas.\n" +"L|frames: La entrada es una carpeta que contiene fotogramas o es un vídeo" + +#: tools/mask/cli.py:71 +msgid "" +"R|Run the mask tool on multiple sources. If selected then the other options " +"should be set as follows:\n" +"L|input: A parent folder containing either all of the video files to be " +"processed, or containing sub-folders of frames/faces.\n" +"L|output-folder: If provided, then sub-folders will be created within the " +"given location to hold the previews for each input.\n" +"L|alignments: Alignments field will be ignored for batch processing. The " +"alignments files must exist at the default location (for frames). For batch " +"processing of masks with 'faces' as the input type, then only the PNG header " +"within the extracted faces will be updated." +msgstr "" +"R|Ejecute la herramienta de máscara en varias fuentes. Si se selecciona, las " +"otras opciones deben configurarse de la siguiente manera:\n" +"L|input: una carpeta principal que contiene todos los archivos de video que " +"se procesarán o que contiene subcarpetas de marcos/caras.\n" +"L|output-folder: si se proporciona, se crearán subcarpetas dentro de la " +"ubicación dada para contener las vistas previas de cada entrada.\n" +"L|alignments: el campo de alineaciones se ignorará para el procesamiento por " +"lotes. Los archivos de alineaciones deben existir en la ubicación " +"predeterminada (para marcos). Para el procesamiento por lotes de máscaras " +"con 'caras' como tipo de entrada, solo se actualizará el encabezado PNG " +"dentro de las caras extraídas." + +#: tools/mask/cli.py:87 tools/mask/cli.py:119 +msgid "process" +msgstr "proceso" + +#: tools/mask/cli.py:89 +msgid "" +"R|Masker to use.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members. Profile faces " +"may result in sub-par performance." +msgstr "" +"R|Máscara a utilizar.\n" +"L|bisenet-fp: Máscara relativamente ligera basada en NN que proporciona un " +"control más refinado sobre el área a enmascarar, incluido el enmascaramiento " +"completo de la cabeza (configurable en la configuración de la máscara).\n" +"L|components: Máscara diseñada para proporcionar una segmentación facial " +"basada en la posición de los puntos de referencia. Se construye un casco " +"convexo alrededor del exterior de los puntos de referencia para crear una " +"máscara.\n" +"L|custom: Una máscara ficticia que llena el área de la máscara con 1 o 0 " +"(configurable en la configuración). Esto solo es necesario si tiene la " +"intención de editar manualmente las máscaras personalizadas usted mismo en " +"la herramienta manual. Esta máscara no utiliza la GPU.\n" +"L|extended: Máscara diseñada para proporcionar una segmentación facial " +"basada en el posicionamiento de las ubicaciones de los puntos de referencia. " +"Se construye un casco convexo alrededor del exterior de los puntos de " +"referencia y la máscara se extiende hacia arriba en la frente.\n" +"L|vgg-clear: Máscara diseñada para proporcionar una segmentación inteligente " +"de rostros principalmente frontales y libres de obstrucciones. Los rostros " +"de perfil y las obstrucciones pueden dar lugar a un rendimiento inferior.\n" +"L|vgg-obstructed: Máscara diseñada para proporcionar una segmentación " +"inteligente de rostros principalmente frontales. El modelo de máscara ha " +"sido entrenado específicamente para reconocer algunas obstrucciones faciales " +"(manos y gafas). Los rostros de perfil pueden dar lugar a un rendimiento " +"inferior.\n" +"L|unet-dfl: Máscara diseñada para proporcionar una segmentación inteligente " +"de rostros principalmente frontales. El modelo de máscara ha sido entrenado " +"por los miembros de la comunidad y necesitará ser probado para una mayor " +"descripción. Los rostros de perfil pueden dar lugar a un rendimiento " +"inferior." + +#: tools/mask/cli.py:121 +msgid "" +"R|The Mask tool process to perform.\n" +"L|all: Update the mask for all faces in the alignments file for the selected " +"'masker'.\n" +"L|missing: Create a mask for all faces in the alignments file where a mask " +"does not previously exist for the selected 'masker'.\n" +"L|output: Don't update the masks, just output the selected 'masker' for " +"review/editing in external tools to the given output folder.\n" +"L|import: Import masks that have been edited outside of faceswap into the " +"alignments file. Note: 'custom' must be the selected 'masker' and the masks " +"must be in the same format as the 'input-type' (frames or faces)" +msgstr "" +"R|Процесс инструмента «Маска», который необходимо выполнить.\n" +"L|all: обновить маску для всех лиц в файле выравниваний для выбранного " +"«masker».\n" +"L|missing: создать маску для всех граней в файле выравниваний, где маска " +"ранее не существовала для выбранного «masker».\n" +"L|output: не обновляйте маски, просто выведите выбранный «masker» для " +"просмотра/редактирования во внешних инструментах в данную выходную папку.\n" +"L|import: импортируйте маски, которые были отредактированы вне Facewap, в " +"файл выравниваний. Примечание. «custom» должен быть выбранным «masker», а " +"маски должны быть в том же формате, что и «input-type» (frames или faces)." + +#: tools/mask/cli.py:135 tools/mask/cli.py:154 tools/mask/cli.py:176 +msgid "import" +msgstr "importar" + +#: tools/mask/cli.py:137 +msgid "" +"R|Import only. The path to the folder that contains masks to be imported.\n" +"L|How the masks are provided is not important, but they will be stored, " +"internally, as 8-bit grayscale images.\n" +"L|If the input are images, then the masks must be named exactly the same as " +"input frames/faces (excluding the file extension).\n" +"L|If the input is a video file, then the filename of the masks is not " +"important but should contain the frame number at the end of the filename " +"(but before the file extension). The frame number can be separated from the " +"rest of the filename by any non-numeric character and can be padded by any " +"number of zeros. The frame number must correspond correctly to the frame " +"number in the original video (starting from frame 1)." +msgstr "" +"R|Sólo importar. La ruta a la carpeta que contiene las máscaras que se " +"importarán.\n" +"L|Cómo se proporcionan las máscaras no es importante, pero se almacenarán " +"internamente como imágenes en escala de grises de 8 bits.\n" +"L|Si la entrada son imágenes, entonces las máscaras deben tener el mismo " +"nombre que los cuadros/caras de entrada (excluyendo la extensión del " +"archivo).\n" +"L|Si la entrada es un archivo de vídeo, entonces el nombre del archivo de " +"las máscaras no es importante pero debe contener el número de fotograma al " +"final del nombre del archivo (pero antes de la extensión del archivo). El " +"número de fotograma se puede separar del resto del nombre del archivo " +"mediante cualquier carácter no numérico y se puede rellenar con cualquier " +"número de ceros. El número de fotograma debe corresponder correctamente al " +"número de fotograma del vídeo original (a partir del fotograma 1)." + +#: tools/mask/cli.py:156 +msgid "" +"R|Import/Output only. When importing masks, this is the centering to use. " +"For output this is only used for outputting custom imported masks, and " +"should correspond to the centering used when importing the mask. Note: For " +"any job other than 'import' and 'output' this option is ignored as mask " +"centering is handled internally.\n" +"L|face: Centers the mask on the center of the face, adjusting for pitch and " +"yaw. Outside of requirements for full head masking/training, this is likely " +"to be the best choice.\n" +"L|head: Centers the mask on the center of the head, adjusting for pitch and " +"yaw. Note: You should only select head centering if you intend to include " +"the full head (including hair) within the mask and are looking to train a " +"full head model.\n" +"L|legacy: The 'original' extraction technique. Centers the mask near the of " +"the nose with and crops closely to the face. Can result in the edges of the " +"mask appearing outside of the training area." +msgstr "" +"R|Solo importación/salida. Al importar máscaras, este es el centrado que se " +"debe utilizar. Para la salida, esto solo se utiliza para generar máscaras " +"importadas personalizadas y debe corresponder al centrado utilizado al " +"importar la máscara. Nota: Para cualquier trabajo que no sea \"importación\" " +"y \"salida\", esta opción se ignora ya que el centrado de la máscara se " +"maneja internamente.\n" +"L|cara: centra la máscara en el centro de la cara, ajustando el tono y la " +"orientación. Aparte de los requisitos para el entrenamiento/enmascaramiento " +"de cabeza completa, esta probablemente sea la mejor opción.\n" +"L|head: centra la máscara en el centro de la cabeza, ajustando el cabeceo y " +"la guiñada. Nota: Sólo debe seleccionar el centrado de la cabeza si desea " +"incluir la cabeza completa (incluido el cabello) dentro de la máscara y " +"desea entrenar un modelo de cabeza completa.\n" +"L|legacy: La técnica de extracción 'original'. Centra la máscara cerca de la " +"nariz y la recorta cerca de la cara. Puede provocar que los bordes de la " +"máscara aparezcan fuera del área de entrenamiento." + +#: tools/mask/cli.py:181 +msgid "" +"Import only. The size, in pixels to internally store the mask at.\n" +"The default is 128 which is fine for nearly all usecases. Larger sizes will " +"result in larger alignments files and longer processing." +msgstr "" +"Sólo importar. El tamaño, en píxeles, para almacenar internamente la " +"máscara.\n" +"El valor predeterminado es 128, que está bien para casi todos los casos de " +"uso. Los tamaños más grandes darán como resultado archivos de alineaciones " +"más grandes y un procesamiento más largo." + +#: tools/mask/cli.py:189 tools/mask/cli.py:197 tools/mask/cli.py:211 +#: tools/mask/cli.py:225 tools/mask/cli.py:235 +msgid "output" +msgstr "salida" + +#: tools/mask/cli.py:191 +msgid "" +"Optional output location. If provided, a preview of the masks created will " +"be output in the given folder." +msgstr "" +"Ubicación de salida opcional. Si se proporciona, se obtendrá una vista " +"previa de las máscaras creadas en la carpeta indicada." + +#: tools/mask/cli.py:202 +msgid "" +"Apply gaussian blur to the mask output. Has the effect of smoothing the " +"edges of the mask giving less of a hard edge. the size is in pixels. This " +"value should be odd, if an even number is passed in then it will be rounded " +"to the next odd number. NB: Only effects the output preview. Set to 0 for off" +msgstr "" +"Aplica el desenfoque gaussiano a la salida de la máscara. Tiene el efecto de " +"suavizar los bordes de la máscara dando menos de un borde duro. el tamaño " +"está en píxeles. Este valor debe ser impar, si se pasa un número par se " +"redondeará al siguiente número impar. NB: Sólo afecta a la vista previa de " +"salida. Si se ajusta a 0, se desactiva" + +#: tools/mask/cli.py:216 +msgid "" +"Helps reduce 'blotchiness' on some masks by making light shades white and " +"dark shades black. Higher values will impact more of the mask. NB: Only " +"effects the output preview. Set to 0 for off" +msgstr "" +"Ayuda a reducir la \"mancha\" en algunas máscaras haciendo que los tonos " +"claros sean blancos y los oscuros negros. Los valores más altos afectarán " +"más a la máscara. NB: Sólo afecta a la vista previa de salida. Si se ajusta " +"a 0, se desactiva" + +#: tools/mask/cli.py:227 +msgid "" +"R|How to format the output when processing is set to 'output'.\n" +"L|combined: The image contains the face/frame, face mask and masked face.\n" +"L|masked: Output the face/frame as rgba image with the face masked.\n" +"L|mask: Only output the mask as a single channel image." +msgstr "" +"R|Cómo formatear la salida cuando el procesamiento se establece en " +"'salida'.\n" +"L|combined: La imagen contiene la cara o fotograma, la máscara facial y la " +"cara enmascarada.\n" +"L|masked: Da salida a la cara o fotograma como imagen rgba con la cara " +"enmascarada.\n" +"L|mask: Sólo emite la máscara como una imagen de un solo canal." + +#: tools/mask/cli.py:237 +msgid "" +"R|Whether to output the whole frame or only the face box when using output " +"processing. Only has an effect when using frames as input." +msgstr "" +"R|Marcar esta opción dará como salida el fotograma completo, en vez de sólo " +"el cuadro de la cara cuando se utiliza el procesamiento de salida. Sólo " +"tiene efecto cuando se utilizan cuadros como entrada." + +#~ msgid "" +#~ "R|Whether to update all masks in the alignments files, only those faces " +#~ "that do not already have a mask of the given `mask type` or just to " +#~ "output the masks to the `output` location.\n" +#~ "L|all: Update the mask for all faces in the alignments file.\n" +#~ "L|missing: Create a mask for all faces in the alignments file where a " +#~ "mask does not previously exist.\n" +#~ "L|output: Don't update the masks, just output them for review in the " +#~ "given output folder." +#~ msgstr "" +#~ "R|Si se actualizan todas las máscaras en los archivos de alineación, sólo " +#~ "aquellas caras que no tienen ya una máscara del \"tipo de máscara\" dado " +#~ "o sólo se envían las máscaras a la ubicación \"de salida\".\n" +#~ "L|all: Actualiza la máscara de todas las caras del archivo de " +#~ "alineación.\n" +#~ "L|missing: Crea una máscara para todas las caras del fichero de " +#~ "alineaciones en las que no existe una máscara previamente.\n" +#~ "L|output: No actualiza las máscaras, sólo las emite para su revisión en " +#~ "la carpeta de salida dada." + +#~ msgid "" +#~ "Full path to the alignments file to add the mask to. NB: if the mask " +#~ "already exists in the alignments file it will be overwritten." +#~ msgstr "" +#~ "Ruta completa del archivo de alineaciones al que se añadirá la máscara. " +#~ "Nota: si la máscara ya existe en el archivo de alineaciones, se " +#~ "sobrescribirá." diff --git a/locales/es/LC_MESSAGES/tools.model.cli.mo b/locales/es/LC_MESSAGES/tools.model.cli.mo new file mode 100644 index 0000000000..55dd5dba0e Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.model.cli.mo differ diff --git a/locales/es/LC_MESSAGES/tools.model.cli.po b/locales/es/LC_MESSAGES/tools.model.cli.po new file mode 100644 index 0000000000..56079517ca --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.model.cli.po @@ -0,0 +1,90 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:51+0000\n" +"PO-Revision-Date: 2024-03-29 00:00+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: es\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/model/cli.py:13 +msgid "This tool lets you perform actions on saved Faceswap models." +msgstr "" +"Esta herramienta le permite realizar acciones en modelos Faceswap guardados." + +#: tools/model/cli.py:22 +msgid "A tool for performing actions on Faceswap trained model files" +msgstr "" +"Una herramienta para realizar acciones en archivos de modelos entrenados " +"Faceswap" + +#: tools/model/cli.py:34 +msgid "" +"Model directory. A directory containing the model you wish to perform an " +"action on." +msgstr "" +"Directorio de modelo. Un directorio que contiene el modelo en el que desea " +"realizar una acción." + +#: tools/model/cli.py:43 +msgid "" +"R|Choose which action you want to perform.\n" +"L|'inference' - Create an inference only copy of the model. Strips any " +"layers from the model which are only required for training. NB: This is for " +"exporting the model for use in external applications. Inference generated " +"models cannot be used within Faceswap. See the 'format' option for " +"specifying the model output format.\n" +"L|'nan-scan' - Scan the model file for NaNs or Infs (invalid data).\n" +"L|'restore' - Restore a model from backup." +msgstr "" +"R|Elige qué acción quieres realizar.\n" +"L|'inference': crea una copia del modelo solo de inferencia. Elimina las " +"capas del modelo que solo se requieren para el entrenamiento. NB: Esto es " +"para exportar el modelo para su uso en aplicaciones externas. Los modelos " +"generados por inferencia no se pueden usar en Faceswap. Consulte la opción " +"'formato' para especificar el formato de salida del modelo.\n" +"L|'nan-scan': escanea el archivo del modelo en busca de NaN o Inf (datos no " +"válidos).\n" +"L|'restore': restaura un modelo desde una copia de seguridad." + +#: tools/model/cli.py:57 tools/model/cli.py:69 +msgid "inference" +msgstr "inferencia" + +#: tools/model/cli.py:59 +msgid "" +"R|The format to save the model as. Note: Only used for 'inference' job.\n" +"L|'h5' - Standard Keras H5 format. Does not store any custom layer " +"information. Layers will need to be loaded from Faceswap to use.\n" +"L|'saved-model' - Tensorflow's Saved Model format. Contains all information " +"required to load the model outside of Faceswap." +msgstr "" +"R|El formato para guardar el modelo. Nota: Solo se usa para el trabajo de " +"'inference'.\n" +"L|'h5' - Formato estándar de Keras H5. No almacena ninguna información de " +"capa personalizada. Las capas deberán cargarse desde Faceswap para usar.\n" +"L|'saved-model': formato de modelo guardado de Tensorflow. Contiene toda la " +"información necesaria para cargar el modelo fuera de Faceswap." + +#: tools/model/cli.py:71 +#, fuzzy +#| msgid "" +#| "Only used for 'inference' job. Generate the inference model for B -> A " +#| "instead of A -> B." +msgid "" +"Only used for 'inference' job. Generate the inference model for B -> A " +"instead of A -> B." +msgstr "" +"Solo se usa para el trabajo de 'inference'. Genere el modelo de inferencia " +"para B -> A en lugar de A -> B." diff --git a/locales/es/LC_MESSAGES/tools.preview.mo b/locales/es/LC_MESSAGES/tools.preview.mo new file mode 100644 index 0000000000..955c957645 Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.preview.mo differ diff --git a/locales/es/LC_MESSAGES/tools.preview.po b/locales/es/LC_MESSAGES/tools.preview.po new file mode 100644 index 0000000000..f9cfb9218a --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.preview.po @@ -0,0 +1,93 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: 2024-03-29 00:00+0000\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/preview/cli.py:15 +msgid "This command allows you to preview swaps to tweak convert settings." +msgstr "" +"Este comando permite previsualizar los intercambios para ajustar la " +"configuración de la conversión." + +#: tools/preview/cli.py:30 +msgid "" +"Preview tool\n" +"Allows you to configure your convert settings with a live preview" +msgstr "" +"Herramienta de vista previa\n" +"Permite configurar los ajustes de conversión con una vista previa en directo" + +#: tools/preview/cli.py:47 tools/preview/cli.py:57 tools/preview/cli.py:65 +msgid "data" +msgstr "datos" + +#: tools/preview/cli.py:50 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file." +msgstr "" +"Directorio o vídeo de entrada. Un directorio que contenga los archivos de " +"imagen que desea procesar o la ruta a un archivo de vídeo." + +#: tools/preview/cli.py:60 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" +"Ruta del archivo de alineaciones para la entrada, si no está en la ubicación " +"por defecto" + +#: tools/preview/cli.py:68 +msgid "" +"Model directory. A directory containing the trained model you wish to " +"process." +msgstr "" +"Directorio del modelo. Un directorio que contiene el modelo entrenado que " +"desea procesar." + +#: tools/preview/cli.py:74 +msgid "Swap the model. Instead of A -> B, swap B -> A" +msgstr "Intercambiar el modelo. En lugar de convertir A en B, convierte B en A" + +#: tools/preview/control_panels.py:510 +msgid "Save full config" +msgstr "Guardar la configuración completa" + +#: tools/preview/control_panels.py:513 +msgid "Reset full config to default values" +msgstr "Restablecer la configuración completa a los valores por defecto" + +#: tools/preview/control_panels.py:516 +msgid "Reset full config to saved values" +msgstr "Restablecer la configuración completa a los valores guardados" + +#: tools/preview/control_panels.py:667 +#, python-brace-format +msgid "Save {title} config" +msgstr "Guardar la configuración de {title}" + +#: tools/preview/control_panels.py:670 +#, python-brace-format +msgid "Reset {title} config to default values" +msgstr "" +"Restablecer la configuración completa de {title} a los valores por defecto" + +#: tools/preview/control_panels.py:673 +#, python-brace-format +msgid "Reset {title} config to saved values" +msgstr "" +"Restablecer la configuración completa de {title} a los valores guardados" diff --git a/locales/es/LC_MESSAGES/tools.sort.cli.mo b/locales/es/LC_MESSAGES/tools.sort.cli.mo new file mode 100644 index 0000000000..1eea2cf248 Binary files /dev/null and b/locales/es/LC_MESSAGES/tools.sort.cli.mo differ diff --git a/locales/es/LC_MESSAGES/tools.sort.cli.po b/locales/es/LC_MESSAGES/tools.sort.cli.po new file mode 100644 index 0000000000..0914d6011c --- /dev/null +++ b/locales/es/LC_MESSAGES/tools.sort.cli.po @@ -0,0 +1,551 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: faceswap.spanish\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: 2024-03-29 00:03+0000\n" +"Last-Translator: \n" +"Language-Team: tokafondo\n" +"Language: es_ES\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/sort/cli.py:15 +msgid "This command lets you sort images using various methods." +msgstr "" +"Este comando le permite ordenar las imágenes utilizando varios métodos." + +#: tools/sort/cli.py:21 +msgid "" +" Adjust the '-t' ('--threshold') parameter to control the strength of " +"grouping." +msgstr "" +" Ajuste el parámetro '-t' ('--threshold') para controlar la fuerza de la " +"agrupación." + +#: tools/sort/cli.py:22 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the percentage of color pixels " +"that appear in the image." +msgstr "" +" Ajuste el parámetro '-b' ('--bins') para controlar el número de " +"contenedores para agrupar. Cada imagen se asigna a un contenedor por el " +"porcentaje de píxeles de color que aparecen en la imagen." + +#: tools/sort/cli.py:25 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the number of degrees the face " +"is orientated from center." +msgstr "" +" Ajuste el parámetro '-b' ('--bins') para controlar el número de " +"contenedores para agrupar. Cada imagen se asigna a un contenedor por el " +"número de grados que la cara está orientada desde el centro." + +#: tools/sort/cli.py:28 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. The minimum and maximum values are taken for the chosen sort " +"metric. The bins are then populated with the results from the group sorting." +msgstr "" +" Ajuste el parámetro '-b' ('--bins') para controlar el número de " +"contenedores para agrupar. Los valores mínimo y máximo se toman para la " +"métrica de clasificación elegida. Luego, los contenedores se llenan con los " +"resultados de la clasificación de grupos." + +#: tools/sort/cli.py:32 +msgid "faces by blurriness." +msgstr "rostros por desenfoque." + +#: tools/sort/cli.py:33 +msgid "faces by fft filtered blurriness." +msgstr "caras por borrosidad filtrada fft." + +#: tools/sort/cli.py:34 +msgid "" +"faces by the estimated distance of the alignments from an 'average' face. " +"This can be useful for eliminating misaligned faces. Sorts from most like an " +"average face to least like an average face." +msgstr "" +"caras por la distancia estimada de las alineaciones desde una cara " +"'promedio'. Esto puede ser útil para eliminar caras desalineadas. Ordena de " +"más parecido a un rostro promedio a menos parecido a un rostro promedio." + +#: tools/sort/cli.py:37 +msgid "" +"faces using VGG Face2 by face similarity. This uses a pairwise clustering " +"algorithm to check the distances between 512 features on every face in your " +"set and order them appropriately." +msgstr "" +"caras usando VGG Face2 por similitud de caras. Esto utiliza un algoritmo de " +"agrupamiento por pares para verificar las distancias entre 512 " +"características en cada cara de su conjunto y ordenarlas apropiadamente." + +#: tools/sort/cli.py:40 +msgid "faces by their landmarks." +msgstr "caras por sus puntos de referencia." + +#: tools/sort/cli.py:41 +msgid "Like 'face-cnn' but sorts by dissimilarity." +msgstr "Como 'face-cnn' pero ordenada por la similitud." + +#: tools/sort/cli.py:42 +msgid "faces by Yaw (rotation left to right)." +msgstr "caras por guiñada (rotación de izquierda a derecha)." + +#: tools/sort/cli.py:43 +msgid "faces by Pitch (rotation up and down)." +msgstr "caras por Pitch (rotación arriba y abajo)." + +#: tools/sort/cli.py:44 +msgid "" +"faces by Roll (rotation). Aligned faces should have a roll value close to " +"zero. The further the Roll value from zero the higher liklihood the face is " +"misaligned." +msgstr "" +"caras por Roll (rotación). Las caras alineadas deben tener un valor de " +"balanceo cercano a cero. Cuanto más lejos esté el valor de Roll de cero, " +"mayor será la probabilidad de que la cara esté desalineada." + +#: tools/sort/cli.py:46 +msgid "faces by their color histogram." +msgstr "caras por su histograma de color." + +#: tools/sort/cli.py:47 +msgid "Like 'hist' but sorts by dissimilarity." +msgstr "Como 'hist' pero ordenada por la disimilitud." + +#: tools/sort/cli.py:48 +msgid "" +"images by the average intensity of the converted grayscale color channel." +msgstr "" +"imágenes por la intensidad media del canal de color en escala de grises " +"convertido." + +#: tools/sort/cli.py:49 +msgid "" +"images by their number of black pixels. Useful when faces are near borders " +"and a large part of the image is black." +msgstr "" +"imágenes por su número de píxeles negros. Útil cuando las caras están cerca " +"de los bordes y una gran parte de la imagen es negra." + +#: tools/sort/cli.py:51 +msgid "" +"images by the average intensity of the converted Y color channel. Bright " +"lighting and oversaturated images will be ranked first." +msgstr "" +"imágenes por la intensidad media del canal de color Y convertido. La " +"iluminación brillante y las imágenes sobresaturadas se clasificarán en " +"primer lugar." + +#: tools/sort/cli.py:53 +msgid "" +"images by the average intensity of the converted Cg color channel. Green " +"images will be ranked first and red images will be last." +msgstr "" +"imágenes por la intensidad media del canal de color Cg convertido. Las " +"imágenes verdes se clasificarán primero y las imágenes rojas serán las " +"últimas." + +#: tools/sort/cli.py:55 +msgid "" +"images by the average intensity of the converted Co color channel. Orange " +"images will be ranked first and blue images will be last." +msgstr "" +"imágenes por la intensidad media del canal de color Co convertido. Las " +"imágenes naranjas se clasificarán en primer lugar y las imágenes azules en " +"último lugar." + +#: tools/sort/cli.py:57 +msgid "" +"images by their size in the original frame. Faces further from the camera " +"and from lower resolution sources will be sorted first, whilst faces closer " +"to the camera and from higher resolution sources will be sorted last." +msgstr "" +"imágenes por su tamaño en el marco original. Las caras más alejadas de la " +"cámara y de fuentes de menor resolución se ordenarán primero, mientras que " +"las caras más cercanas a la cámara y de fuentes de mayor resolución se " +"ordenarán en último lugar." + +#: tools/sort/cli.py:81 +msgid "Sort faces using a number of different techniques" +msgstr "Clasificar los rostros mediante diferentes técnicas" + +#: tools/sort/cli.py:91 tools/sort/cli.py:98 tools/sort/cli.py:110 +#: tools/sort/cli.py:150 +msgid "data" +msgstr "datos" + +#: tools/sort/cli.py:92 +msgid "Input directory of aligned faces." +msgstr "Directorio de entrada de caras alineadas." + +#: tools/sort/cli.py:100 +msgid "" +"Output directory for sorted aligned faces. If not provided and 'keep' is " +"selected then a new folder called 'sorted' will be created within the input " +"folder to house the output. If not provided and 'keep' is not selected then " +"the images will be sorted in-place, overwriting the original contents of the " +"'input_dir'" +msgstr "" +"Directorio de salida para caras alineadas ordenadas. Si no se proporciona y " +"se selecciona 'keep', se creará una nueva carpeta llamada 'sorted' dentro de " +"la carpeta de entrada para albergar la salida. Si no se proporciona y no se " +"selecciona 'keep', las imágenes se ordenarán en el lugar, sobrescribiendo el " +"contenido original de 'input_dir'" + +#: tools/sort/cli.py:112 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple folders of faces you wish to sort. The faces will be output to " +"separate sub-folders in the output_dir" +msgstr "" +"R|Si se selecciona, input_dir debe ser una carpeta principal que contenga " +"varias carpetas de caras que desea ordenar. Las caras se enviarán a " +"subcarpetas separadas en output_dir" + +#: tools/sort/cli.py:121 +msgid "sort settings" +msgstr "ajustes de ordenación" + +#: tools/sort/cli.py:124 +msgid "" +"R|Choose how images are sorted. Selecting a sort method gives the images a " +"new filename based on the order the image appears within the given method.\n" +"L|'none': Don't sort the images. When a 'group-by' method is selected, " +"selecting 'none' means that the files will be moved/copied into their " +"respective bins, but the files will keep their original filenames. Selecting " +"'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" +"R|Elige cómo se ordenan las imágenes. Al seleccionar un método de " +"clasificación, las imágenes reciben un nuevo nombre de archivo basado en el " +"orden en que aparece la imagen dentro del método dado.\n" +"L|'none': No ordenar las imágenes. Cuando se selecciona un método de " +"'agrupar por', seleccionar 'none' significa que los archivos se moverán/" +"copiarán en sus contenedores respectivos, pero los archivos mantendrán sus " +"nombres de archivo originales. Seleccionar 'none' para 'sort-by' y 'group-" +"by' no hará nada" + +#: tools/sort/cli.py:136 tools/sort/cli.py:164 tools/sort/cli.py:184 +msgid "group settings" +msgstr "ajustes de grupo" + +#: tools/sort/cli.py:139 +#, fuzzy +#| msgid "" +#| "R|Selecting a group by method will move/copy files into numbered bins " +#| "based on the selected method.\n" +#| "L|'none': Don't bin the images. Folders will be sorted by the selected " +#| "'sort-by' but will not be binned, instead they will be sorted into a " +#| "single folder. Selecting 'none' for both 'sort-by' and 'group-by' will " +#| "do nothing" +msgid "" +"R|Selecting a group by method will move/copy files into numbered bins based " +"on the selected method.\n" +"L|'none': Don't bin the images. Folders will be sorted by the selected 'sort-" +"by' but will not be binned, instead they will be sorted into a single " +"folder. Selecting 'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" +"R|Al seleccionar un grupo por método, los archivos se moverán/copiarán en " +"contenedores numerados según el método seleccionado.\n" +"L|'none': No agrupar las imágenes. Las carpetas se ordenarán por el 'sort-" +"by' seleccionado, pero no se agruparán, sino que se ordenarán en una sola " +"carpeta. Seleccionar 'none' para 'sort-by' y 'group-by' no hará nada" + +#: tools/sort/cli.py:152 +msgid "" +"Whether to keep the original files in their original location. Choosing a " +"'sort-by' method means that the files have to be renamed. Selecting 'keep' " +"means that the original files will be kept, and the renamed files will be " +"created in the specified output folder. Unselecting keep means that the " +"original files will be moved and renamed based on the selected sort/group " +"criteria." +msgstr "" +"Ya sea para mantener los archivos originales en su ubicación original. " +"Elegir un método de 'sort-by' significa que los archivos tienen que ser " +"renombrados. Seleccionar 'keep' significa que los archivos originales se " +"mantendrán y los archivos renombrados se crearán en la carpeta de salida " +"especificada. Deseleccionar 'keep' significa que los archivos originales se " +"moverán y cambiarán de nombre en función de los criterios de clasificación/" +"grupo seleccionados." + +#: tools/sort/cli.py:167 +msgid "" +"R|Float value. Minimum threshold to use for grouping comparison with 'face-" +"cnn' 'hist' and 'face' methods.\n" +"The lower the value the more discriminating the grouping is. Leaving -1.0 " +"will allow Faceswap to choose the default value.\n" +"L|For 'face-cnn' 7.2 should be enough, with 4 being very discriminating. \n" +"L|For 'hist' 0.3 should be enough, with 0.2 being very discriminating. \n" +"L|For 'face' between 0.1 (more bins) to 0.5 (fewer bins) should be about " +"right.\n" +"Be careful setting a value that's too extrene in a directory with many " +"images, as this could result in a lot of folders being created. Defaults: " +"face-cnn 7.2, hist 0.3, face 0.25" +msgstr "" +"R|Valor flotante. Umbral mínimo a usar para agrupar la comparación con los " +"métodos 'face-cnn' 'hist' y 'face'.\n" +"Cuanto más bajo es el valor, más discriminatoria es la agrupación. Dejar " +"-1.0 permitirá que Faceswap elija el valor predeterminado.\n" +"L|Para 'face-cnn' 7.2 debería ser suficiente, siendo 4 muy discriminatorio.\n" +"L|Para 'hist' 0.3 debería ser suficiente, siendo 0.2 muy discriminatorio.\n" +"L|Para 'face', entre 0,1 (más contenedores) y 0,4 (pocos contenedores) " +"debería ser correcto.\n" +"Tenga cuidado al establecer un valor que sea demasiado extremo en un " +"directorio con muchas imágenes, ya que esto podría resultar en la creación " +"de muchas carpetas. Valores predeterminados: face-cnn 7.2, hist 0.3, face " +"0.25" + +#: tools/sort/cli.py:187 +#, fuzzy, python-format +#| msgid "" +#| "R|Integer value. Used to control the number of bins created for grouping " +#| "by: any 'blur' methods, 'color' methods or 'face metric' methods " +#| "('distance', 'size') and 'orientation; methods ('yaw', 'pitch'). For any " +#| "other grouping methods see the '-t' ('--threshold') option.\n" +#| "L|For 'face metric' methods the bins are filled, according the the " +#| "distribution of faces between the minimum and maximum chosen metric.\n" +#| "L|For 'color' methods the number of bins represents the divider of the " +#| "percentage of colored pixels. Eg. For a bin number of '5': The first " +#| "folder will have the faces with 0%% to 20%% colored pixels, second 21%% " +#| "to 40%%, etc. Any empty bins will be deleted, so you may end up with " +#| "fewer bins than selected.\n" +#| "L|For 'blur' methods folder 0 will be the least blurry, while the last " +#| "folder will be the blurriest.\n" +#| "L|For 'orientation' methods the number of bins is dictated by how much " +#| "180 degrees is divided. Eg. If 18 is selected, then each folder will be a " +#| "10 degree increment. Folder 0 will contain faces looking the most to the " +#| "left/down whereas the last folder will contain the faces looking the most " +#| "to the right/up. NB: Some bins may be empty if faces do not fit the " +#| "criteria.\n" +#| "Default value: 5" +msgid "" +"R|Integer value. Used to control the number of bins created for grouping by: " +"any 'blur' methods, 'color' methods or 'face metric' methods ('distance', " +"'size') and 'orientation; methods ('yaw', 'pitch'). For any other grouping " +"methods see the '-t' ('--threshold') option.\n" +"L|For 'face metric' methods the bins are filled, according the the " +"distribution of faces between the minimum and maximum chosen metric.\n" +"L|For 'color' methods the number of bins represents the divider of the " +"percentage of colored pixels. Eg. For a bin number of '5': The first folder " +"will have the faces with 0%% to 20%% colored pixels, second 21%% to 40%%, " +"etc. Any empty bins will be deleted, so you may end up with fewer bins than " +"selected.\n" +"L|For 'blur' methods folder 0 will be the least blurry, while the last " +"folder will be the blurriest.\n" +"L|For 'orientation' methods the number of bins is dictated by how much 180 " +"degrees is divided. Eg. If 18 is selected, then each folder will be a 10 " +"degree increment. Folder 0 will contain faces looking the most to the left/" +"down whereas the last folder will contain the faces looking the most to the " +"right/up. NB: Some bins may be empty if faces do not fit the criteria. \n" +"Default value: 5" +msgstr "" +"R|Valor entero. Se utiliza para controlar el número de contenedores creados " +"para agrupar por: cualquier método de 'blur', método de 'color' o método de " +"'face metric' ('distance', 'size') y 'orientación; métodos ('yaw', 'pitch'). " +"Para cualquier otro método de agrupación, consulte la opción '-t' ('--" +"threshold').\n" +"L|Para los métodos de 'face metric', los contenedores se llenan de acuerdo " +"con la distribución de caras entre la métrica mínima y máxima elegida.\n" +"L|Para los métodos de 'color', el número de contenedores representa el " +"divisor del porcentaje de píxeles coloreados. P.ej. Para un número de " +"contenedor de '5': la primera carpeta tendrá las caras con 0%% a 20%% " +"píxeles de color, la segunda 21%% a 40%%, etc. Se eliminarán todos los " +"contenedores vacíos, por lo que puede terminar con menos contenedores que " +"los seleccionados.\n" +"L|Para los métodos 'blur', la carpeta 0 será la menos borrosa, mientras que " +"la última carpeta será la más borrosa.\n" +"L|Para los métodos de 'orientation', el número de contenedores está dictado " +"por cuánto se dividen 180 grados. P.ej. Si se selecciona 18, cada carpeta " +"tendrá un incremento de 10 grados. La carpeta 0 contendrá las caras que " +"miran más hacia la izquierda/abajo, mientras que la última carpeta contendrá " +"las caras que miran más hacia la derecha/arriba. NB: algunos contenedores " +"pueden estar vacíos si las caras no se ajustan a los criterios.\n" +"Valor predeterminado: 5" + +#: tools/sort/cli.py:207 tools/sort/cli.py:217 +msgid "settings" +msgstr "ajustes" + +#: tools/sort/cli.py:210 +msgid "" +"Logs file renaming changes if grouping by renaming, or it logs the file " +"copying/movement if grouping by folders. If no log file is specified with " +"'--log-file', then a 'sort_log.json' file will be created in the input " +"directory." +msgstr "" +"Registra los cambios en el nombre de los archivos si se agrupa por nombre, o " +"registra la copia o movimiento de archivos si se agrupa por carpetas. Si no " +"se especifica ningún archivo de registro con '--log-file', se creará un " +"archivo 'sort_log.json' en el directorio de entrada." + +#: tools/sort/cli.py:221 +msgid "" +"Specify a log file to use for saving the renaming or grouping information. " +"If specified extension isn't 'json' or 'yaml', then json will be used as the " +"serializer, with the supplied filename. Default: sort_log.json" +msgstr "" +"Especifica un archivo de registro que se utilizará para guardar la " +"información de renombrado o agrupación. Si la extensión especificada no es " +"'json' o 'yaml', se utilizará json como serializador, con el nombre de " +"archivo suministrado. Por defecto: sort_log.json" + +#~ msgid " option is deprecated. Use 'yaw'" +#~ msgstr " la opción está en desuso. Usa 'yaw'" + +#~ msgid " option is deprecated. Use 'color-black'" +#~ msgstr " la opción está en desuso. Usa 'color-black'" + +#~ msgid "output" +#~ msgstr "salida" + +#~ msgid "" +#~ "Deprecated and no longer used. The final processing will be dictated by " +#~ "the sort/group by methods and whether 'keep_original' is selected." +#~ msgstr "" +#~ "En desuso y ya no se usa. El procesamiento final será dictado por los " +#~ "métodos de ordenación/agrupación y si se selecciona 'keepl'." + +#~ msgid "Output directory for sorted aligned faces." +#~ msgstr "Directorio de salida para las caras alineadas ordenadas." + +#~ msgid "" +#~ "R|Sort by method. Choose how images are sorted. \n" +#~ "L|'blur': Sort faces by blurriness.\n" +#~ "L|'blur-fft': Sort faces by fft filtered blurriness.\n" +#~ "L|'distance' Sort faces by the estimated distance of the alignments from " +#~ "an 'average' face. This can be useful for eliminating misaligned faces.\n" +#~ "L|'face': Use VGG Face to sort by face similarity. This uses a pairwise " +#~ "clustering algorithm to check the distances between 512 features on every " +#~ "face in your set and order them appropriately.\n" +#~ "L|'face-cnn': Sort faces by their landmarks. You can adjust the threshold " +#~ "with the '-t' (--ref_threshold) option.\n" +#~ "L|'face-cnn-dissim': Like 'face-cnn' but sorts by dissimilarity.\n" +#~ "L|'face-yaw': Sort faces by Yaw (rotation left to right).\n" +#~ "L|'hist': Sort faces by their color histogram. You can adjust the " +#~ "threshold with the '-t' (--ref_threshold) option.\n" +#~ "L|'hist-dissim': Like 'hist' but sorts by dissimilarity.\n" +#~ "L|'color-gray': Sort images by the average intensity of the converted " +#~ "grayscale color channel.\n" +#~ "L|'color-luma': Sort images by the average intensity of the converted Y " +#~ "color channel. Bright lighting and oversaturated images will be ranked " +#~ "first.\n" +#~ "L|'color-green': Sort images by the average intensity of the converted Cg " +#~ "color channel. Green images will be ranked first and red images will be " +#~ "last.\n" +#~ "L|'color-orange': Sort images by the average intensity of the converted " +#~ "Co color channel. Orange images will be ranked first and blue images will " +#~ "be last.\n" +#~ "L|'size': Sort images by their size in the original frame. Faces closer " +#~ "to the camera and from higher resolution sources will be sorted first, " +#~ "whilst faces further from the camera and from lower resolution sources " +#~ "will be sorted last.\n" +#~ "L|'black-pixels': Sort images by their number of black pixels. Useful " +#~ "when faces are near borders and a large part of the image is black.\n" +#~ "Default: face" +#~ msgstr "" +#~ "R|Método de ordenación. Elige cómo se ordenan las imágenes. \n" +#~ "L|'blur': Ordena las caras por desenfoque.\n" +#~ "L|'blur-fft': Ordena las caras por fft filtrado desenfoque.\n" +#~ "L|'distance' Ordene las caras por la distancia estimada de las " +#~ "alineaciones desde una cara \"promedio\". Esto puede resultar útil para " +#~ "eliminar caras desalineadas.\n" +#~ "L|'face': Utiliza VGG Face para ordenar por similitud de caras. Esto " +#~ "utiliza un algoritmo de agrupación por pares para comprobar las " +#~ "distancias entre 512 características en cada cara en su conjunto y " +#~ "ordenarlos adecuadamente.\n" +#~ "L|'face-cnn': Ordena las caras por sus puntos de referencia. Puedes " +#~ "ajustar el umbral con la opción '-t' (--ref_threshold).\n" +#~ "L|'face-cnn-dissim': Como 'face-cnn' pero ordena por disimilitud.\n" +#~ "L|'face-yaw': Ordena las caras por Yaw (rotación de izquierda a " +#~ "derecha).\n" +#~ "L|'hist': Ordena las caras por su histograma de color. Puedes ajustar el " +#~ "umbral con la opción '-t' (--ref_threshold).\n" +#~ "L|'hist-dissim': Como 'hist' pero ordena por disimilitud.\n" +#~ "L|'color-gray': Ordena las imágenes por la intensidad media del canal de " +#~ "color previa conversión a escala de grises convertido.\n" +#~ "L|'color-luma': Ordena las imágenes por la intensidad media del canal de " +#~ "color Y. Las imágenes muy brillantes y sobresaturadas se clasificarán " +#~ "primero.\n" +#~ "L|'color-green': Ordena las imágenes por la intensidad media del canal de " +#~ "color Cg. Las imágenes verdes serán clasificadas primero y las rojas " +#~ "serán las últimas.\n" +#~ "L|'color-orange': Ordena las imágenes por la intensidad media del canal " +#~ "de color Co. Las imágenes naranjas serán clasificadas primero y las " +#~ "azules serán las últimas.\n" +#~ "L|'size': Ordena las imágenes por su tamaño en el marco original. Los " +#~ "rostros más cercanos a la cámara y de fuentes de mayor resolución se " +#~ "ordenarán primero, mientras que los rostros más alejados de la cámara y " +#~ "de fuentes de menor resolución se ordenarán en último lugar.\n" +#~ "\vL|'black-pixels': Ordene las imágenes por su número de píxeles negros. " +#~ "Útil cuando los rostros están cerca de los bordes y una gran parte de la " +#~ "imagen es negra .\n" +#~ "Por defecto: face" + +#~ msgid "" +#~ "Keeps the original files in the input directory. Be careful when using " +#~ "this with rename grouping and no specified output directory as this would " +#~ "keep the original and renamed files in the same directory." +#~ msgstr "" +#~ "Mantiene los archivos originales en el directorio de entrada. Tenga " +#~ "cuidado al usar esto con la agrupación de renombre y sin especificar el " +#~ "directorio de salida, ya que esto mantendría los archivos originales y " +#~ "renombrados en el mismo directorio." + +#~ msgid "" +#~ "R|Default: rename.\n" +#~ "L|'folders': files are sorted using the -s/--sort-by method, then they " +#~ "are organized into folders using the -g/--group-by grouping method.\n" +#~ "L|'rename': files are sorted using the -s/--sort-by then they are renamed." +#~ msgstr "" +#~ "R|Por defecto: renombrar.\n" +#~ "L|'folders': los archivos se ordenan utilizando el método -s/--sort-by, y " +#~ "luego se organizan en carpetas utilizando el método de agrupación -g/--" +#~ "group-by.\n" +#~ "L|'rename': los archivos se ordenan utilizando el método -s/--sort-by y " +#~ "luego se renombran." + +#~ msgid "" +#~ "Group by method. When -fp/--final-processing by folders choose the how " +#~ "the images are grouped after sorting. Default: hist" +#~ msgstr "" +#~ "Método de agrupamiento. Elija la forma de agrupar las imágenes, en el " +#~ "caso de hacerlo por carpetas, después de la clasificación. Por defecto: " +#~ "hist" + +#, python-format +#~ msgid "" +#~ "Integer value. Number of folders that will be used to group by blur, face-" +#~ "yaw and black-pixels. For blur folder 0 will be the least blurry, while " +#~ "the last folder will be the blurriest. For face-yaw the number of bins is " +#~ "by how much 180 degrees is divided. So if you use 18, then each folder " +#~ "will be a 10 degree increment. Folder 0 will contain faces looking the " +#~ "most to the left whereas the last folder will contain the faces looking " +#~ "the most to the right. If the number of images doesn't divide evenly into " +#~ "the number of bins, the remaining images get put in the last bin. For " +#~ "black-pixels it represents the divider of the percentage of black pixels. " +#~ "For 10, first folder will have the faces with 0 to 10%% black pixels, " +#~ "second 11 to 20%%, etc. Default value: 5" +#~ msgstr "" +#~ "Valor entero. Número de carpetas que se utilizarán al agrupar por 'blur' " +#~ "y 'face-yaw'. Para 'blur' la carpeta 0 será la menos borrosa, mientras " +#~ "que la última carpeta será la más borrosa. Para 'face-yaw' el número de " +#~ "carpetas es por cuanto se dividen los 180 grados. Así que si usas 18, " +#~ "entonces cada carpeta será un incremento de 10 grados. La carpeta 0 " +#~ "contendrá las caras que miren más a la izquierda, mientras que la última " +#~ "carpeta contendrá las caras que miren más a la derecha. Si el número de " +#~ "imágenes no se divide uniformemente en el número de carpetas, las " +#~ "imágenes restantes se colocan en la última carpeta. Para píxeles negros, " +#~ "representa el divisor del porcentaje de píxeles negros. Para 10, la " +#~ "primera carpeta tendrá las caras con 0 a 10%% de píxeles negros, la " +#~ "segunda de 11 a 20%%, etc. Valor por defecto: 5" diff --git a/locales/faceswap.pot b/locales/faceswap.pot new file mode 100644 index 0000000000..8979c87555 --- /dev/null +++ b/locales/faceswap.pot @@ -0,0 +1,33 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"POT-Creation-Date: 2021-02-18 23:48-0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=cp1252\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: pygettext.py 1.5\n" + + +#: faceswap.py:43 +msgid "Extract the faces from pictures or a video" +msgstr "" + +#: faceswap.py:44 +msgid "Train a model for the two faces A and B" +msgstr "" + +#: faceswap.py:47 +msgid "Convert source pictures or video to a new one with the face swapped" +msgstr "" + +#: faceswap.py:48 +msgid "Launch the Faceswap Graphical User Interface" +msgstr "" + diff --git a/locales/gui.menu.pot b/locales/gui.menu.pot new file mode 100644 index 0000000000..a20a799f11 --- /dev/null +++ b/locales/gui.menu.pot @@ -0,0 +1,154 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-06-07 13:54+0100\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: ./lib/gui/menu.py:37 +msgid "faceswap.dev - Guides and Forum" +msgstr "" + +#: ./lib/gui/menu.py:38 +msgid "Patreon - Support this project" +msgstr "" + +#: ./lib/gui/menu.py:39 +msgid "Discord - The FaceSwap Discord server" +msgstr "" + +#: ./lib/gui/menu.py:40 +msgid "Github - Our Source Code" +msgstr "" + +#: ./lib/gui/menu.py:60 +msgid "File" +msgstr "" + +#: ./lib/gui/menu.py:61 +msgid "Settings" +msgstr "" + +#: ./lib/gui/menu.py:62 +msgid "Help" +msgstr "" + +#: ./lib/gui/menu.py:85 +msgid "Configure Settings..." +msgstr "" + +#: ./lib/gui/menu.py:116 +msgid "New Project..." +msgstr "" + +#: ./lib/gui/menu.py:121 +msgid "Open Project..." +msgstr "" + +#: ./lib/gui/menu.py:126 +msgid "Save Project" +msgstr "" + +#: ./lib/gui/menu.py:131 +msgid "Save Project as..." +msgstr "" + +#: ./lib/gui/menu.py:136 +msgid "Reload Project from Disk" +msgstr "" + +#: ./lib/gui/menu.py:141 +msgid "Close Project" +msgstr "" + +#: ./lib/gui/menu.py:147 +msgid "Open Task..." +msgstr "" + +#: ./lib/gui/menu.py:154 +msgid "Open recent" +msgstr "" + +#: ./lib/gui/menu.py:156 +msgid "Quit" +msgstr "" + +#: ./lib/gui/menu.py:211 +msgid "{} Task" +msgstr "" + +#: ./lib/gui/menu.py:223 +msgid "Clear recent files" +msgstr "" + +#: ./lib/gui/menu.py:391 +msgid "Check for updates..." +msgstr "" + +#: ./lib/gui/menu.py:394 +msgid "Update Faceswap..." +msgstr "" + +#: ./lib/gui/menu.py:398 +msgid "Switch Branch" +msgstr "" + +#: ./lib/gui/menu.py:401 +msgid "Resources" +msgstr "" + +#: ./lib/gui/menu.py:404 +msgid "Output System Information" +msgstr "" + +#: ./lib/gui/menu.py:589 +msgid "currently selected Task" +msgstr "" + +#: ./lib/gui/menu.py:589 +msgid "Project" +msgstr "" + +#: ./lib/gui/menu.py:591 +msgid "Reload {} from disk" +msgstr "" + +#: ./lib/gui/menu.py:593 +msgid "Create a new {}..." +msgstr "" + +#: ./lib/gui/menu.py:595 +msgid "Reset {} to default" +msgstr "" + +#: ./lib/gui/menu.py:597 +msgid "Save {}" +msgstr "" + +#: ./lib/gui/menu.py:599 +msgid "Save {} as..." +msgstr "" + +#: ./lib/gui/menu.py:603 +msgid " from a task or project file" +msgstr "" + +#: ./lib/gui/menu.py:604 +msgid "Load {}..." +msgstr "" + +#: ./lib/gui/menu.py:659 +msgid "Configure {} settings..." +msgstr "" diff --git a/locales/gui.tooltips.pot b/locales/gui.tooltips.pot new file mode 100644 index 0000000000..f6973d6152 --- /dev/null +++ b/locales/gui.tooltips.pot @@ -0,0 +1,193 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"POT-Creation-Date: 2021-03-22 18:37+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=cp1252\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: pygettext.py 1.5\n" + + +#: ./lib/gui/command.py:184 +msgid "Output command line options to the console" +msgstr "" + +#: ./lib/gui/command.py:195 +msgid "Run the {} script" +msgstr "" + +#: ./lib/gui/control_helper.py:1234 +msgid "Select a folder..." +msgstr "" + +#: ./lib/gui/control_helper.py:1235 ./lib/gui/control_helper.py:1236 +msgid "Select a file..." +msgstr "" + +#: ./lib/gui/control_helper.py:1237 +msgid "Select a folder of images..." +msgstr "" + +#: ./lib/gui/control_helper.py:1238 +msgid "Select a video..." +msgstr "" + +#: ./lib/gui/control_helper.py:1239 +msgid "Select a model folder..." +msgstr "" + +#: ./lib/gui/control_helper.py:1240 +msgid "Select one or more files..." +msgstr "" + +#: ./lib/gui/control_helper.py:1241 +msgid "Select a file or folder..." +msgstr "" + +#: ./lib/gui/control_helper.py:1242 +msgid "Select a save location..." +msgstr "" + +#: ./lib/gui/display.py:71 +msgid "Summary statistics for each training session" +msgstr "" + +#: ./lib/gui/display.py:113 +msgid "Preview updates every 5 seconds" +msgstr "" + +#: ./lib/gui/display.py:122 +msgid "Graph showing Loss vs Iterations" +msgstr "" + +#: ./lib/gui/display.py:125 +msgid "Training preview. Updated on every save iteration" +msgstr "" + +#: ./lib/gui/display_analysis.py:342 +msgid "Load/Refresh stats for the currently training session" +msgstr "" + +#: ./lib/gui/display_analysis.py:344 +msgid "Clear currently displayed session stats" +msgstr "" + +#: ./lib/gui/display_analysis.py:346 +msgid "Save session stats to csv" +msgstr "" + +#: ./lib/gui/display_analysis.py:348 +msgid "Load saved session stats" +msgstr "" + +#: ./lib/gui/display_command.py:94 +msgid "Preview updates at every model save. Click to refresh now." +msgstr "" + +#: ./lib/gui/display_command.py:261 +msgid "Graph updates at every model save. Click to refresh now." +msgstr "" + +#: ./lib/gui/display_command.py:275 +msgid "Display the raw loss data" +msgstr "" + +#: ./lib/gui/display_command.py:287 +msgid "Display the smoothed loss data" +msgstr "" + +#: ./lib/gui/display_command.py:294 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing." +msgstr "" + +#: ./lib/gui/display_command.py:324 +msgid "Set the number of iterations to display. 0 displays the full session." +msgstr "" + +#: ./lib/gui/display_page.py:238 +msgid "Save {}(s) to file" +msgstr "" + +#: ./lib/gui/display_page.py:250 +msgid "Enable or disable {} display" +msgstr "" + +#: ./lib/gui/popup_configure.py:209 +msgid "Close without saving" +msgstr "" + +#: ./lib/gui/popup_configure.py:210 +msgid "Save this page's config" +msgstr "" + +#: ./lib/gui/popup_configure.py:211 +msgid "Reset this page's config to default values" +msgstr "" + +#: ./lib/gui/popup_configure.py:213 +msgid "Save all settings for the currently selected config" +msgstr "" + +#: ./lib/gui/popup_configure.py:216 +msgid "Reset all settings for the currently selected config to default values" +msgstr "" + +#: ./lib/gui/popup_configure.py:538 +msgid "Select a plugin to configure:" +msgstr "" + +#: ./lib/gui/popup_session.py:191 +msgid "Display {}" +msgstr "" + +#: ./lib/gui/popup_session.py:342 +msgid "Refresh graph" +msgstr "" + +#: ./lib/gui/popup_session.py:344 +msgid "Save display data to csv" +msgstr "" + +#: ./lib/gui/popup_session.py:346 +msgid "Number of data points to sample for rolling average" +msgstr "" + +#: ./lib/gui/popup_session.py:348 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing" +msgstr "" + +#: ./lib/gui/popup_session.py:350 +msgid "Flatten data points that fall more than 1 standard deviation from the mean to the mean value." +msgstr "" + +#: ./lib/gui/popup_session.py:353 +msgid "Display rolling average of the data" +msgstr "" + +#: ./lib/gui/popup_session.py:355 +msgid "Smooth the data" +msgstr "" + +#: ./lib/gui/popup_session.py:357 +msgid "Display raw data" +msgstr "" + +#: ./lib/gui/popup_session.py:359 +msgid "Display polynormal data trend" +msgstr "" + +#: ./lib/gui/popup_session.py:361 +msgid "Set the data to display" +msgstr "" + +#: ./lib/gui/popup_session.py:363 +msgid "Change y-axis scale" +msgstr "" + diff --git a/locales/kr/LC_MESSAGES/faceswap.mo b/locales/kr/LC_MESSAGES/faceswap.mo new file mode 100644 index 0000000000..4613eb7345 Binary files /dev/null and b/locales/kr/LC_MESSAGES/faceswap.mo differ diff --git a/locales/kr/LC_MESSAGES/faceswap.po b/locales/kr/LC_MESSAGES/faceswap.po new file mode 100644 index 0000000000..c4829dac52 --- /dev/null +++ b/locales/kr/LC_MESSAGES/faceswap.po @@ -0,0 +1,34 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"POT-Creation-Date: 2021-02-18 23:48-0000\n" +"PO-Revision-Date: 2022-11-24 12:21+0900\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.2\n" + +#: faceswap.py:43 +msgid "Extract the faces from pictures or a video" +msgstr "그림들 또는 비디오에서 얼굴을 추출합니다" + +#: faceswap.py:44 +msgid "Train a model for the two faces A and B" +msgstr "얼굴들 A와 B에 대한 모델을 훈련시킵니다" + +#: faceswap.py:47 +msgid "Convert source pictures or video to a new one with the face swapped" +msgstr "원본 이미지 또는 비디오를 얼굴이 뒤바뀐 새로운 이미지 또는 영상으로 변환합니다" + +#: faceswap.py:48 +msgid "Launch the Faceswap Graphical User Interface" +msgstr "Faceswap GUI를 실행합니다" diff --git a/locales/kr/LC_MESSAGES/gui.menu.mo b/locales/kr/LC_MESSAGES/gui.menu.mo new file mode 100644 index 0000000000..2bab76f5b0 Binary files /dev/null and b/locales/kr/LC_MESSAGES/gui.menu.mo differ diff --git a/locales/kr/LC_MESSAGES/gui.menu.po b/locales/kr/LC_MESSAGES/gui.menu.po new file mode 100644 index 0000000000..b20b5dca54 --- /dev/null +++ b/locales/kr/LC_MESSAGES/gui.menu.po @@ -0,0 +1,155 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-06-07 13:54+0100\n" +"PO-Revision-Date: 2023-06-07 14:11+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.3.1\n" + +#: lib/gui/menu.py:37 +msgid "faceswap.dev - Guides and Forum" +msgstr "faceswap.dev - Guides and Forum" + +#: lib/gui/menu.py:38 +msgid "Patreon - Support this project" +msgstr "Patreon - Support this project" + +#: lib/gui/menu.py:39 +msgid "Discord - The FaceSwap Discord server" +msgstr "Discord - The FaceSwap Discord server" + +#: lib/gui/menu.py:40 +msgid "Github - Our Source Code" +msgstr "Github - Our Source Code" + +#: lib/gui/menu.py:60 +msgid "File" +msgstr "" + +#: lib/gui/menu.py:61 +msgid "Settings" +msgstr "" + +#: lib/gui/menu.py:62 +msgid "Help" +msgstr "" + +#: lib/gui/menu.py:85 +msgid "Configure Settings..." +msgstr "" + +#: lib/gui/menu.py:116 +msgid "New Project..." +msgstr "" + +#: lib/gui/menu.py:121 +msgid "Open Project..." +msgstr "" + +#: lib/gui/menu.py:126 +msgid "Save Project" +msgstr "" + +#: lib/gui/menu.py:131 +msgid "Save Project as..." +msgstr "" + +#: lib/gui/menu.py:136 +msgid "Reload Project from Disk" +msgstr "" + +#: lib/gui/menu.py:141 +msgid "Close Project" +msgstr "" + +#: lib/gui/menu.py:147 +msgid "Open Task..." +msgstr "" + +#: lib/gui/menu.py:154 +msgid "Open recent" +msgstr "" + +#: lib/gui/menu.py:156 +msgid "Quit" +msgstr "" + +#: lib/gui/menu.py:211 +msgid "{} Task" +msgstr "" + +#: lib/gui/menu.py:223 +msgid "Clear recent files" +msgstr "" + +#: lib/gui/menu.py:391 +msgid "Check for updates..." +msgstr "" + +#: lib/gui/menu.py:394 +msgid "Update Faceswap..." +msgstr "" + +#: lib/gui/menu.py:398 +msgid "Switch Branch" +msgstr "" + +#: lib/gui/menu.py:401 +msgid "Resources" +msgstr "" + +#: lib/gui/menu.py:404 +msgid "Output System Information" +msgstr "" + +#: lib/gui/menu.py:589 +msgid "currently selected Task" +msgstr "현재 선택된 작업" + +#: lib/gui/menu.py:589 +msgid "Project" +msgstr "프로젝트" + +#: lib/gui/menu.py:591 +msgid "Reload {} from disk" +msgstr "디스크에서 {}를 다시 가져옵니다" + +#: lib/gui/menu.py:593 +msgid "Create a new {}..." +msgstr "새로운 {}를 만들기." + +#: lib/gui/menu.py:595 +msgid "Reset {} to default" +msgstr "{} 기본으로 재설정" + +#: lib/gui/menu.py:597 +msgid "Save {}" +msgstr "{} 저장" + +#: lib/gui/menu.py:599 +msgid "Save {} as..." +msgstr "{}를 다른 이름으로 저장." + +#: lib/gui/menu.py:603 +msgid " from a task or project file" +msgstr " 작업 또는 프로젝트 파일에서" + +#: lib/gui/menu.py:604 +msgid "Load {}..." +msgstr "{} 가져오기." + +#: lib/gui/menu.py:659 +msgid "Configure {} settings..." +msgstr "{} 세팅 설정하기." diff --git a/locales/kr/LC_MESSAGES/gui.tooltips.mo b/locales/kr/LC_MESSAGES/gui.tooltips.mo new file mode 100644 index 0000000000..bce4cb2148 Binary files /dev/null and b/locales/kr/LC_MESSAGES/gui.tooltips.mo differ diff --git a/locales/kr/LC_MESSAGES/gui.tooltips.po b/locales/kr/LC_MESSAGES/gui.tooltips.po new file mode 100644 index 0000000000..16c1631fcc --- /dev/null +++ b/locales/kr/LC_MESSAGES/gui.tooltips.po @@ -0,0 +1,205 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"POT-Creation-Date: 2021-03-22 18:37+0000\n" +"PO-Revision-Date: 2023-06-07 14:13+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.3.1\n" + +#: lib/gui/command.py:184 +msgid "Output command line options to the console" +msgstr "명령어 옵션들을 콘솔에 출력" + +#: lib/gui/command.py:195 +msgid "Run the {} script" +msgstr "{} 스크립트 실행" + +#: lib/gui/control_helper.py:1234 +msgid "Select a folder..." +msgstr "폴더 선택." + +#: lib/gui/control_helper.py:1235 lib/gui/control_helper.py:1236 +msgid "Select a file..." +msgstr "파일 선택." + +#: lib/gui/control_helper.py:1237 +msgid "Select a folder of images..." +msgstr "이미지들의 폴더 선택." + +#: lib/gui/control_helper.py:1238 +msgid "Select a video..." +msgstr "비디오 선택." + +#: lib/gui/control_helper.py:1239 +msgid "Select a model folder..." +msgstr "모델 폴더 선택하기." + +#: lib/gui/control_helper.py:1240 +msgid "Select one or more files..." +msgstr "하나 이상의 파일들 선택." + +#: lib/gui/control_helper.py:1241 +msgid "Select a file or folder..." +msgstr "파일 또는 폴더 선택." + +#: lib/gui/control_helper.py:1242 +msgid "Select a save location..." +msgstr "저장 위치 선택." + +#: lib/gui/display.py:71 +msgid "Summary statistics for each training session" +msgstr "각 훈련 세션들에 대한 통계 요약" + +#: lib/gui/display.py:113 +msgid "Preview updates every 5 seconds" +msgstr "5초마다 미리보기를 업데이트하기" + +#: lib/gui/display.py:122 +msgid "Graph showing Loss vs Iterations" +msgstr "반복에 따른 손실율 그래프" + +#: lib/gui/display.py:125 +msgid "Training preview. Updated on every save iteration" +msgstr "훈련 미리보기. 매 저장된 반복마다 업데이트됩니다" + +#: lib/gui/display_analysis.py:342 +msgid "Load/Refresh stats for the currently training session" +msgstr "현재 훈련 세션에 대한 통계 가져오기/새로고침" + +#: lib/gui/display_analysis.py:344 +msgid "Clear currently displayed session stats" +msgstr "현재 보여지는 세션 통계 지우기" + +#: lib/gui/display_analysis.py:346 +msgid "Save session stats to csv" +msgstr "세션 통계 csv로 저장하기" + +#: lib/gui/display_analysis.py:348 +msgid "Load saved session stats" +msgstr "저장된 세션 통계 가져오기" + +#: lib/gui/display_command.py:94 +msgid "Preview updates at every model save. Click to refresh now." +msgstr "" +"모델을 저장할 때마다 미리보기를 업데이트합니다. 지금 새로고침하기 위해 누르세" +"요." + +#: lib/gui/display_command.py:261 +msgid "Graph updates at every model save. Click to refresh now." +msgstr "" +"모델을 저장할 때마다 그래프를 업데이트합니다. 지금 새로고침하기 위해 누르세" +"요." + +#: lib/gui/display_command.py:275 +msgid "Display the raw loss data" +msgstr "원시 손실 데이터 보이기" + +#: lib/gui/display_command.py:287 +msgid "Display the smoothed loss data" +msgstr "매끄러운 손실 데이터 보이기" + +#: lib/gui/display_command.py:294 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing." +msgstr "" +"매끄러움 정도를 설정합니다. 0이면 매끄러움이 없고, 0.99이면 최대로 매끄러워집" +"니다." + +#: lib/gui/display_command.py:324 +msgid "Set the number of iterations to display. 0 displays the full session." +msgstr "" +"화면에 보여질 반복 횟수를 설정합니다. 0 displays는 모든 세션에서 보여줍니다." + +#: lib/gui/display_page.py:238 +msgid "Save {}(s) to file" +msgstr "{}(s)를 파일에 저장합니다" + +#: lib/gui/display_page.py:250 +msgid "Enable or disable {} display" +msgstr "{} display를 활성화 또는 비활성화" + +#: lib/gui/popup_configure.py:209 +msgid "Close without saving" +msgstr "저장하지 않고 닫기" + +#: lib/gui/popup_configure.py:210 +msgid "Save this page's config" +msgstr "이 페이지의 설정을 저장" + +#: lib/gui/popup_configure.py:211 +msgid "Reset this page's config to default values" +msgstr "이 페이지의 설정을 기본값으로 재설정" + +#: lib/gui/popup_configure.py:213 +msgid "Save all settings for the currently selected config" +msgstr "현재 선택된 모든 설정을 저장" + +#: lib/gui/popup_configure.py:216 +msgid "Reset all settings for the currently selected config to default values" +msgstr "현재 선택된 모든 설정을 기본값으로 재설정" + +#: lib/gui/popup_configure.py:538 +msgid "Select a plugin to configure:" +msgstr "구성할 플러그인 선택:" + +#: lib/gui/popup_session.py:191 +msgid "Display {}" +msgstr "{} 보이기" + +#: lib/gui/popup_session.py:342 +msgid "Refresh graph" +msgstr "그래프 새로고침" + +#: lib/gui/popup_session.py:344 +msgid "Save display data to csv" +msgstr "디스플레이 데이터를 csv로 저장" + +#: lib/gui/popup_session.py:346 +msgid "Number of data points to sample for rolling average" +msgstr "샘플의 이동평균 데이터 포인트 개수" + +#: lib/gui/popup_session.py:348 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing" +msgstr "" +"매끄러움 정도를 설정합니다. 0이면 매끄러움이 없고, 0.99이면 최대로 매끄러워집" +"니다" + +#: lib/gui/popup_session.py:350 +msgid "" +"Flatten data points that fall more than 1 standard deviation from the mean " +"to the mean value." +msgstr "평균에서 값까지 1 표준 편차보다 더 멀리 떨어진 데이터들 펴기." + +#: lib/gui/popup_session.py:353 +msgid "Display rolling average of the data" +msgstr "데이터의 이동평균 보이기" + +#: lib/gui/popup_session.py:355 +msgid "Smooth the data" +msgstr "데이터 매끄럽게 하기" + +#: lib/gui/popup_session.py:357 +msgid "Display raw data" +msgstr "원시 데이터 보이기" + +#: lib/gui/popup_session.py:359 +msgid "Display polynormal data trend" +msgstr "다항 데이터 트렌드 보이기" + +#: lib/gui/popup_session.py:361 +msgid "Set the data to display" +msgstr "데이터를 display에 설정하기" + +#: lib/gui/popup_session.py:363 +msgid "Change y-axis scale" +msgstr "변경합니다 y축의 범위를" diff --git a/locales/kr/LC_MESSAGES/lib.cli.args.mo b/locales/kr/LC_MESSAGES/lib.cli.args.mo new file mode 100644 index 0000000000..bc4b9fce78 Binary files /dev/null and b/locales/kr/LC_MESSAGES/lib.cli.args.mo differ diff --git a/locales/kr/LC_MESSAGES/lib.cli.args.po b/locales/kr/LC_MESSAGES/lib.cli.args.po new file mode 100644 index 0000000000..ec473df1f1 --- /dev/null +++ b/locales/kr/LC_MESSAGES/lib.cli.args.po @@ -0,0 +1,57 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 18:06+0000\n" +"PO-Revision-Date: 2024-03-28 18:17+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.4.2\n" + +#: lib/cli/args.py:188 lib/cli/args.py:199 lib/cli/args.py:208 +#: lib/cli/args.py:219 +msgid "Global Options" +msgstr "전역 옵션들" + +#: lib/cli/args.py:190 +msgid "" +"R|Exclude GPUs from use by Faceswap. Select the number(s) which correspond " +"to any GPU(s) that you do not wish to be made available to Faceswap. " +"Selecting all GPUs here will force Faceswap into CPU mode.\n" +"L|{}" +msgstr "" +"R|Faceswap에서 사용되는 GPUs를 제외합니다. Faceswap에서 사용되게 하고 싶지 " +"않은 GPU(s)에 해당하는 번호를 선택하세요. 모든 GPUs를 선택하면 Faceswap으로 " +"하여금 CPU mode를 강제로 사용하게 합니다.\n" +"L|{}" + +#: lib/cli/args.py:201 +msgid "" +"Optionally overide the saved config with the path to a custom config file." +msgstr "선택적으로 저장된 설정을 경로와 함께 개인 설정 파일에 덮어씌웁니다." + +#: lib/cli/args.py:210 +msgid "" +"Log level. Stick with INFO or VERBOSE unless you need to file an error " +"report. Be careful with TRACE as it will generate a lot of data" +msgstr "" +"로그 레벨. 오류 리포트가 필요하지 않다면 INFO와 VERBOSE를 사용하세요. 단, 굉" +"장히 많은 데이터를 생성할 수 있는 TRACE는 조심하세요" + +#: lib/cli/args.py:220 +msgid "Path to store the logfile. Leave blank to store in the faceswap folder" +msgstr "로그파일을 저장할 경로. faceswap 폴더에 저장하고 싶으면 비워두세요" + +#: lib/cli/args.py:319 +msgid "Output to Shell console instead of GUI console" +msgstr "결과를 GUI 콘솔이 아닌 쉘 콘솔에 출력합니다" diff --git a/locales/kr/LC_MESSAGES/lib.cli.args_extract_convert.mo b/locales/kr/LC_MESSAGES/lib.cli.args_extract_convert.mo new file mode 100644 index 0000000000..1f0c43722c Binary files /dev/null and b/locales/kr/LC_MESSAGES/lib.cli.args_extract_convert.mo differ diff --git a/locales/kr/LC_MESSAGES/lib.cli.args_extract_convert.po b/locales/kr/LC_MESSAGES/lib.cli.args_extract_convert.po new file mode 100644 index 0000000000..a504625624 --- /dev/null +++ b/locales/kr/LC_MESSAGES/lib.cli.args_extract_convert.po @@ -0,0 +1,656 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-12 11:56+0100\n" +"PO-Revision-Date: 2024-04-12 12:00+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.4.2\n" + +#: lib/cli/args_extract_convert.py:46 lib/cli/args_extract_convert.py:56 +#: lib/cli/args_extract_convert.py:64 lib/cli/args_extract_convert.py:122 +#: lib/cli/args_extract_convert.py:483 lib/cli/args_extract_convert.py:492 +msgid "Data" +msgstr "데이터" + +#: lib/cli/args_extract_convert.py:48 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file. NB: This should be the source video/" +"frames NOT the source faces." +msgstr "" +"폴더나 비디오를 입력하세요. 당신이 사용하고 싶은 이미지 파일들을 가진 폴더 또" +"는 비디오 파일의 경로여야 합니다. NB: 이 폴더는 원본 비디오여야 합니다." + +#: lib/cli/args_extract_convert.py:57 +msgid "Output directory. This is where the converted files will be saved." +msgstr "출력 폴더. 변환된 파일들이 저장될 곳입니다." + +#: lib/cli/args_extract_convert.py:66 +msgid "" +"Optional path to an alignments file. Leave blank if the alignments file is " +"at the default location." +msgstr "" +"(선택적) alignments 파일의 경로. 비워두면 alignments 파일이 기본 위치에 저장" +"됩니다." + +#: lib/cli/args_extract_convert.py:97 +msgid "" +"Extract faces from image or video sources.\n" +"Extraction plugins can be configured in the 'Settings' Menu" +msgstr "" +"얼굴들을 이미지 또는 비디오에서 추출합니다.\n" +"추출 플러그인은 '설정' 메뉴에서 설정할 수 있습니다" + +#: lib/cli/args_extract_convert.py:124 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple videos and/or folders of images you wish to extract from. The faces " +"will be output to separate sub-folders in the output_dir." +msgstr "" +"R|만약 선택된다면 input_dir은 당신이 추출하고자 하는 여러개의 비디오 그리고/" +"또는 이미지들을 가진 부모 폴더가 되야 합니다. 얼굴들은 output_dir에 분리된 하" +"위 폴더에 저장됩니다." + +#: lib/cli/args_extract_convert.py:133 lib/cli/args_extract_convert.py:152 +#: lib/cli/args_extract_convert.py:167 lib/cli/args_extract_convert.py:206 +#: lib/cli/args_extract_convert.py:224 lib/cli/args_extract_convert.py:237 +#: lib/cli/args_extract_convert.py:247 lib/cli/args_extract_convert.py:257 +#: lib/cli/args_extract_convert.py:503 lib/cli/args_extract_convert.py:529 +#: lib/cli/args_extract_convert.py:568 +msgid "Plugins" +msgstr "플러그인들" + +#: lib/cli/args_extract_convert.py:135 +msgid "" +"R|Detector to use. Some of these have configurable settings in '/config/" +"extract.ini' or 'Settings > Configure Extract 'Plugins':\n" +"L|cv2-dnn: A CPU only extractor which is the least reliable and least " +"resource intensive. Use this if not using a GPU and time is important.\n" +"L|mtcnn: Good detector. Fast on CPU, faster on GPU. Uses fewer resources " +"than other GPU detectors but can often return more false positives.\n" +"L|s3fd: Best detector. Slow on CPU, faster on GPU. Can detect more faces and " +"fewer false positives than other GPU detectors, but is a lot more resource " +"intensive.\n" +"L|external: Import a face detection bounding box from a json file. " +"(configurable in Detect settings)" +msgstr "" +"R|사용할 감지기. 몇몇 감지기들은 '/config/extract.ini' 또는 '설정 > 추출 플러" +"그인 설정'에서 설정이 가능합니다:\n" +"L|cv2-dnn: 가장 믿을 수 없고 가장 자원을 덜 사용하며 CPU만을 사용하는 추출기" +"입니다. 만약 GPU를 사용하지 않고 시간이 중요하다면 사용하세요.\n" +"L|mtcnn: 좋은 감지기. CPU에서도 빠르고 GPU에서도 빠릅니다. 다른 GPU 감지기들" +"보다 더 적은 자원을 사용하지만 가끔 더 많은 false positives를 돌려줄 수 있습" +"니다.\n" +"L|s3fd: 가장 좋은 감지기. CPU에선 느리고 GPU에선 빠릅니다. 다른 GPU 감지기들" +"보다 더 많은 얼굴들을 감지할 수 있고 과 더 적은 false positives를 돌려주지만 " +"자원을 굉장히 많이 사용합니다.\n" +"L|external: JSON 파일에서 얼굴 감지 경계 박스를 가져옵니다. (설정 감지에서 구" +"성 가능)" + +#: lib/cli/args_extract_convert.py:154 +msgid "" +"R|Aligner to use.\n" +"L|cv2-dnn: A CPU only landmark detector. Faster, less resource intensive, " +"but less accurate. Only use this if not using a GPU and time is important.\n" +"L|fan: Best aligner. Fast on GPU, slow on CPU.\n" +"L|external: Import 68 point 2D landmarks or an aligned bounding box from a " +"json file. (configurable in Align settings)" +msgstr "" +"R|사용할 Aligner.\n" +"L|cv2-dnn: CPU만을 사용하는 특징점 감지기. 빠르고 자원을 덜 사용하지만 부정확" +"합니다. GPU를 사용하지 않고 시간이 중요할 때에만 사용하세요.\n" +"L|fan: 가장 좋은 aligner. GPU에선 빠르고 CPU에선 느립니다.\n" +"L|external: JSON 파일에서 68 포인트 2D 랜드 마크 또는 정렬 된 경계 상자를 가" +"져옵니다. (정렬 설정에서 구성 가능)" + +#: lib/cli/args_extract_convert.py:169 +msgid "" +"R|Additional Masker(s) to use. The masks generated here will all take up GPU " +"RAM. You can select none, one or multiple masks, but the extraction may take " +"longer the more you select. NB: The Extended and Components (landmark based) " +"masks are automatically generated on extraction.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU so will not use any additional VRAM.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"The auto generated masks are as follows:\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" +msgstr "" +"R|사용할 추가 Mask입니다. 여기서 생성된 마스크는 모두 GPU RAM을 차지합니다. " +"마스크를 0개, 1개 또는 여러 개 선택할 수 있지만 더 많이 선택할수록 추출에 시" +"간이 더 걸릴 수 있습니다. NB: 확장 및 구성 요소(특징점 기반) 마스크는 추출 " +"시 자동으로 생성됩니다.\n" +"L|bisnet-fp: 전체 헤드 마스킹(마스크 설정에서 구성 가능)을 포함하여 마스킹할 " +"영역에 대한 보다 정교한 제어를 제공하는 비교적 가벼운 NN 기반 마스크입니다.\n" +"L|custom: 마스크 영역을 모든 1 또는 0으로 채우는 dummy 마스크입니다(설정에서 " +"구성 가능). 수동 도구에서 사용자 정의 마스크를 직접 수동으로 편집하려는 경우" +"에만 필요합니다. 이 마스크는 GPU를 사용하지 않으므로 추가 VRAM을 사용하지 않" +"습니다.\n" +"L|vgg-clear: 대부분의 정면에 장애물이 없는 스마트한 분할을 제공하도록 설계된 " +"마스크입니다. 프로필 얼굴들 및 장애물들로 인해 성능이 저하될 수 있습니다.\n" +"L|vgg-obstructed: 대부분의 정면 얼굴을 스마트하게 분할할 수 있도록 설계된 마" +"스크입니다. 마스크 모델은 일부 안면 장애물(손과 안경)을 인식하도록 특별히 훈" +"련되었습니다. 프로필 얼굴들은 평균 이하의 성능을 초래할 수 있습니다.\n" +"L|unet-dfl: 대부분 정면 얼굴을 스마트하게 분할하도록 설계된 마스크. 마스크 모" +"델은 커뮤니티 구성원들에 의해 훈련되었으며 추가 설명을 위해 테스트가 필요하" +"다. 프로필 얼굴들은 평균 이하의 성능을 초래할 수 있습니다.\n" +"자동 생성 마스크는 다음과 같습니다.\n" +"L|components: 특징점 위치의 위치를 기반으로 얼굴 분할을 제공하도록 설계된 마" +"스크입니다. 특징점의 외부에는 마스크를 만들기 위해 convex hull가 형성되어 있" +"습니다.\n" +"L|extended: 특징점 위치의 위치를 기반으로 얼굴 분할을 제공하도록 설계된 마스" +"크입니다. 특징점의 외부에는 convex hull가 형성되어 있으며, 마스크는 이마 위" +"로 뻗어 있습ㄴ다.\n" +"(예: '-M unet-dfl vgg-clear', '--masker vgg-obstructed')" + +#: lib/cli/args_extract_convert.py:208 +msgid "" +"R|Performing normalization can help the aligner better align faces with " +"difficult lighting conditions at an extraction speed cost. Different methods " +"will yield different results on different sets. NB: This does not impact the " +"output face, just the input to the aligner.\n" +"L|none: Don't perform normalization on the face.\n" +"L|clahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"L|hist: Equalize the histograms on the RGB channels.\n" +"L|mean: Normalize the face colors to the mean." +msgstr "" +"R|정규화를 수행하면 aligner가 추출 속도 비용으로 어려운 조명 조건의 얼굴을 " +"더 잘 정렬할 수 있습니다. 방법이 다르면 세트마다 결과가 다릅니다. NB: 출력 얼" +"굴에는 영향을 주지 않으며 aligner에 대한 입력에만 영향을 줍니다.\n" +"L|none: 얼굴에 정규화를 수행하지 마십시오.\n" +"L|clahe: 얼굴에 Contrast Limited Adaptive Histogram Equalization를 수행합니" +"다.\n" +"L|hist: RGB 채널의 히스토그램을 동일하게 합니다.\n" +"L|mean: 얼굴 색상을 평균으로 정규화합니다." + +#: lib/cli/args_extract_convert.py:226 +msgid "" +"The number of times to re-feed the detected face into the aligner. Each time " +"the face is re-fed into the aligner the bounding box is adjusted by a small " +"amount. The final landmarks are then averaged from each iteration. Helps to " +"remove 'micro-jitter' but at the cost of slower extraction speed. The more " +"times the face is re-fed into the aligner, the less micro-jitter should " +"occur but the longer extraction will take." +msgstr "" +"검출된 얼굴을 aligner에 다시 공급하는 횟수입니다. 얼굴이 aligner에 다시 공급" +"될 때마다 경계 상자가 소량 조정됩니다. 그런 다음 각 반복에서 최종 특징점의 평" +"균을 구한다. 'micro-jitter'를 제거하는 데 도움이 되지만 추출 속도가 느려집니" +"다. 얼굴이 aligner에 다시 공급되는 횟수가 많을수록 micro-jitter 적게 발생하지" +"만 추출에 더 오랜 시간이 걸립니다." + +#: lib/cli/args_extract_convert.py:239 +msgid "" +"Re-feed the initially found aligned face through the aligner. Can help " +"produce better alignments for faces that are rotated beyond 45 degrees in " +"the frame or are at extreme angles. Slows down extraction." +msgstr "" +"_aligner를 통해 처음 발견된 정렬된 얼굴을 재공급합니다. 프레임에서 45도 이상 " +"회전하거나 극단적인 각도에 있는 얼굴을 더 잘 정렬할 수 있습니다. 추출 속도가 " +"느려집니다." + +#: lib/cli/args_extract_convert.py:249 +msgid "" +"If a face isn't found, rotate the images to try to find a face. Can find " +"more faces at the cost of extraction speed. Pass in a single number to use " +"increments of that size up to 360, or pass in a list of numbers to enumerate " +"exactly what angles to check." +msgstr "" +"얼굴이 발견되지 않으면 이미지를 회전하여 얼굴을 찾습니다. 추출 속도를 희생하" +"면서 더 많은 얼굴을 찾을 수 있습니다. 단일 숫자를 입력하여 해당 크기의 증분" +"을 360까지 사용하거나 숫자 목록을 입력하여 확인할 각도를 정확하게 열거합니다." + +#: lib/cli/args_extract_convert.py:259 +msgid "" +"Obtain and store face identity encodings from VGGFace2. Slows down extract a " +"little, but will save time if using 'sort by face'" +msgstr "" +"VGGFace2에서 얼굴 식별 인코딩을 가져와 저장합니다. 추출 속도를 약간 늦추지만 " +"'얼굴별로 정렬'을 사용하면 시간을 절약할 수 있습니다." + +#: lib/cli/args_extract_convert.py:269 lib/cli/args_extract_convert.py:280 +#: lib/cli/args_extract_convert.py:293 lib/cli/args_extract_convert.py:307 +#: lib/cli/args_extract_convert.py:614 lib/cli/args_extract_convert.py:623 +#: lib/cli/args_extract_convert.py:638 lib/cli/args_extract_convert.py:651 +#: lib/cli/args_extract_convert.py:665 +msgid "Face Processing" +msgstr "얼굴 처리" + +#: lib/cli/args_extract_convert.py:271 +msgid "" +"Filters out faces detected below this size. Length, in pixels across the " +"diagonal of the bounding box. Set to 0 for off" +msgstr "" +"이 크기 미만으로 탐지된 얼굴을 필터링합니다. 길이, 경계 상자의 대각선에 걸친 " +"픽셀 단위입니다. 0으로 설정하면 꺼집니다" + +#: lib/cli/args_extract_convert.py:282 +msgid "" +"Optionally filter out people who you do not wish to extract by passing in " +"images of those people. Should be a small variety of images at different " +"angles and in different conditions. A folder containing the required images " +"or multiple image files, space separated, can be selected." +msgstr "" +"선택적으로 추출하지 않을 사람의 이미지들을 전달하여 그 사람들을 제외합니다. " +"각도와 조건이 다른 작은 다양한 이미지여야 합니다. 추출되지 않는데 필요한 이미" +"지들 또는 공백으로 구분된 여러 이미지 파일이 들어 있는 폴더를 선택할 수 있습" +"니다." + +#: lib/cli/args_extract_convert.py:295 +msgid "" +"Optionally select people you wish to extract by passing in images of that " +"person. Should be a small variety of images at different angles and in " +"different conditions A folder containing the required images or multiple " +"image files, space separated, can be selected." +msgstr "" +"선택적으로 추출하고 싶은 사람의 이미지를 전달하여 그 사람을 선택합니다. 각도" +"와 조건이 다른 작은 다양한 이미지여야 합니다. 추출할 때 필요한 이미지들 또는 " +"공백으로 구분된 여러 이미지 파일이 들어 있는 폴더를 선택할 수 있습니다." + +#: lib/cli/args_extract_convert.py:309 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Higher values are stricter." +msgstr "" +"옵션인 nfilter/filter 파일과 함께 사용합니다. 긍정적인 얼굴 인식을 위한 임계" +"값. 값이 높을수록 엄격합니다." + +#: lib/cli/args_extract_convert.py:318 lib/cli/args_extract_convert.py:331 +#: lib/cli/args_extract_convert.py:344 lib/cli/args_extract_convert.py:356 +msgid "output" +msgstr "출력" + +#: lib/cli/args_extract_convert.py:320 +msgid "" +"The output size of extracted faces. Make sure that the model you intend to " +"train supports your required size. This will only need to be changed for hi-" +"res models." +msgstr "" +"추출된 얼굴의 출력 크기입니다. 훈련하려는 모델이 필요한 크기를 지원하는지 꼭 " +"확인하세요. 이것은 고해상도 모델에 대해서만 변경하면 됩니다." + +#: lib/cli/args_extract_convert.py:333 +msgid "" +"Extract every 'nth' frame. This option will skip frames when extracting " +"faces. For example a value of 1 will extract faces from every frame, a value " +"of 10 will extract faces from every 10th frame." +msgstr "" +"모든 'n번째' 프레임을 추출합니다. 이 옵션은 얼굴을 추출할 때 건너뛸 프레임을 " +"설정합니다. 예를 들어, 값이 1이면 모든 프레임에서 얼굴이 추출되고, 값이 10이" +"면 모든 10번째 프레임에서 얼굴이 추출됩니다." + +#: lib/cli/args_extract_convert.py:346 +msgid "" +"Automatically save the alignments file after a set amount of frames. By " +"default the alignments file is only saved at the end of the extraction " +"process. NB: If extracting in 2 passes then the alignments file will only " +"start to be saved out during the second pass. WARNING: Don't interrupt the " +"script when writing the file because it might get corrupted. Set to 0 to " +"turn off" +msgstr "" +"프레임 수가 설정된 후 alignments 파일을 자동으로 저장합니다. 기본적으로 " +"alignments 파일은 추출 프로세스가 끝날 때만 저장됩니다. NB: 2번째 추출에서 성" +"공하면 두 번째 추출 중에만 alignments 파일이 저장되기 시작합니다. 경고: 파일" +"을 쓸 때 스크립트가 손상될 수 있으므로 스크립트를 중단하지 마십시오. 해제하려" +"면 0으로 설정" + +#: lib/cli/args_extract_convert.py:357 +msgid "Draw landmarks on the ouput faces for debugging purposes." +msgstr "디버깅을 위해 출력 얼굴에 특징점을 그립니다." + +#: lib/cli/args_extract_convert.py:363 lib/cli/args_extract_convert.py:373 +#: lib/cli/args_extract_convert.py:381 lib/cli/args_extract_convert.py:388 +#: lib/cli/args_extract_convert.py:678 lib/cli/args_extract_convert.py:691 +#: lib/cli/args_extract_convert.py:712 lib/cli/args_extract_convert.py:718 +msgid "settings" +msgstr "설정" + +#: lib/cli/args_extract_convert.py:365 +msgid "" +"Don't run extraction in parallel. Will run each part of the extraction " +"process separately (one after the other) rather than all at the same time. " +"Useful if VRAM is at a premium." +msgstr "" +"추출을 병렬로 실행하지 마십시오. 추출 프로세스의 각 부분을 동시에 모두 실행하" +"는 것이 아니라 개별적으로(하나씩) 실행합니다. VRAM이 프리미엄인 경우 유용합니" +"다." + +#: lib/cli/args_extract_convert.py:375 +msgid "" +"Skips frames that have already been extracted and exist in the alignments " +"file" +msgstr "이미 추출되었거나 alignments 파일에 존재하는 프레임들을 스킵합니다" + +#: lib/cli/args_extract_convert.py:382 +msgid "Skip frames that already have detected faces in the alignments file" +msgstr "이미 얼굴을 탐지하여 alignments 파일에 존재하는 프레임들을 스킵합니다" + +#: lib/cli/args_extract_convert.py:389 +msgid "Skip saving the detected faces to disk. Just create an alignments file" +msgstr "" +"탐지된 얼굴을 디스크에 저장하지 않습니다. 그저 alignments 파일을 만듭니다" + +#: lib/cli/args_extract_convert.py:463 +msgid "" +"Swap the original faces in a source video/images to your final faces.\n" +"Conversion plugins can be configured in the 'Settings' Menu" +msgstr "" +"원본 비디오/이미지의 원래 얼굴을 최종 얼굴으로 바꿉니다.\n" +"변환 플러그인은 '설정' 메뉴에서 구성할 수 있습니다" + +#: lib/cli/args_extract_convert.py:485 +msgid "" +"Only required if converting from images to video. Provide The original video " +"that the source frames were extracted from (for extracting the fps and " +"audio)." +msgstr "" +"이미지에서 비디오로 변환하는 경우에만 필요합니다. 소스 프레임이 추출된 원본 " +"비디오(fps 및 오디오 추출용)를 입력하세요." + +#: lib/cli/args_extract_convert.py:494 +msgid "" +"Model directory. The directory containing the trained model you wish to use " +"for conversion." +msgstr "" +"모델 폴더. 당신이 변환에 사용하고자 하는 훈련된 모델을 가진 폴더입니다." + +#: lib/cli/args_extract_convert.py:505 +msgid "" +"R|Performs color adjustment to the swapped face. Some of these options have " +"configurable settings in '/config/convert.ini' or 'Settings > Configure " +"Convert Plugins':\n" +"L|avg-color: Adjust the mean of each color channel in the swapped " +"reconstruction to equal the mean of the masked area in the original image.\n" +"L|color-transfer: Transfers the color distribution from the source to the " +"target image using the mean and standard deviations of the L*a*b* color " +"space.\n" +"L|manual-balance: Manually adjust the balance of the image in a variety of " +"color spaces. Best used with the Preview tool to set correct values.\n" +"L|match-hist: Adjust the histogram of each color channel in the swapped " +"reconstruction to equal the histogram of the masked area in the original " +"image.\n" +"L|seamless-clone: Use cv2's seamless clone function to remove extreme " +"gradients at the mask seam by smoothing colors. Generally does not give very " +"satisfactory results.\n" +"L|none: Don't perform color adjustment." +msgstr "" +"R|스왑된 얼굴의 색상 조정을 수행합니다. 이러한 옵션 중 일부에는 '/config/" +"convert.ini' 또는 '설정 > 변환 플러그인 구성'에서 구성 가능한 설정이 있습니" +"다.\n" +"L|avg-color: 스왑된 재구성에서 각 색상 채널의 평균이 원본 영상에서 마스킹된 " +"영역의 평균과 동일하도록 조정합니다.\n" +"L|color-transfer: L*a*b* 색 공간의 평균 및 표준 편차를 사용하여 소스에서 대" +"상 이미지로 색 분포를 전송합니다.\n" +"L|manual-balance: 다양한 색 공간에서 이미지의 밸런스를 수동으로 조정합니다. " +"올바른 값을 설정하려면 미리 보기 도구와 함께 사용하는 것이 좋습니다.\n" +"L|match-hist: 스왑된 재구성에서 각 색상 채널의 히스토그램을 조정하여 원래 영" +"상에서 마스킹된 영역의 히스토그램과 동일하게 만듭니다.\n" +"L|seamless-clone: cv2의 원활한 복제 기능을 사용하여 색상을 평활화하여 마스크 " +"심에서 극단적인 gradients을 제거합니다. 일반적으로 매우 만족스러운 결과를 제" +"공하지 않습니다.\n" +"L|none: 색상 조정을 수행하지 않습니다." + +#: lib/cli/args_extract_convert.py:531 +msgid "" +"R|Masker to use. NB: The mask you require must exist within the alignments " +"file. You can add additional masks with the Mask Tool.\n" +"L|none: Don't use a mask.\n" +"L|bisenet-fp_face: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'face' or " +"'legacy' centering.\n" +"L|bisenet-fp_head: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'head' " +"centering.\n" +"L|custom_face: Custom user created, face centered mask.\n" +"L|custom_head: Custom user created, head centered mask.\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"L|predicted: If the 'Learn Mask' option was enabled during training, this " +"will use the mask that was created by the trained model." +msgstr "" +"R|사용할 마스크. NB: 필요한 마스크는 alignments 파일 내에 있어야 합니다. 마스" +"크 도구를 사용하여 마스크를 추가할 수 있습니다.\n" +"L|none: 마스크 쓰지 마세요.\n" +"L|bisnet-fp_face: 마스크할 영역을 보다 정교하게 제어할 수 있는 비교적 가벼운 " +"NN 기반 마스크입니다(마스크 설정에서 구성 가능). 모델이 '얼굴' 또는 '레거시' " +"중심으로 훈련된 경우 이 버전의 bisnet-fp를 사용하십시오.\n" +"L|bisnet-fp_head: 마스크할 영역을 보다 정교하게 제어할 수 있는 비교적 가벼운 " +"NN 기반 마스크입니다(마스크 설정에서 구성 가능). 모델이 '헤드' 중심으로 훈련" +"된 경우 이 버전의 bisnet-fp를 사용하십시오.\n" +"L|custom_face: 사용자 지정 사용자가 생성한 얼굴 중심 마스크입니다.\n" +"L|custom_head: 사용자 지정 사용자가 생성한 머리 중심 마스크입니다.\n" +"L|components: 특징점 위치의 배치를 기반으로 얼굴 분할을 제공하도록 설계된 마" +"스크입니다. 특징점의 외부에는 마스크를 만들기 위해 convex hull가 형성되어 있" +"습니다.\n" +"L|extended: 특징점 위치의 배치를 기반으로 얼굴 분할을 제공하도록 설계된 마스" +"크입니다. 지형지물의 외부에는 convex hull가 형성되어 있으며, 마스크는 이마 위" +"로 뻗어 있습니다.\n" +"L|vgg-clear: 대부분의 정면에 장애물이 없는 스마트한 분할을 제공하도록 설계된 " +"마스크입니다. 옆 얼굴 및 장애물로 인해 성능이 저하될 수 있습니다.\n" +"L|vgg-obstructed: 대부분의 정면 얼굴을 스마트하게 분할할 수 있도록 설계된 마" +"스크입니다. 마스크 모델은 일부 안면 장애물(손과 안경)을 인식하도록 특별히 훈" +"련되었습니다. 옆 얼굴은 평균 이하의 성능을 초래할 수 있습니다.\n" +"L|unet-dfl: 대부분 정면 얼굴을 스마트하게 분할하도록 설계된 마스크. 마스크 모" +"델은 커뮤니티 구성원들에 의해 훈련되었으며 추가 설명을 위해 테스트가 필요하" +"다. 옆 얼굴은 평균 이하의 성능을 초래할 수 있습니다.\n" +"L|predicted: 교육 중에 'Learn Mask(마스크 학습)' 옵션이 활성화된 경우에는 교" +"육을 받은 모델이 만든 마스크가 사용됩니다." + +#: lib/cli/args_extract_convert.py:570 +msgid "" +"R|The plugin to use to output the converted images. The writers are " +"configurable in '/config/convert.ini' or 'Settings > Configure Convert " +"Plugins:'\n" +"L|ffmpeg: [video] Writes out the convert straight to video. When the input " +"is a series of images then the '-ref' (--reference-video) parameter must be " +"set.\n" +"L|gif: [animated image] Create an animated gif.\n" +"L|opencv: [images] The fastest image writer, but less options and formats " +"than other plugins.\n" +"L|patch: [images] Outputs the raw swapped face patch, along with the " +"transformation matrix required to re-insert the face back into the original " +"frame. Use this option if you wish to post-process and composite the final " +"face within external tools.\n" +"L|pillow: [images] Slower than opencv, but has more options and supports " +"more formats." +msgstr "" +"R|변환된 이미지를 출력하는 데 사용할 플러그인입니다. 기록 장치는 '/config/" +"convert.ini' 또는 '설정 > 변환 플러그인 구성:'에서 구성할 수 있습니다.\n" +"L|ffmpeg: [video] 변환된 결과를 바로 video로 씁니다. 입력이 영상 시리즈인 경" +"우 '-ref'(--reference-video) 파라미터를 설정해야 합니다.\n" +"L|gif : [애니메이션 이미지] 애니메이션 gif를 만듭니다.\n" +"L|opencv: [이미지] 가장 빠른 이미지 작성기이지만 다른 플러그인에 비해 옵션과 " +"형식이 적습니다.\n" +"L|patch: [이미지] 원래 프레임에 얼굴을 다시 삽입하는 데 필요한 변환 행렬과 함" +"께 원시 교체된 얼굴 패치를 출력합니다.\n" +"L|pillow: [images] opencv보다 느리지만 더 많은 옵션이 있고 더 많은 형식을 지" +"원합니다." + +#: lib/cli/args_extract_convert.py:591 lib/cli/args_extract_convert.py:600 +#: lib/cli/args_extract_convert.py:703 +msgid "Frame Processing" +msgstr "프레임 처리" + +#: lib/cli/args_extract_convert.py:593 +#, python-format +msgid "" +"Scale the final output frames by this amount. 100%% will output the frames " +"at source dimensions. 50%% at half size 200%% at double size" +msgstr "" +"최종 출력 프레임의 크기를 이 양만큼 조정합니다. 100%%는 원본의 차원에서 프레" +"임을 출력합니다. 50%%는 절반 크기에서, 200%%는 두 배 크기에서" + +#: lib/cli/args_extract_convert.py:602 +msgid "" +"Frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use " +"--frame-ranges 10-50 90-100. Frames falling outside of the selected range " +"will be discarded unless '-k' (--keep-unchanged) is selected. NB: If you are " +"converting from images, then the filenames must end with the frame-number!" +msgstr "" +"예를 들어 전송을 적용할 프레임 범위 프레임 10 - 50 및 90 - 100의 경우 --" +"frame-ranges 10-50 90-100을 사용합니다. '-k'(--keep-unchanged)를 선택하지 않" +"으면 선택한 범위를 벗어나는 프레임이 삭제됩니다. NB: 이미지에서 변환하는 경" +"우 파일 이름은 프레임 번호로 끝나야 합니다!" + +#: lib/cli/args_extract_convert.py:616 +msgid "" +"Scale the swapped face by this percentage. Positive values will enlarge the " +"face, Negative values will shrink the face." +msgstr "" +"이 백분율로 교체된 면의 크기를 조정합니다. 양수 값은 얼굴을 확대하고, 음수 값" +"은 얼굴을 축소합니다." + +#: lib/cli/args_extract_convert.py:625 +msgid "" +"If you have not cleansed your alignments file, then you can filter out faces " +"by defining a folder here that contains the faces extracted from your input " +"files/video. If this folder is defined, then only faces that exist within " +"your alignments file and also exist within the specified folder will be " +"converted. Leaving this blank will convert all faces that exist within the " +"alignments file." +msgstr "" +"만약 alignments 파일을 지우지 않은 경우 입력 파일/비디오에서 추출된 얼굴이 포" +"함된 폴더를 정의하여 얼굴을 걸러낼 수 있습니다. 이 폴더가 정의된 경우 " +"alignments 파일 내에 존재하거나 지정된 폴더 내에 존재하는 얼굴만 변환됩니다. " +"이 항목을 공백으로 두면 alignments 파일 내에 있는 모든 얼굴이 변환됩니다." + +#: lib/cli/args_extract_convert.py:640 +msgid "" +"Optionally filter out people who you do not wish to process by passing in an " +"image of that person. Should be a front portrait with a single person in the " +"image. Multiple images can be added space separated. NB: Using face filter " +"will significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"선택적으로 처리하고 싶지 않은 사람의 이미지를 전달하여 그 사람을 걸러낼 수 있" +"습니다. 이미지는 한 사람의 정면 모습이여야 합니다. 여러 이미지를 공백으로 구" +"분하여 추가할 수 있습니다. 주의: 얼굴 필터를 사용하면 추출 속도가 현저히 감소" +"하므로 정확성을 보장할 수 없습니다." + +#: lib/cli/args_extract_convert.py:653 +msgid "" +"Optionally select people you wish to process by passing in an image of that " +"person. Should be a front portrait with a single person in the image. " +"Multiple images can be added space separated. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"선택적으로 해당 사용자의 이미지를 전달하여 처리할 사용자를 선택합니다. 이미지" +"에 한 사람이 있는 정면 초상화여야 합니다. 여러 이미지를 공백으로 구분하여 추" +"가할 수 있습니다. 주의: 얼굴 필터를 사용하면 추출 속도가 현저히 감소하므로 정" +"확성을 보장할 수 없습니다." + +#: lib/cli/args_extract_convert.py:667 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Lower values are stricter. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"옵션인 nfilter/filter 파일을 함께 사용합니다. 긍정적인 얼굴 인식을 위한 임계" +"값. 낮은 값이 더 엄격합니다. 주의: 얼굴 필터를 사용하면 추출 속도가 현저히 감" +"소하므로 정확성을 보장할 수 없습니다." + +#: lib/cli/args_extract_convert.py:680 +msgid "" +"The maximum number of parallel processes for performing conversion. " +"Converting images is system RAM heavy so it is possible to run out of memory " +"if you have a lot of processes and not enough RAM to accommodate them all. " +"Setting this to 0 will use the maximum available. No matter what you set " +"this to, it will never attempt to use more processes than are available on " +"your system. If singleprocess is enabled this setting will be ignored." +msgstr "" +"변환을 수행하기 위한 최대 병렬 프로세스 수입니다. 이미지 변환은 시스템 RAM에 " +"부담이 크기 때문에 프로세스가 많고 모든 프로세스를 수용할 RAM이 충분하지 않" +"은 경우 메모리가 부족할 수 있습니다. 이것을 0으로 설정하면 사용 가능한 최대값" +"을 사용합니다. 얼마를 설정하든 시스템에서 사용 가능한 것보다 더 많은 프로세스" +"를 사용하려고 시도하지 않습니다. 단일 프로세스가 활성화된 경우 이 설정은 무시" +"됩니다." + +#: lib/cli/args_extract_convert.py:693 +msgid "" +"Enable On-The-Fly Conversion. NOT recommended. You should generate a clean " +"alignments file for your destination video. However, if you wish you can " +"generate the alignments on-the-fly by enabling this option. This will use an " +"inferior extraction pipeline and will lead to substandard results. If an " +"alignments file is found, this option will be ignored." +msgstr "" +"실시간 변환을 활성화합니다. 권장하지 않습니다. 당신은 변환 비디오에 대한 깨끗" +"한 alignments 파일을 생성해야 합니다. 그러나 원하는 경우 이 옵션을 활성화하" +"여 즉시 alignments 파일을 생성할 수 있습니다. 이것은 안좋은 추출 과정을 사용" +"하고 표준 이하의 결과로 이어질 것입니다. alignments 파일이 발견되면 이 옵션" +"은 무시됩니다." + +#: lib/cli/args_extract_convert.py:705 +msgid "" +"When used with --frame-ranges outputs the unchanged frames that are not " +"processed instead of discarding them." +msgstr "" +"사용시 --frame-ranges 인자를 사용하면 변경되지 않은 프레임을 버리지 않은 결과" +"가 출력됩니다." + +#: lib/cli/args_extract_convert.py:713 +msgid "Swap the model. Instead converting from of A -> B, converts B -> A" +msgstr "모델을 바꿉니다. A -> B에서 변환하는 대신 B -> A로 변환" + +#: lib/cli/args_extract_convert.py:719 +msgid "Disable multiprocessing. Slower but less resource intensive." +msgstr "멀티프로세싱을 쓰지 않습니다. 느리지만 자원을 덜 소모합니다." + +#~ msgid "" +#~ "[LEGACY] This only needs to be selected if a legacy model is being loaded " +#~ "or if there are multiple models in the model folder" +#~ msgstr "" +#~ "[LEGACY] 이것은 레거시 모델을 로드 중이거나 모델 폴더에 여러 모델이 있는 " +#~ "경우에만 선택되어야 합니다" diff --git a/locales/kr/LC_MESSAGES/lib.cli.args_train.mo b/locales/kr/LC_MESSAGES/lib.cli.args_train.mo new file mode 100644 index 0000000000..0576c21604 Binary files /dev/null and b/locales/kr/LC_MESSAGES/lib.cli.args_train.mo differ diff --git a/locales/kr/LC_MESSAGES/lib.cli.args_train.po b/locales/kr/LC_MESSAGES/lib.cli.args_train.po new file mode 100644 index 0000000000..1cc9ad1291 --- /dev/null +++ b/locales/kr/LC_MESSAGES/lib.cli.args_train.po @@ -0,0 +1,362 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-15 20:02+0000\n" +"PO-Revision-Date: 2025-12-19 23:26+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.8\n" + +#: lib/cli/args_train.py:30 +msgid "" +"Train a model on extracted original (A) and swap (B) faces.\n" +"Training models can take a long time. Anything from 24hrs to over a week\n" +"Model plugins can be configured in the 'Settings' Menu" +msgstr "" +"추출된 원래(A) 얼굴과 스왑(B) 얼굴에 대한 모델을 훈련합니다.\n" +"모델을 훈련하는 데 시간이 오래 걸릴 수 있습니다. 24시간에서 일주일 이상의 시" +"간이 필요합니다.\n" +"모델 플러그인은 '설정' 메뉴에서 구성할 수 있습니다" + +#: lib/cli/args_train.py:49 lib/cli/args_train.py:58 +msgid "faces" +msgstr "얼굴들" + +#: lib/cli/args_train.py:51 +msgid "" +"Input directory. A directory containing training images for face A. This is " +"the original face, i.e. the face that you want to remove and replace with " +"face B." +msgstr "" +"입력 디렉토리. 얼굴 A에 대한 훈련 이미지가 포함된 디렉토리입니다. 이것은 원" +"래 얼굴, 즉 제거하고 B 얼굴로 대체하려는 얼굴입니다." + +#: lib/cli/args_train.py:60 +msgid "" +"Input directory. A directory containing training images for face B. This is " +"the swap face, i.e. the face that you want to place onto the head of person " +"A." +msgstr "" +"입력 디렉터리. 얼굴 B에 대한 훈련 이미지를 포함하는 디렉토리. 이것은 대체 얼" +"굴, 즉 사람 A의 얼굴 앞에 배치하려는 얼굴이다." + +#: lib/cli/args_train.py:67 lib/cli/args_train.py:80 lib/cli/args_train.py:97 +#: lib/cli/args_train.py:123 lib/cli/args_train.py:133 +msgid "model" +msgstr "모델" + +#: lib/cli/args_train.py:69 +msgid "" +"Model directory. This is where the training data will be stored. You should " +"always specify a new folder for new models. If starting a new model, select " +"either an empty folder, or a folder which does not exist (which will be " +"created). If continuing to train an existing model, specify the location of " +"the existing model." +msgstr "" +"모델 디렉토리. 여기에 훈련 데이터가 저장됩니다. 새 모델의 경우 항상 새 폴더" +"를 지정해야 합니다. 새 모델을 시작할 경우 빈 폴더 또는 존재하지 않는 폴더(생" +"성될 폴더)를 선택합니다. 기존 모델을 계속 학습하는 경우 기존 모델의 위치를 지" +"정합니다." + +#: lib/cli/args_train.py:82 +msgid "" +"R|Load the weights from a pre-existing model into a newly created model. For " +"most models this will load weights from the Encoder of the given model into " +"the encoder of the newly created model. Some plugins may have specific " +"configuration options allowing you to load weights from other layers. " +"Weights will only be loaded when creating a new model. This option will be " +"ignored if you are resuming an existing model. Generally you will also want " +"to 'freeze-weights' whilst the rest of your model catches up with your " +"Encoder.\n" +"NB: Weights can only be loaded from models of the same plugin as you intend " +"to train." +msgstr "" +"R|기존 모델의 가중치를 새로 생성된 모델로 로드합니다. 대부분의 모델에서는 주" +"어진 모델의 인코더에서 새로 생성된 모델의 인코더로 가중치를 로드합니다. 일부 " +"플러그인에는 다른 층에서 가중치를 로드할 수 있는 특정 구성 옵션이 있을 수 있" +"습니다. 가중치는 새 모델을 생성할 때만 로드됩니다. 기존 모델을 재개하는 경우 " +"이 옵션은 무시됩니다. 일반적으로 나머지 모델이 인코더를 따라잡는 동안에도 '가" +"중치 동결'이 필요합니다.\n" +"주의: 가중치는 훈련하려는 플러그인 모델에서만 로드할 수 있습니다." + +#: lib/cli/args_train.py:99 +msgid "" +"R|Select which trainer to use. Trainers can be configured from the Settings " +"menu or the config folder.\n" +"L|original: The original model created by /u/deepfakes.\n" +"L|dfaker: 64px in/128px out model from dfaker. Enable 'warp-to-landmarks' " +"for full dfaker method.\n" +"L|dfl-h128: 128px in/out model from deepfacelab\n" +"L|dfl-sae: Adaptable model from deepfacelab\n" +"L|dlight: A lightweight, high resolution DFaker variant.\n" +"L|iae: A model that uses intermediate layers to try to get better details\n" +"L|lightweight: A lightweight model for low-end cards. Don't expect great " +"results. Can train as low as 1.6GB with batch size 8.\n" +"L|realface: A high detail, dual density model based on DFaker, with " +"customizable in/out resolution. The autoencoders are unbalanced so B>A swaps " +"won't work so well. By andenixa et al. Very configurable.\n" +"L|unbalanced: 128px in/out model from andenixa. The autoencoders are " +"unbalanced so B>A swaps won't work so well. Very configurable.\n" +"L|villain: 128px in/out model from villainguy. Very resource hungry (You " +"will require a GPU with a fair amount of VRAM). Good for details, but more " +"susceptible to color differences." +msgstr "" +"R|사용할 훈련 모델을 선택합니다. 훈련 모델은 설정 메뉴 또는 구성 폴더에서 구" +"성할 수 있습니다.\n" +"L|original: /u/deepfakes로 만든 원래 모델입니다.\n" +"L|dfaker: 64px in/128px out 모델 from dfaker. Full dfaker 메서드에 대해 '특징" +"점으로 변환'를 활성화합니다.\n" +"L|dfl-h128: Deepfake lab의 128px in/out 모델\n" +"L|dfl-sae: Deepface Lab의 적응형 모델\n" +"L|dlight: 경량, 고해상도 DFaker 변형입니다.\n" +"L|iae: 중간 층들을 사용하여 더 나은 세부 정보를 얻기 위해 노력하는 모델.\n" +"L|lightweight: 저가형 카드용 경량 모델. 좋은 결과를 기대하지 마세요. 최대한 " +"낮게 잡아서 배치 사이즈 8에 1.6GB까지 훈련이 가능합니다.\n" +"L|realface: DFaker를 기반으로 한 높은 디테일의 이중 밀도 모델로, 사용자 정의 " +"가능한 입/출력 해상도를 제공합니다. 오토인코더가 불균형하여 B>A 스왑이 잘 작" +"동하지 않습니다. Andenixa 등에 의해. 매우 구성 가능합니다.\n" +"L|unbalanced: andenixa의 128px in/out 모델. 오토인코더가 불균형하여 B>A 스왑" +"이 잘 작동하지 않습니다. 매우 구성 가능합니다.\n" +"L|villain : villainguy의 128px in/out 모델. 리소스가 매우 부족합니다( 상당한 " +"양의 VRAM이 있는 GPU가 필요합니다). 세부 사항에는 좋지만 색상 차이에 더 취약" +"합니다." + +#: lib/cli/args_train.py:125 +msgid "" +"Output a summary of the model and exit. If a model folder is provided then a " +"summary of the saved model is displayed. Otherwise a summary of the model " +"that would be created by the chosen plugin and configuration settings is " +"displayed." +msgstr "" +"모델 요약을 출력하고 종료합니다. 모델 폴더가 제공되면 저장된 모델의 요약이 표" +"시됩니다. 그렇지 않으면 선택한 플러그인 및 구성 설정에 의해 생성되는 모델 요" +"약이 표시됩니다." + +#: lib/cli/args_train.py:135 +msgid "" +"Freeze the weights of the model. Freezing weights means that some of the " +"parameters in the model will no longer continue to learn, but those that are " +"not frozen will continue to learn. For most models, this will freeze the " +"encoder, but some models may have configuration options for freezing other " +"layers." +msgstr "" +"모델의 가중치를 동결합니다. 가중치를 고정하면 모델의 일부 매개변수가 더 이상 " +"학습되지 않지만 고정되지 않은 매개변수는 계속 학습됩니다. 대부분의 모델에서 " +"이렇게 하면 인코더가 고정되지만 일부 모델에는 다른 레이어를 고정하기 위한 구" +"성 옵션이 있을 수 있습니다." + +#: lib/cli/args_train.py:147 lib/cli/args_train.py:160 +#: lib/cli/args_train.py:174 lib/cli/args_train.py:183 +#: lib/cli/args_train.py:190 lib/cli/args_train.py:199 +msgid "training" +msgstr "훈련" + +#: lib/cli/args_train.py:149 +msgid "" +"Batch size. This is the number of images processed through the model for " +"each side per iteration. NB: As the model is fed 2 sides at a time, the " +"actual number of images within the model at any one time is double the " +"number that you set here. Larger batches require more GPU RAM." +msgstr "" +"배치 크기. 반복당 각 측면에 대해 모델을 통해 처리되는 이미지 수입니다. NB: " +"한 번에 모델에게 2개의 측면이 공급되므로 한 번에 모델 내의 실제 이미지 수는 " +"여기에서 설정한 수의 두 배입니다. 더 큰 배치에는 더 많은 GPU RAM이 필요합니" +"다." + +#: lib/cli/args_train.py:162 +msgid "" +"Length of training in iterations. This is only really used for automation. " +"There is no 'correct' number of iterations a model should be trained for. " +"You should stop training when you are happy with the previews. However, if " +"you want the model to stop automatically at a set number of iterations, you " +"can set that value here." +msgstr "" +"반복에서 훈련 길이. 이것은 실제로 자동화에만 사용됩니다. 모델을 훈련해야 하" +"는 '올바른' 반복 횟수는 없습니다. 미리 보기에 만족하면 훈련을 중단해야 합니" +"다. 그러나 설정된 반복 횟수에서 모델이 자동으로 중지되도록 하려면 여기에서 해" +"당 값을 설정할 수 있습니다." + +#: lib/cli/args_train.py:176 +msgid "" +"Learning rate warmup. Linearly increase the learning rate from 0 to the " +"chosen target rate over the number of iterations given here. 0 to disable." +msgstr "" +"학습률 워밍업. 여기에 주어진 반복 횟수에 따라 학습률을 0에서 선택한 목표 속도" +"까지 선형적으로 증가시킵니다. 0으로 설정하면 비활성화됩니다." + +#: lib/cli/args_train.py:184 +msgid "Use distibuted training on multi-gpu setups." +msgstr "멀티 GPU 환경에서 분산 학습을 활용하세요." + +#: lib/cli/args_train.py:192 +msgid "" +"Disables TensorBoard logging. NB: Disabling logs means that you will not be " +"able to use the graph or analysis for this session in the GUI." +msgstr "" +"텐서보드 로깅을 비활성화합니다. 주의: 로그를 비활성화하면 GUI에서 이 세션에 " +"대한 그래프 또는 분석을 사용할 수 없습니다." + +#: lib/cli/args_train.py:201 +msgid "" +"Use the Learning Rate Finder to discover the optimal learning rate for " +"training. For new models, this will calculate the optimal learning rate for " +"the model. For existing models this will use the optimal learning rate that " +"was discovered when initializing the model. Setting this option will ignore " +"the manually configured learning rate (configurable in train settings)." +msgstr "" +"학습률 찾기를 사용하여 훈련을 위한 최적의 학습률을 찾아보세요. 새 모델의 경" +"우 모델에 대한 최적의 학습률을 계산합니다. 기존 모델의 경우 모델을 초기화할 " +"때 발견된 최적의 학습률을 사용합니다. 이 옵션을 설정하면 수동으로 구성된 학습" +"률(기차 설정에서 구성 가능)이 무시됩니다." + +#: lib/cli/args_train.py:214 lib/cli/args_train.py:224 +msgid "Saving" +msgstr "저장" + +#: lib/cli/args_train.py:215 +msgid "Sets the number of iterations between each model save." +msgstr "각 모델 저장 사이의 반복 횟수를 설정합니다." + +#: lib/cli/args_train.py:226 +msgid "" +"Sets the number of iterations before saving a backup snapshot of the model " +"in it's current state. Set to 0 for off." +msgstr "" +"현재 상태에서 모델의 백업 스냅샷을 저장하기 전에 반복할 횟수를 설정합니다. 0" +"으로 설정하면 꺼집니다." + +#: lib/cli/args_train.py:233 lib/cli/args_train.py:245 +#: lib/cli/args_train.py:257 +msgid "timelapse" +msgstr "타임랩스" + +#: lib/cli/args_train.py:235 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'A' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-B parameter." +msgstr "" +"타임랩스를 만드는 옵션입니다. Timelapse(시간 경과)는 저장을 반복할 때마다 선" +"택한 얼굴의 이미지를 Timelapse-output(시간 경과 출력) 폴더에 저장합니다. 타임" +"랩스를 만드는 데 사용할 'A' 얼굴의 입력 폴더여야 합니다. 또한 사용자는 --" +"timelapse-output 및 --timelapse-input-B 매개 변수를 제공해야 합니다." + +#: lib/cli/args_train.py:247 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'B' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-A parameter." +msgstr "" +"타임 랩스를 만드는 데 선택적입니다. Timelapse(시간 경과)는 저장을 반복할 때마" +"다 선택한 얼굴의 이미지를 Timelapse-output(시간 경과 출력) 폴더에 저장합니" +"다. 타임 랩스를 만드는 데 사용할 'B' 얼굴의 입력 폴더여야 합니다. 또한 사용자" +"는 --timelapse-output 및 --timelapse-input-A 매개 변수를 제공해야 합니다." + +#: lib/cli/args_train.py:259 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. If " +"the input folders are supplied but no output folder, it will default to your " +"model folder/timelapse/" +msgstr "" +"타임랩스를 만드는 데 선택적입니다. Timelapse(시간 경과)는 저장을 반복할 때마" +"다 선택한 얼굴의 이미지를 Timelapse-output(시간 경과 출력) 폴더에 저장합니" +"다. 입력 폴더가 제공되었지만 출력 폴더가 없는 경우 모델 폴더에/timelapse/로 " +"기본 설정됩니다" + +#: lib/cli/args_train.py:268 lib/cli/args_train.py:275 +msgid "preview" +msgstr "미리보기" + +#: lib/cli/args_train.py:269 +msgid "Show training preview output. in a separate window." +msgstr "훈련 미리보기 결과를 각기 다른 창에서 보여줍니다." + +#: lib/cli/args_train.py:277 +msgid "" +"Writes the training result to a file. The image will be stored in the root " +"of your FaceSwap folder." +msgstr "" +"훈련 결과를 파일에 씁니다. 이미지는 Faceswap 폴더의 최상위 폴더에 저장됩니다." + +#: lib/cli/args_train.py:284 lib/cli/args_train.py:294 +#: lib/cli/args_train.py:304 lib/cli/args_train.py:314 +msgid "augmentation" +msgstr "보정" + +#: lib/cli/args_train.py:286 +msgid "" +"Warps training faces to closely matched Landmarks from the opposite face-set " +"rather than randomly warping the face. This is the 'dfaker' way of doing " +"warping." +msgstr "" +"무작위로 얼굴을 변환하지 않고 반대쪽 얼굴 세트에서 특징점과 밀접하게 일치하도" +"록 훈련 얼굴을 변환해줍니다. 이것은 변환하는 'dfaker' 방식이다." + +#: lib/cli/args_train.py:296 +msgid "" +"To effectively learn, a random set of images are flipped horizontally. " +"Sometimes it is desirable for this not to occur. Generally this should be " +"left off except for during 'fit training'." +msgstr "" +"효과적으로 학습하기 위해 임의의 이미지 세트를 수평으로 뒤집습니다. 때때로 이" +"런 일이 일어나지 않는 것이 바람직합니다. 일반적으로 'fit training' 중을 제외" +"하고는 이 작업을 중단해야 합니다." + +#: lib/cli/args_train.py:306 +msgid "" +"Color augmentation helps make the model less susceptible to color " +"differences between the A and B sets, at an increased training time cost. " +"Enable this option to disable color augmentation." +msgstr "" +"색상 보정은 모델이 A와 B 세트 사이의 색상 차이에 덜 민감하게 만드는 데 도움" +"이 되며, 훈련 시간 비용이 증가합니다. 색상 보저를 사용하지 않으려면 이 옵션" +"을 사용합니다." + +#: lib/cli/args_train.py:316 +msgid "" +"Warping is integral to training the Neural Network. This option should only " +"be enabled towards the very end of training to try to bring out more detail. " +"Think of it as 'fine-tuning'. Enabling this option from the beginning is " +"likely to kill a model and lead to terrible results." +msgstr "" +"변환은 신경망을 훈련하는 데 필수적입니다. 이 옵션은 보다 세부적인 것들을 뽑아" +"내위하여 훈련 막바지까지 활성화하여야 합니다. 이것은 '미세 조정'이라고 생각하" +"면 됩니다. 처음부터 이 옵션을 활성화하면 모델이 죽을 수있고 끔찍한 결과를 초" +"래할 수 있습니다." + +#~ msgid "" +#~ "R|Select the distribution stategy to use.\n" +#~ "L|default: Use Tensorflow's default distribution strategy.\n" +#~ "L|central-storage: Centralizes variables on the CPU whilst operations are " +#~ "performed on 1 or more local GPUs. This can help save some VRAM at the " +#~ "cost of some speed by not storing variables on the GPU. Note: Mixed-" +#~ "Precision is not supported on multi-GPU setups.\n" +#~ "L|mirrored: Supports synchronous distributed training across multiple " +#~ "local GPUs. A copy of the model and all variables are loaded onto each " +#~ "GPU with batches distributed to each GPU at each iteration." +#~ msgstr "" +#~ "R|사용할 배포 상태를 선택합니다.\n" +#~ "L|default: Tensorflow의 기본 배포 전략을 사용합니다.\n" +#~ "L|central-storage: 작업이 1개 이상의 로컬 GPU에서 수행되는 동안 CPU의 변수" +#~ "를 중앙 집중화합니다. 이렇게 하면 GPU에 변수를 저장하지 않음으로써 약간의 " +#~ "속도를 희생하여 일부 VRAM을 절약할 수 있습니다. 참고: 다중 정밀도는 다중 " +#~ "GPU 설정에서 지원되지 않습니다.\n" +#~ "L|mirrored: 여러 로컬 GPU에서 동기화 분산 훈련을 지원합니다. 모델의 복사본" +#~ "과 모든 변수는 각 반복에서 각 GPU에 배포된 배치들와 함께 각 GPU에 로드됩니" +#~ "다." diff --git a/locales/kr/LC_MESSAGES/tools.alignments.cli.mo b/locales/kr/LC_MESSAGES/tools.alignments.cli.mo new file mode 100644 index 0000000000..0c4f74d355 Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.alignments.cli.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.alignments.cli.po b/locales/kr/LC_MESSAGES/tools.alignments.cli.po new file mode 100644 index 0000000000..1bedaa0a22 --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.alignments.cli.po @@ -0,0 +1,254 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-19 11:28+0100\n" +"PO-Revision-Date: 2024-04-19 11:30+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/alignments/cli.py:16 +msgid "" +"This command lets you perform various tasks pertaining to an alignments file." +msgstr "" +"이 명령을 사용하여 alignments 파일과 관련된 다양한 작ㅇ를 수행할 수 있습니다." + +#: tools/alignments/cli.py:31 +msgid "" +"Alignments tool\n" +"This tool allows you to perform numerous actions on or using an alignments " +"file against its corresponding faceset/frame source." +msgstr "" +"_alignments 도구\n" +"이 도구를 사용하면 해당 얼굴 세트/프레임 원본에 해당하는 alignments 파일을 사" +"용하거나 여러 작업을 수행할 수 있습니다." + +#: tools/alignments/cli.py:43 +msgid " Must Pass in a frames folder/source video file (-r)." +msgstr "" +" 프레임들이 저장된 폴더나 원본 비디오 파일을 무조건 전달해야 합니다 (-r)." + +#: tools/alignments/cli.py:44 +msgid " Must Pass in a faces folder (-c)." +msgstr " 얼굴 폴더를 무조건 전달해야 합니다 (-c)." + +#: tools/alignments/cli.py:45 +msgid "" +" Must Pass in either a frames folder/source video file OR a faces folder (-r " +"or -c)." +msgstr "" +" 프레임 폴더나 원본 비디오 파일 또는 얼굴 폴더중 하나를 무조건 전달해야 합니" +"다 (-r and -c)." + +#: tools/alignments/cli.py:47 +msgid "" +" Must Pass in a frames folder/source video file AND a faces folder (-r and -" +"c)." +msgstr "" +" 프레임 폴더나 원본 비디오 파일 그리고 얼굴 폴더를 무조건 전달해야 합니다 (-" +"r and -c)." + +#: tools/alignments/cli.py:49 +msgid " Use the output option (-o) to process results." +msgstr " 결과를 진행하려면 (-o) 출력 옵션을 사용하세요." + +#: tools/alignments/cli.py:58 tools/alignments/cli.py:104 +msgid "processing" +msgstr "처리" + +#: tools/alignments/cli.py:61 +#, python-brace-format +msgid "" +"R|Choose which action you want to perform. NB: All actions require an " +"alignments file (-a) to be passed in.\n" +"L|'draw': Draw landmarks on frames in the selected folder/video. A subfolder " +"will be created within the frames folder to hold the output.{0}\n" +"L|'export': Export the contents of an alignments file to a json file. Can be " +"used for editing alignment information in external tools and then re-" +"importing by using Faceswap's Extract 'Import' plugins. Note: masks and " +"identity vectors will not be included in the exported file, so will be re-" +"generated when the json file is imported back into Faceswap. All data is " +"exported with the origin (0, 0) at the top left of the canvas.\n" +"L|'extract': Re-extract faces from the source frames/video based on " +"alignment data. This is a lot quicker than re-detecting faces. Can pass in " +"the '-een' (--extract-every-n) parameter to only extract every nth frame." +"{1}\n" +"L|'from-faces': Generate alignment file(s) from a folder of extracted faces. " +"if the folder of faces comes from multiple sources, then multiple alignments " +"files will be created. NB: for faces which have been extracted from folders " +"of source images, rather than a video, a single alignments file will be " +"created as there is no way for the process to know how many folders of " +"images were originally used. You do not need to provide an alignments file " +"path to run this job. {3}\n" +"L|'missing-alignments': Identify frames that do not exist in the alignments " +"file.{2}{0}\n" +"L|'missing-frames': Identify frames in the alignments file that do not " +"appear within the frames folder/video.{2}{0}\n" +"L|'multi-faces': Identify where multiple faces exist within the alignments " +"file.{2}{4}\n" +"L|'no-faces': Identify frames that exist within the alignment file but no " +"faces were detected.{2}{0}\n" +"L|'remove-faces': Remove deleted faces from an alignments file. The original " +"alignments file will be backed up.{3}\n" +"L|'rename' - Rename faces to correspond with their parent frame and position " +"index in the alignments file (i.e. how they are named after running extract)." +"{3}\n" +"L|'sort': Re-index the alignments from left to right. For alignments with " +"multiple faces this will ensure that the left-most face is at index 0.\n" +"L|'spatial': Perform spatial and temporal filtering to smooth alignments " +"(EXPERIMENTAL!)" +msgstr "" +"R|실행할 작업을 선택합니다. 주의: 모든 작업을 수행하려면 alignments 파일(-a)" +"을 전달해야 합니다.\n" +"L|'draw': 선택한 폴더/비디오의 프레임에 특징점을 그립니다. 출력을 저장할 하" +"위 폴더가 프레임 폴더 내에 생성됩니다.{0}\n" +"L|'export': 정렬 파일의 내용을 JSON 파일로 내보내십시오. 외부 도구에서 정렬 " +"정보를 편집 한 다음 FaceSwap의 추출물 'Import'플러그인을 사용하여 다시 인상하" +"는 데 사용할 수 있습니다. 참고 : 마스크 및 ID 벡터는 내보내기 파일에 포함되" +"지 않으므로 JSON 파일이 다시 FaceSwap으로 가져 오면 다시 생성됩니다. 모든 데" +"이터는 캔버스의 왼쪽 상단에있는 원점 (0, 0)으로 내 보냅니다.\n" +"L|'extract': alignments 데이터를 기반으로 소스 프레임/비디오에서 얼굴을 재추" +"출합니다. 이것은 얼굴을 재감지하는 것보다 훨씬 더 빠릅니다. '-een'(--extract-" +"every-n) 매개 변수를 전달하여 모든 n번째 프레임을 추출할 수 있습니다.{1}\n" +"L|'from-faces': 추출된 얼굴 폴더에서 alignments 파일을 생성합니다. 폴더 내의 " +"얼굴들을 여러 소스에서 가져온 경우 여러 alignments 파일이 생성됩니다. 참고: " +"비디오가 아닌 원본 이미지의 폴더를 추출한 얼굴의 경우, 원래 사용된 이미지의 " +"폴더 수를 알 수 없으므로 단일 alignments 파일이 생성됩니다. 이 작업을 실행하" +"기 위해 alignments 파일 경로를 제공할 필요는 없습니다. {3}\n" +"L|'missing-alignments': alignments 파일에 없는 프레임을 식별합니다.{2}{0}\n" +"L|'missing-frames': alignments 파일에서 [프레임 폴더/비디오] 내에 나타나지 않" +"는 프레임을 식별합니다.{2}{0}\n" +"L|'multi-faces': alignments 파일 내에서 여러 얼굴이 있는 위치를 식별합니다." +"{2}{4}\n" +"L|'no faces': alignments 파일 내에 있지만 얼굴이 탐지되지 않은 프레임을 식별" +"합니다.{2}{0}\n" +"L|'removes-faces': alignments 파일에서 삭제된 얼굴을 제거합니다. 원래 " +"alignments 파일은 백업됩니다.{3}\n" +"L|'rename' : alignments 파일의 상위 프레임 및 위치 색인에 해당하도록 얼굴 이" +"름을 바꿉니다(즉, 추출을 실행한 후에 얼굴 이름을 짓는 방법).{3}\n" +"L|'sort': alignments을 왼쪽에서 오른쪽으로 다시 인덱싱합니다. 얼굴이 여러 개" +"인 alignments의 경우 맨 왼쪽 얼굴이 색인 0에 있습니다.\n" +"L| 'spatial': 공간 및 시간 필터링을 수행하여 alignments를 원활하게 수행합니다" +"(실험적!)." + +#: tools/alignments/cli.py:107 +msgid "" +"R|How to output discovered items ('faces' and 'frames' only):\n" +"L|'console': Print the list of frames to the screen. (DEFAULT)\n" +"L|'file': Output the list of frames to a text file (stored within the source " +"directory).\n" +"L|'move': Move the discovered items to a sub-folder within the source " +"directory." +msgstr "" +"R|검색된 항목을 출력하는 방법('얼굴' 및 '프레임'만 해당):\n" +"L|'console': 프레임 목록을 화면에 인쇄합니다. (기본값)\n" +"L|'파일': 프레임 목록을 텍스트 파일(소스 디렉토리에 저장)로 출력합니다.\n" +"L|'이동': 검색된 항목을 원본 디렉토리 내의 하위 폴더로 이동합니다." + +#: tools/alignments/cli.py:118 tools/alignments/cli.py:141 +#: tools/alignments/cli.py:148 +msgid "data" +msgstr "데이터" + +#: tools/alignments/cli.py:125 +msgid "" +"Full path to the alignments file to be processed. If you have input a " +"'frames_dir' and don't provide this option, the process will try to find the " +"alignments file at the default location. All jobs require an alignments file " +"with the exception of 'from-faces' when the alignments file will be " +"generated in the specified faces folder." +msgstr "" +"처리할 alignments 파일의 전체 경로입니다. 'frames_dir'을 입력했는데 이 옵션" +"을 제공하지 않으면 프로세스는 기본 위치에서 alignments 파일을 찾으려고 합니" +"다. 지정된 얼굴 폴더에 alignments 파일이 생성될 때 모든 작업은 'from-" +"faces'를 제외한 alignments 파일이 필요로 합니다." + +#: tools/alignments/cli.py:142 +msgid "Directory containing source frames that faces were extracted from." +msgstr "얼굴 추출의 소스로 쓰인 원본 프레임이 저장된 디렉토리." + +#: tools/alignments/cli.py:150 +msgid "" +"R|Run the aligmnents tool on multiple sources. The following jobs support " +"batch mode:\n" +"L|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " +"sort, spatial.\n" +"If batch mode is selected then the other options should be set as follows:\n" +"L|alignments_file: For 'sort' and 'spatial' this should point to the parent " +"folder containing the alignments files to be processed. For all other jobs " +"this option is ignored, and the alignments files must exist at their default " +"location relative to the original frames folder/video.\n" +"L|faces_dir: For 'from-faces' this should be a parent folder, containing sub-" +"folders of extracted faces from which to generate alignments files. For " +"'extract' this should be a parent folder where sub-folders will be created " +"for each extraction to be run. For all other jobs this option is ignored.\n" +"L|frames_dir: For 'draw', 'extract', 'missing-alignments', 'missing-frames' " +"and 'no-faces' this should be a parent folder containing video files or sub-" +"folders of images to perform the alignments job on. The alignments file " +"should exist at the default location. For all other jobs this option is " +"ignored." +msgstr "" +"R|여러 소스에서 정렬 도구를 실행합니다. 다음 작업은 배치 모드를 지원합니다.\n" +"L|그리기, 추출, 얼굴부터, 정렬 누락, 프레임 누락, 얼굴 없음, 정렬, 공간.\n" +"배치 모드를 선택한 경우 다른 옵션을 다음과 같이 설정해야 합니다.\n" +"L|alignments_file: 'sort'및 'spatial'의 경우 처리할 정렬 파일이 포함된 상위 " +"폴더를 가리켜야 합니다. 다른 모든 작업의 경우 이 옵션은 무시되며 정렬 파일은 " +"원본 프레임 폴더/비디오에 상대적인 기본 위치에 있어야 합니다.\n" +"L|faces_dir: 'from-faces'의 경우 정렬 파일을 생성할 추출된 면의 하위 폴더를 " +"포함하는 상위 폴더여야 합니다. '추출'의 경우 실행할 각 추출에 대해 하위 폴더" +"가 생성되는 상위 폴더여야 합니다. 다른 모든 작업의 경우 이 옵션은 무시됩니" +"다.\n" +"L|frames_dir: 'draw', 'extract', 'missing-alignments', 'missing-frames' 및 " +"'no-faces'의 경우 비디오 파일이 포함된 상위 폴더 또는 정렬 작업을 수행할 이미" +"지의 하위 폴더여야 합니다. 에. 정렬 파일은 기본 위치에 있어야 합니다. 다른 모" +"든 작업의 경우 이 옵션은 무시됩니다." + +#: tools/alignments/cli.py:176 tools/alignments/cli.py:188 +#: tools/alignments/cli.py:198 +msgid "extract" +msgstr "추출" + +#: tools/alignments/cli.py:178 +msgid "" +"[Extract only] Extract every 'nth' frame. This option will skip frames when " +"extracting faces. For example a value of 1 will extract faces from every " +"frame, a value of 10 will extract faces from every 10th frame." +msgstr "" +"[Extract only] 모든 'n번째' 프레임을 추출합니다. 이 옵션은 얼굴을 추출할 때 " +"프레임을 건너뜁니다. 예를 들어, 값이 1이면 모든 프레임에서 얼굴이 추출되고, " +"값이 10이면 모든 10번째 프레임에서 얼굴이 추출됩니다." + +#: tools/alignments/cli.py:189 +msgid "[Extract only] The output size of extracted faces." +msgstr "[Extract only] 추출된 얼굴들의 결과 크기입니다." + +#: tools/alignments/cli.py:200 +msgid "" +"[Extract only] Only extract faces that have been resized by this percent or " +"more to meet the specified extract size (`-sz`, `--size`). Useful for " +"excluding low-res images from a training set. Set to 0 to extract all faces. " +"Eg: For an extract size of 512px, A setting of 50 will only include faces " +"that have been resized from 256px or above. Setting to 100 will only extract " +"faces that have been resized from 512px or above. A setting of 200 will only " +"extract faces that have been downscaled from 1024px or above." +msgstr "" +"[Extract only] 지정된 추출 크기('-sz', '--size')를 맞추기 위하여 크기가 이 비" +"율 이상 resize된 얼굴들만 추출합니다. 훈련 세트에서 저해상도 이미지를 제외하" +"는 데 유용합니다. 모든 얼굴을 추출하려면 0으로 설정합니다. 예: 추출 크기가 " +"512px인 경우, 50으로 설정하면 크기가 256px 이상인 면만 포함됩니다. 100으로 설" +"정하면 512px 이상에서 크기가 조정된 얼굴만 추출됩니다. 200으로 설정하면 " +"1024px 이상에서 축소된 얼굴만 추출됩니다." + +#~ msgid "Directory containing extracted faces." +#~ msgstr "추출된 얼굴들이 저장된 디렉토리." diff --git a/locales/kr/LC_MESSAGES/tools.effmpeg.cli.mo b/locales/kr/LC_MESSAGES/tools.effmpeg.cli.mo new file mode 100644 index 0000000000..f6f563913f Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.effmpeg.cli.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.effmpeg.cli.po b/locales/kr/LC_MESSAGES/tools.effmpeg.cli.po new file mode 100644 index 0000000000..58b106c585 --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.effmpeg.cli.po @@ -0,0 +1,185 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:50+0000\n" +"PO-Revision-Date: 2024-03-29 00:05+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/effmpeg/cli.py:15 +msgid "This command allows you to easily execute common ffmpeg tasks." +msgstr "" +"이 명령어는 사용자에게 일반 ffmpeg 작업을 쉽게 실행할 수 있도록 해줍니다." + +#: tools/effmpeg/cli.py:52 +msgid "A wrapper for ffmpeg for performing image <> video converting." +msgstr "이미지 <> 비디오 변환을 수행하기 위한 ffmpeg용 wrapper입니다." + +#: tools/effmpeg/cli.py:64 +msgid "" +"R|Choose which action you want ffmpeg ffmpeg to do.\n" +"L|'extract': turns videos into images \n" +"L|'gen-vid': turns images into videos \n" +"L|'get-fps' returns the chosen video's fps.\n" +"L|'get-info' returns information about a video.\n" +"L|'mux-audio' add audio from one video to another.\n" +"L|'rescale' resize video.\n" +"L|'rotate' rotate video.\n" +"L|'slice' cuts a portion of the video into a separate video file." +msgstr "" +"R|ffmpeg ffmpeg에서 수행할 작업을 선택합니다.\n" +"L|'extraction': 비디오를 이미지로 바꿉니다.\n" +"L|'gen-vid': 이미지를 비디오로 바꿉니다.\n" +"L|'get-fps'는 선택한 비디오의 fps를 반환합니다.\n" +"L|'get-info'는 동영상에 대한 정보를 반환합니다.\n" +"L|'mux-audio'는 한 비디오에서 다른 비디오로 오디오를 추가합니다.\n" +"L|'rescale' 크기 조정 비디오.\n" +"L|'rotate' 비디오 회전.\n" +"L| 'slice'는 동영상의 일부를 별도의 동영상 파일로 잘라냅니다." + +#: tools/effmpeg/cli.py:78 +msgid "Input file." +msgstr "입력 파일." + +#: tools/effmpeg/cli.py:79 tools/effmpeg/cli.py:86 tools/effmpeg/cli.py:100 +msgid "data" +msgstr "데이터" + +#: tools/effmpeg/cli.py:89 +msgid "" +"Output file. If no output is specified then: if the output is meant to be a " +"video then a video called 'out.mkv' will be created in the input directory; " +"if the output is meant to be a directory then a directory called 'out' will " +"be created inside the input directory. Note: the chosen output file " +"extension will determine the file encoding." +msgstr "" +"출력 파일. 출력이 지정되지 않은 경우: 출력이 비디오여야 한다면 입력 디렉토리" +"에 'out.mkv'라는 비디오가 생성됩니다. 출력이 디렉토리여야 한다면 입력 디렉토" +"리 내에 'out'이라는 디렉터리가 생성됩니다. 참고: 선택한 출력 파일 확장자가 파" +"일 인코딩을 결정합니다." + +#: tools/effmpeg/cli.py:102 +msgid "Path to reference video if 'input' was not a video." +msgstr "만약 input이 비디오가 아닐 경우 참고 비디으의 경로." + +#: tools/effmpeg/cli.py:108 tools/effmpeg/cli.py:118 tools/effmpeg/cli.py:156 +#: tools/effmpeg/cli.py:185 +msgid "output" +msgstr "출력" + +#: tools/effmpeg/cli.py:110 +msgid "" +"Provide video fps. Can be an integer, float or fraction. Negative values " +"will will make the program try to get the fps from the input or reference " +"videos." +msgstr "" +"비디오 fps를 제공합니다. 정수, 부동 또는 분수가 될 수 있습니다. 음수 값을 지" +"정하면 프로그램이 입력 또는 참조 비디오에서 fps를 가져오려고 합니다." + +#: tools/effmpeg/cli.py:120 +msgid "" +"Image format that extracted images should be saved as. '.bmp' will offer the " +"fastest extraction speed, but will take the most storage space. '.png' will " +"be slower but will take less storage." +msgstr "" +"추출된 이미지의 확장자는 '.bmp'로 저장되어야 합니다. '.bmp'는 가장 빠른 추출 " +"속도를 제공하지만 가장 많은 저장 공간을 차지합니다. '.png'은 속도는 더 느리지" +"만 저장 공간은 더 적게 차지합니다." + +#: tools/effmpeg/cli.py:127 tools/effmpeg/cli.py:136 tools/effmpeg/cli.py:145 +msgid "clip" +msgstr "클립" + +#: tools/effmpeg/cli.py:129 +msgid "" +"Enter the start time from which an action is to be applied. Default: " +"00:00:00, in HH:MM:SS format. You can also enter the time with or without " +"the colons, e.g. 00:0000 or 026010." +msgstr "" +"작업을 적용할 시작 시간을 입력합니다. 기본값: 00:00:00, HH:MM:SS 형식입니다. " +"콜론을 포함하거나 포함하지 않은 시간(예: 00:0000 또는 026010)을 입력할 수도 " +"있습니다." + +#: tools/effmpeg/cli.py:138 +msgid "" +"Enter the end time to which an action is to be applied. If both an end time " +"and duration are set, then the end time will be used and the duration will " +"be ignored. Default: 00:00:00, in HH:MM:SS." +msgstr "" +"적용된 작업의 종료 시간을 입력합니다. 종료 시간과 기간이 모두 설정된 경우 종" +"료 시간이 사용되고 기간이 무시됩니다. 기본값: 00:00:00, HH:MM:SS." + +#: tools/effmpeg/cli.py:147 +msgid "" +"Enter the duration of the chosen action, for example if you enter 00:00:10 " +"for slice, then the first 10 seconds after and including the start time will " +"be cut out into a new video. Default: 00:00:00, in HH:MM:SS format. You can " +"also enter the time with or without the colons, e.g. 00:0000 or 026010." +msgstr "" +"선택한 작업의 지속 시간을 입력합니다. 예를 들어 슬라이스에 00:00:10을 입력하" +"면 시작 시간 이후의 첫 10초가 새 비디오로 잘라집니다. 기본값: 00:00:00, HH:" +"MM:SS 형식입니다. 콜론을 포함하거나 포함하지 않은 시간(예: 00:0000 또는 " +"026010)을 입력할 수도 있습니다." + +#: tools/effmpeg/cli.py:158 +msgid "" +"Mux the audio from the reference video into the input video. This option is " +"only used for the 'gen-vid' action. 'mux-audio' action has this turned on " +"implicitly." +msgstr "" +"참조 비디오의 오디오를 입력 비디오에 병합합니다. 이 옵션은 'gen-vid' 작업에" +"만 사용됩니다. 'mux-timeout' 작업은 이 작업을 암시적으로 활성화했습니다." + +#: tools/effmpeg/cli.py:169 tools/effmpeg/cli.py:179 +msgid "rotate" +msgstr "회전" + +#: tools/effmpeg/cli.py:171 +msgid "" +"Transpose the video. If transpose is set, then degrees will be ignored. For " +"cli you can enter either the number or the long command name, e.g. to use " +"(1, 90Clockwise) -tr 1 or -tr 90Clockwise" +msgstr "" +"비디오를 전치합니다. 전치를 설정하면 각도가 무시됩니다. cli의 경우 숫자 또는 " +"긴 명령 이름을 입력할 수 있습니다(예: (1, 90Clockwise) (-tr 1 또는 -tr " +"90Clockwise)" + +#: tools/effmpeg/cli.py:180 +msgid "Rotate the video clockwise by the given number of degrees." +msgstr "비디오를 주어진 입력 각도에 따라 시계방향으로 회전합니다." + +#: tools/effmpeg/cli.py:187 +msgid "Set the new resolution scale if the chosen action is 'rescale'." +msgstr "선택한 작업이 'rescale'이라면 새로운 해상도 크기를 설정합니다." + +#: tools/effmpeg/cli.py:192 tools/effmpeg/cli.py:200 +msgid "settings" +msgstr "설정" + +#: tools/effmpeg/cli.py:194 +msgid "" +"Reduces output verbosity so that only serious errors are printed. If both " +"quiet and verbose are set, verbose will override quiet." +msgstr "" +"출력 상세도를 줄여 심각한 오류만 출력합니다. quiet와 verbose가 모두 설정된 경" +"우 verbose가 quiet를 재정의합니다." + +#: tools/effmpeg/cli.py:202 +msgid "" +"Increases output verbosity. If both quiet and verbose are set, verbose will " +"override quiet." +msgstr "" +"출력 상세도를 높입니다. quiet와 verbose가 모두 설정된 경우 verbose가 quiet를 " +"재정의합니다." diff --git a/locales/kr/LC_MESSAGES/tools.manual.mo b/locales/kr/LC_MESSAGES/tools.manual.mo new file mode 100644 index 0000000000..2a801da9a2 Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.manual.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.manual.po b/locales/kr/LC_MESSAGES/tools.manual.po new file mode 100644 index 0000000000..0fbcc99572 --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.manual.po @@ -0,0 +1,283 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:55+0000\n" +"PO-Revision-Date: 2024-03-29 00:05+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/manual/cli.py:13 +msgid "" +"This command lets you perform various actions on frames, faces and " +"alignments files using visual tools." +msgstr "" +"이 명령어는 visual 도구들을 사용하여 프레임, 얼굴, alignments 파일들에 대한 " +"다양한 작업을 수행할 수 있도록 해줍니다." + +#: tools/manual/cli.py:23 +msgid "" +"A tool to perform various actions on frames, faces and alignments files " +"using visual tools" +msgstr "" +"프레임, 얼굴, alignments 파일들에 대한 다양한 작업을 수행할 수 있도록 해주는 " +"도구" + +#: tools/manual/cli.py:35 tools/manual/cli.py:44 +msgid "data" +msgstr "데이터" + +#: tools/manual/cli.py:38 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" +"입력에 대한 alignments 파일의 경로, 만약 설정되지 않았다면 기본 경로입니다" + +#: tools/manual/cli.py:46 +msgid "" +"Video file or directory containing source frames that faces were extracted " +"from." +msgstr "얼굴이 추출된 소스 프레임을 가지고 있는 비디오 파일 또는 디렉토리." + +#: tools/manual/cli.py:53 tools/manual/cli.py:62 +msgid "options" +msgstr "설정" + +#: tools/manual/cli.py:55 +msgid "" +"Force regeneration of the low resolution jpg thumbnails in the alignments " +"file." +msgstr "_alignments 파일에서 저해상도 jpg 미리 보기를 강제로 재생성합니다." + +#: tools/manual/cli.py:64 +msgid "" +"The process attempts to speed up generation of thumbnails by extracting from " +"the video in parallel threads. For some videos, this causes the caching " +"process to hang. If this happens, then set this option to generate the " +"thumbnails in a slower, but more stable single thread." +msgstr "" +"프로세스는 병렬 스레드에서 비디오를 추출하여 썸네일 생성 속도를 높이려고 시도" +"합니다. 일부 비디오의 경우 캐싱 프로세스가 중단될 수 있습니다. 이런 경우 이 " +"옵션을 설정하여 더 느리지만 안정적인 단일 스레드에서 썸네일를 생성하십시오." + +#: tools/manual\faceviewer\frame.py:163 +msgid "Display the landmarks mesh" +msgstr "특징점 망 보이기" + +#: tools/manual\faceviewer\frame.py:164 +msgid "Display the mask" +msgstr "마스크 보이기" + +#: tools/manual\frameviewer\editor\_base.py:628 +#: tools/manual\frameviewer\editor\landmarks.py:44 +#: tools/manual\frameviewer\editor\mask.py:75 +msgid "Magnify/Demagnify the View" +msgstr "보기를 확대/축소 합니다" + +#: tools/manual\frameviewer\editor\bounding_box.py:33 +#: tools/manual\frameviewer\editor\extract_box.py:32 +msgid "Delete Face" +msgstr "얼굴 삭제" + +#: tools/manual\frameviewer\editor\bounding_box.py:36 +msgid "" +"Bounding Box Editor\n" +"Edit the bounding box being fed into the aligner to recalculate the " +"landmarks.\n" +"\n" +" - Grab the corner anchors to resize the bounding box.\n" +" - Click and drag the bounding box to relocate.\n" +" - Click in empty space to create a new bounding box.\n" +" - Right click a bounding box to delete a face." +msgstr "" +"경계 상자 편집기\n" +"aligner 에 공급되는 경계 상자를 편집하여 특징점을 다시 계산합니다.\n" +"\n" +"- corner anchors를 사용하여 경계 상자의 크기를 재조정합니다.\n" +"- 경계 상자를 클릭하고 끌어서 재배치합니다.\n" +"- 빈 공간을 클릭하여 새 경계 상자를 만듭니다.\n" +"- 경계 상자를 마우스 오른쪽 단추로 클릭하여 얼굴을 삭제합니다." + +#: tools/manual\frameviewer\editor\bounding_box.py:70 +msgid "" +"Aligner to use. FAN will obtain better alignments, but cv2-dnn can be useful " +"if FAN cannot get decent alignments and you want to set a base to edit from." +msgstr "" +"사용할 aligner. FAN은 더 나은 alignments을 얻을 수 있지만, 만약 FAN이 적절한 " +"alignments을 얻을 수 없고 편집을 시작할 기준점을 설정하려는 경우 cv2-dnn이 유" +"용할 수 있습니다." + +#: tools/manual\frameviewer\editor\bounding_box.py:83 +msgid "" +"Normalization method to use for feeding faces to the aligner. This can help " +"the aligner better align faces with difficult lighting conditions. Different " +"methods will yield different results on different sets. NB: This does not " +"impact the output face, just the input to the aligner.\n" +"\tnone: Don't perform normalization on the face.\n" +"\tclahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"\thist: Equalize the histograms on the RGB channels.\n" +"\tmean: Normalize the face colors to the mean." +msgstr "" +"_aligner에 얼굴을 공급하는 데 사용할 정규화 방법입니다. 이렇게 하면 aligner" +"가 어려운 조명 조건에서 얼굴을 더 잘 정렬할 수 있습니다. 방법이 다르면 세트마" +"다 결과가 다릅니다. NB: 출력 얼굴에는 영향을 주지 않으며 aligner에게 주는 입" +"력에만 영향을 줍니다.\n" +"\tnone: 얼굴에 정규화를 수행하지 않습니다.\n" +"\tclahe: 얼굴에 Contrast Limited Adaptive Histogram Equalization를 수행합니" +"다.\n" +"\thist: RGB 채널의 히스토그램을 균등화합니다.\n" +"\tmean: 얼굴 색상을 평균으로 정규화합니다." + +#: tools/manual\frameviewer\editor\extract_box.py:35 +msgid "" +"Extract Box Editor\n" +"Move the extract box that has been generated by the aligner. Click and " +"drag:\n" +"\n" +" - Inside the bounding box to relocate the landmarks.\n" +" - The corner anchors to resize the landmarks.\n" +" - Outside of the corners to rotate the landmarks." +msgstr "" +"Box Editor 추출\n" +"aligner에서 생성한 추출 box를 이동합니다. click & drag:\n" +"\n" +"- bouding box 내부에서 특징점을 재배치.\n" +"- 특징점들의 크기를 조정하는 corner anchors.\n" +"- 모서리를 벗어나 특징점을 회전합니다." + +#: tools/manual\frameviewer\editor\landmarks.py:27 +msgid "" +"Landmark Point Editor\n" +"Edit the individual landmark points.\n" +"\n" +" - Click and drag individual points to relocate.\n" +" - Draw a box to select multiple points to relocate." +msgstr "" +"특징점 편집기\n" +"개별 특징점들을 편집합니다.\n" +"\n" +" - 개별 특징점들을 클릭 & 드래그 하여 재배치합니다.\n" +" - 재배치할 여러개의 점들을 박스를 그려서 선택합니다." + +#: tools/manual\frameviewer\editor\mask.py:33 +msgid "" +"Mask Editor\n" +"Edit the mask.\n" +" - NB: For Landmark based masks (e.g. components/extended) it is better to " +"make sure the landmarks are correct rather than editing the mask directly. " +"Any change to the landmarks after editing the mask will override your manual " +"edits." +msgstr "" +"마스크 편집기\n" +"마스크를 편집합니다.\n" +"- 주의: 특징점 기반 마스크(예: 구성 요소/확장)의 경우 마스크를 직접 편집하기" +"보다는 특징점이 올바른지 확인하는 것이 좋습니다. 마스크를 편집한 후 특징점들 " +"변경하면 변경된 특징점들이 수동으로 편집한 마스크에 덮어 씌워집니다." + +#: tools/manual\frameviewer\editor\mask.py:77 +msgid "Draw Tool" +msgstr "그리기 도구" + +#: tools/manual\frameviewer\editor\mask.py:78 +msgid "Erase Tool" +msgstr "지우개 도구" + +#: tools/manual\frameviewer\editor\mask.py:97 +msgid "Select which mask to edit" +msgstr "편집할 마스크를 선택" + +#: tools/manual\frameviewer\editor\mask.py:104 +msgid "Set the brush size. ([ - decrease, ] - increase)" +msgstr "붓 크기 설정. ([ - decrease, ] - increase)" + +#: tools/manual\frameviewer\editor\mask.py:111 +msgid "Select the brush cursor color." +msgstr "붓 커서 색깔 선택." + +#: tools/manual\frameviewer\frame.py:78 +msgid "Play/Pause (SPACE)" +msgstr "재생/멈춤 (스페이스 바)" + +#: tools/manual\frameviewer\frame.py:79 +msgid "Go to First Frame (HOME)" +msgstr "첫 번째 프레임으로 이동 (HOME)" + +#: tools/manual\frameviewer\frame.py:80 +msgid "Go to Previous Frame (Z)" +msgstr "이전 프레임으로 이동 (Z)" + +#: tools/manual\frameviewer\frame.py:81 +msgid "Go to Next Frame (X)" +msgstr "다음 프레임으로 이동 (X)" + +#: tools/manual\frameviewer\frame.py:82 +msgid "Go to Last Frame (END)" +msgstr "마지막 프레임으로 이동 (END)" + +#: tools/manual\frameviewer\frame.py:83 +msgid "Extract the faces to a folder... (Ctrl+E)" +msgstr "폴더에 얼굴 추출... (Ctrl+E)" + +#: tools/manual\frameviewer\frame.py:84 +msgid "Save the Alignments file (Ctrl+S)" +msgstr "_Alignments file 저장 (Ctrl + S" + +#: tools/manual\frameviewer\frame.py:85 +msgid "Filter Frames to only those Containing the Selected Item (F)" +msgstr "오로지 선택된 아이템들을 가지고 있는 필터 프레임 (F)" + +#: tools/manual\frameviewer\frame.py:86 +msgid "" +"Set the distance from an 'average face' to be considered misaligned. Higher " +"distances are more restrictive" +msgstr "" +"'평균 얼굴'로부터의 거리를 잘못 정렬된 것으로 간주하도록 설정. 먼 거리에서 조" +"금 더 제한적입니다" + +#: tools/manual\frameviewer\frame.py:391 +msgid "View alignments" +msgstr "보기 정렬" + +#: tools/manual\frameviewer\frame.py:392 +msgid "Bounding box editor" +msgstr "경계 상자 편집기" + +#: tools/manual\frameviewer\frame.py:393 +msgid "Location editor" +msgstr "위치 편집기" + +#: tools/manual\frameviewer\frame.py:394 +msgid "Mask editor" +msgstr "마스크 편집기" + +#: tools/manual\frameviewer\frame.py:395 +msgid "Landmark point editor" +msgstr "특징점 편집기" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Next" +msgstr "다음" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Previous" +msgstr "이전" + +#: tools/manual\frameviewer\frame.py:481 +msgid "Revert to saved Alignments ({})" +msgstr "저장된 Alignments로 돌아가기 ({})" + +#: tools/manual\frameviewer\frame.py:487 +msgid "Copy {} Alignments ({})" +msgstr "{} Alignments를 복사 ({})" diff --git a/locales/kr/LC_MESSAGES/tools.mask.cli.mo b/locales/kr/LC_MESSAGES/tools.mask.cli.mo new file mode 100644 index 0000000000..9de146e0bc Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.mask.cli.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.mask.cli.po b/locales/kr/LC_MESSAGES/tools.mask.cli.po new file mode 100644 index 0000000000..94080f7faf --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.mask.cli.po @@ -0,0 +1,318 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-06-28 13:45+0100\n" +"PO-Revision-Date: 2024-06-28 13:48+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.4.4\n" + +#: tools/mask/cli.py:15 +msgid "" +"This tool allows you to generate, import, export or preview masks for " +"existing alignments." +msgstr "" +"이 도구를 사용하면 기존 정렬에 대한 마스크를 생성, 가져오기, 내보내기 또는 미" +"리 볼 수 있습니다." + +#: tools/mask/cli.py:25 +msgid "" +"Mask tool\n" +"Generate, import, export or preview masks for existing alignments files." +msgstr "" +"마스크 도구\n" +"기존 alignments 파일에 대한 마스크를 생성, 가져오기, 내보내기 또는 미리 봅니" +"다." + +#: tools/mask/cli.py:35 tools/mask/cli.py:47 tools/mask/cli.py:58 +#: tools/mask/cli.py:69 +msgid "data" +msgstr "데이터" + +#: tools/mask/cli.py:39 +msgid "" +"Full path to the alignments file that contains the masks if not at the " +"default location. NB: If the input-type is faces and you wish to update the " +"corresponding alignments file, then you must provide a value here as the " +"location cannot be automatically detected." +msgstr "" +"기본 위치가 아닌 경우 마스크를 추가할 정렬 파일의 전체 경로입니다. NB: 입력 " +"유형이 얼굴이고 해당 정렬 파일을 업데이트하려는 경우 위치를 자동으로 감지할 " +"수 없으므로 여기에 값을 제공해야 합니다." + +#: tools/mask/cli.py:51 +msgid "Directory containing extracted faces, source frames, or a video file." +msgstr "추출된 얼굴들, 원본 프레임들, 또는 비디오 파일이 존재하는 디렉토리." + +#: tools/mask/cli.py:61 +msgid "" +"R|Whether the `input` is a folder of faces or a folder frames/video\n" +"L|faces: The input is a folder containing extracted faces.\n" +"L|frames: The input is a folder containing frames or is a video" +msgstr "" +"R|'입력'이 얼굴의 폴더인지 아니면 폴더 프레임/비디오인지\n" +"L|faces: 입력은 추출된 얼굴을 포함된 폴더입니다.\n" +"L|frames: 입력이 프레임을 포함된 폴더이거나 비디오입니다" + +#: tools/mask/cli.py:71 +msgid "" +"R|Run the mask tool on multiple sources. If selected then the other options " +"should be set as follows:\n" +"L|input: A parent folder containing either all of the video files to be " +"processed, or containing sub-folders of frames/faces.\n" +"L|output-folder: If provided, then sub-folders will be created within the " +"given location to hold the previews for each input.\n" +"L|alignments: Alignments field will be ignored for batch processing. The " +"alignments files must exist at the default location (for frames). For batch " +"processing of masks with 'faces' as the input type, then only the PNG header " +"within the extracted faces will be updated." +msgstr "" +"R|여러 소스에서 마스크 도구를 실행합니다. 선택한 경우 다른 옵션을 다음과 같" +"이 설정해야 합니다.\n" +"L|input: 처리할 모든 비디오 파일을 포함하거나 프레임/얼굴의 하위 폴더를 포함" +"하는 상위 폴더입니다.\n" +"L|output-folder: 제공된 경우 각 입력에 대한 미리 보기를 보관하기 위해 지정된 " +"위치 내에 하위 폴더가 생성됩니다.\n" +"L|alignments: 일괄 처리에서는 정렬 필드가 무시됩니다. 정렬 파일은 기본 위치" +"(프레임용)에 있어야 합니다. 입력 유형이 '얼굴'인 마스크를 일괄 처리하는 경우 " +"추출된 얼굴 내의 PNG 헤더만 업데이트됩니다." + +#: tools/mask/cli.py:87 tools/mask/cli.py:119 +msgid "process" +msgstr "진행" + +#: tools/mask/cli.py:89 +msgid "" +"R|Masker to use.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members. Profile faces " +"may result in sub-par performance." +msgstr "" +"R|사용할 마스크.\n" +"L|bisnet-fp: 전체 얼굴 마스킹(마스크 설정에서 구성 가능)을 포함하여 마스킹할 " +"영역에 대한 보다 정교한 제어를 제공하는 비교적 가벼운 NN 기반 마스크입니다.\n" +"L|components: 특징점 위치를 기반으로 얼굴 분할을 제공하도록 설계된 마스크입니" +"다. 특징점의 외부에는 마스크를 만들기 위해 convex hull이가 형성되어 있습니" +"다.\n" +"L|custom: 마스크 영역을 모든 1 또는 0으로 채우는 더미 마스크입니다(설정에서 " +"구성 가능). 수동 도구에서 사용자 정의 마스크를 직접 수동으로 편집하려는 경우" +"에만 필요합니다. 이 마스크는 GPU를 사용하지 않습니다.\n" +"L|extended: 특징점 위치를 기반으로 얼굴 분할을 제공하도록 설계된 마스크입니" +"다. 지형지물의 외부에는 convex hull이 형성되어 있으며, 마스크는 이마 위로 뻗" +"어 있습니다.\n" +"L|vgg-clear: 대부분의 정면에 장애물이 없는 스마트한 분할을 제공하도록 설계된 " +"마스크입니다. 프로필 면 및 장애물로 인해 성능이 저하될 수 있습니다.\n" +"L|vgg-obstructed: 대부분의 정면 얼굴을 스마트하게 분할할 수 있도록 설계된 마" +"스크입니다. 마스크 모델은 일부 안면 장애물(손과 안경)을 인식하도록 특별히 훈" +"련되었습니다. 옆 얼굴은 평균 이하의 성능을 초래할 수 있습니다.\n" +"L|unet-dfl: 대부분 정면 얼굴을 스마트하게 분할하도록 설계된 마스크. 마스크 모" +"델은 커뮤니티 구성원들에 의해 훈련되었으며 추가 설명을 위해 테스트가 필요합니" +"다. 옆 얼굴은 평균 이하의 성능을 초래할 수 있습니다." + +#: tools/mask/cli.py:121 +msgid "" +"R|The Mask tool process to perform.\n" +"L|all: Update the mask for all faces in the alignments file for the selected " +"'masker'.\n" +"L|missing: Create a mask for all faces in the alignments file where a mask " +"does not previously exist for the selected 'masker'.\n" +"L|output: Don't update the masks, just output the selected 'masker' for " +"review/editing in external tools to the given output folder.\n" +"L|import: Import masks that have been edited outside of faceswap into the " +"alignments file. Note: 'custom' must be the selected 'masker' and the masks " +"must be in the same format as the 'input-type' (frames or faces)" +msgstr "" +"R|수행할 마스크 도구 프로세스입니다.\n" +"L|all: 선택한 'masker'에 대한 정렬 파일의 모든 면에 대한 마스크를 업데이트합" +"니다.\n" +"L|missing: 선택한 'masker'에 대해 이전에 마스크가 존재하지 않았던 정렬 파일" +"의 모든 면에 대한 마스크를 생성합니다.\n" +"L|output: 마스크를 업데이트하지 않고 외부 도구에서 검토/편집하기 위해 선택한 " +"'masker'를 지정된 출력 폴더로 출력합니다.\n" +"L|import: Faceswap 외부에서 편집된 마스크를 정렬 파일로 가져옵니다. 참고: " +"'custom'은 선택된 'masker'여야 하며 마스크는 'input-type'(frames 또는 faces)" +"과 동일한 형식이어야 합니다." + +#: tools/mask/cli.py:135 tools/mask/cli.py:154 tools/mask/cli.py:176 +msgid "import" +msgstr "수입" + +#: tools/mask/cli.py:137 +msgid "" +"R|Import only. The path to the folder that contains masks to be imported.\n" +"L|How the masks are provided is not important, but they will be stored, " +"internally, as 8-bit grayscale images.\n" +"L|If the input are images, then the masks must be named exactly the same as " +"input frames/faces (excluding the file extension).\n" +"L|If the input is a video file, then the filename of the masks is not " +"important but should contain the frame number at the end of the filename " +"(but before the file extension). The frame number can be separated from the " +"rest of the filename by any non-numeric character and can be padded by any " +"number of zeros. The frame number must correspond correctly to the frame " +"number in the original video (starting from frame 1)." +msgstr "" +"R|가져오기만 가능합니다. 가져올 마스크가 포함된 폴더의 경로입니다.\n" +"L|마스크 제공 방법은 중요하지 않지만 내부적으로 8비트 회색조 이미지로 저장됩" +"니다.\n" +"L|입력이 이미지인 경우 마스크 이름은 입력 프레임/얼굴과 정확히 동일하게 지정" +"되어야 합니다(파일 확장자 제외).\n" +"L|입력이 비디오 파일인 경우 마스크의 파일 이름은 중요하지 않지만 파일 이름 끝" +"에(파일 확장자 앞에) 프레임 번호가 포함되어야 합니다. 프레임 번호는 숫자가 아" +"닌 문자로 파일 이름의 나머지 부분과 구분될 수 있으며 임의 개수의 0으로 채워" +"질 수 있습니다. 프레임 번호는 원본 비디오의 프레임 번호(프레임 1부터 시작)와 " +"정확하게 일치해야 합니다." + +#: tools/mask/cli.py:156 +msgid "" +"R|Import/Output only. When importing masks, this is the centering to use. " +"For output this is only used for outputting custom imported masks, and " +"should correspond to the centering used when importing the mask. Note: For " +"any job other than 'import' and 'output' this option is ignored as mask " +"centering is handled internally.\n" +"L|face: Centers the mask on the center of the face, adjusting for pitch and " +"yaw. Outside of requirements for full head masking/training, this is likely " +"to be the best choice.\n" +"L|head: Centers the mask on the center of the head, adjusting for pitch and " +"yaw. Note: You should only select head centering if you intend to include " +"the full head (including hair) within the mask and are looking to train a " +"full head model.\n" +"L|legacy: The 'original' extraction technique. Centers the mask near the of " +"the nose with and crops closely to the face. Can result in the edges of the " +"mask appearing outside of the training area." +msgstr "" +"R|Import/Output only. 마스크를 가져올 때, 이것은 사용할 중앙 정렬입니다. 출력" +"의 경우, 이것은 사용자 지정 가져온 마스크를 출력하는 데만 사용되며, 마스크를 " +"가져올 때 사용된 중앙 정렬과 일치해야 합니다. 참고: 'import' 및 'output' 이외" +"의 모든 작업의 ​​경우 마스크 중앙 정렬이 내부적으로 처리되므로 이 옵션은 무시됩" +"니다.\n" +"L|면: 피치와 요를 조정하여 마스크를 얼굴 중앙에 배치합니다. 머리 전체 마스킹/" +"훈련에 대한 요구 사항을 제외하면 이것이 최선의 선택일 가능성이 높습니다.\n" +"L|head: 마스크를 머리 중앙에 배치하여 피치와 요를 조정합니다. 참고: 마스크 내" +"에 머리 전체(머리카락 포함)를 포함하고 머리 전체 모델을 훈련시키려는 경우 머" +"리 중심 맞추기만 선택해야 합니다.\n" +"L|레거시: '원래' 추출 기술입니다. 마스크를 코 근처 중앙에 배치하고 얼굴에 가" +"깝게 자릅니다. 마스크 가장자리가 훈련 영역 외부에 나타날 수 있습니다." + +#: tools/mask/cli.py:181 +msgid "" +"Import only. The size, in pixels to internally store the mask at.\n" +"The default is 128 which is fine for nearly all usecases. Larger sizes will " +"result in larger alignments files and longer processing." +msgstr "" +"가져오기만. 마스크를 내부적으로 저장할 크기(픽셀)입니다.\n" +"기본값은 128이며 거의 모든 사용 사례에 적합합니다. 크기가 클수록 정렬 파일도 " +"커지고 처리 시간도 길어집니다." + +#: tools/mask/cli.py:189 tools/mask/cli.py:197 tools/mask/cli.py:211 +#: tools/mask/cli.py:225 tools/mask/cli.py:235 +msgid "output" +msgstr "출력" + +#: tools/mask/cli.py:191 +msgid "" +"Optional output location. If provided, a preview of the masks created will " +"be output in the given folder." +msgstr "" +"선택적 출력 위치. 만약 값이 제공된다면 생성된 마스크 미리 보기가 주어진 폴더" +"에 출력됩니다." + +#: tools/mask/cli.py:202 +msgid "" +"Apply gaussian blur to the mask output. Has the effect of smoothing the " +"edges of the mask giving less of a hard edge. the size is in pixels. This " +"value should be odd, if an even number is passed in then it will be rounded " +"to the next odd number. NB: Only effects the output preview. Set to 0 for off" +msgstr "" +"마스크 출력에 gaussian blur를 적용합니다. 마스크의 가장자리를 매끄럽게 하여 " +"단단한 가장자리를 덜 제공하는 효과가 있습니다. 크기는 픽셀 단위입니다. 이 값" +"은 홀수여야 하며 짝수가 전달되면 다음 홀수로 반올림됩니다. NB: 출력 미리 보기" +"에만 영향을 줍니다. 0으로 설정하면 꺼집니다" + +#: tools/mask/cli.py:216 +msgid "" +"Helps reduce 'blotchiness' on some masks by making light shades white and " +"dark shades black. Higher values will impact more of the mask. NB: Only " +"effects the output preview. Set to 0 for off" +msgstr "" +"밝은 색조를 흰색으로, 어두운 색조를 검은색으로 만들어 일부 마스크의 '흐림'을 " +"줄이는 데 도움이 됩니다. 값이 클수록 마스크에 더 많은 영향을 미칩니다. NB: 출" +"력 미리 보기에만 영향을 줍니다. 0으로 설정하면 꺼집니다" + +#: tools/mask/cli.py:227 +msgid "" +"R|How to format the output when processing is set to 'output'.\n" +"L|combined: The image contains the face/frame, face mask and masked face.\n" +"L|masked: Output the face/frame as rgba image with the face masked.\n" +"L|mask: Only output the mask as a single channel image." +msgstr "" +"R|처리가 'output'으로 설정되어 있을 때 출력을 구성하는 방법.\n" +"L|combined: 이미지에는 얼굴/프레임, 얼굴 마스크 및 마스크된 얼굴이 포함됩니" +"다.\n" +"L|masked: 마스크된 얼굴/프레임을 Rgba 이미지로 출력합니다.\n" +"L|mask: 마스크를 단일 채널 이미지로만 출력합니다." + +#: tools/mask/cli.py:237 +msgid "" +"R|Whether to output the whole frame or only the face box when using output " +"processing. Only has an effect when using frames as input." +msgstr "" +"R|출력 처리를 사용할 때 전체 프레임을 출력할지 또는 페이스 박스만 출력할지 여" +"부. 프레임을 입력으로 사용할 때만 효과가 있습니다." + +#~ msgid "" +#~ "R|Whether to update all masks in the alignments files, only those faces " +#~ "that do not already have a mask of the given `mask type` or just to " +#~ "output the masks to the `output` location.\n" +#~ "L|all: Update the mask for all faces in the alignments file.\n" +#~ "L|missing: Create a mask for all faces in the alignments file where a " +#~ "mask does not previously exist.\n" +#~ "L|output: Don't update the masks, just output them for review in the " +#~ "given output folder." +#~ msgstr "" +#~ "R|alignments 파일의 모든 마스크를 업데이트할지, 지정된 '마스크 유형'의 마" +#~ "스크가 아직 없는 페이스만 업데이트할지, 아니면 단순히 '출력' 위치로 마스크" +#~ "를 출력할지 여부.\n" +#~ "L|all: alignments 파일의 모든 얼굴에 대한 마스크를 업데이트합니다.\n" +#~ "L|missing: 마스크가 없었던 alignments 파일의 모든 얼굴에 대한 마스크를 만" +#~ "듭니다.\n" +#~ "L|output: 마스크를 업데이트하지 말고 지정된 출력 폴더에서 검토할 수 있도" +#~ "록 출력하십시오." + +#~ msgid "" +#~ "Full path to the alignments file to add the mask to. NB: if the mask " +#~ "already exists in the alignments file it will be overwritten." +#~ msgstr "" +#~ "마스크를 추가할 alignments 파일의 전체 경로입니다. 주의: alignments 파일" +#~ "에 마스크가 이미 있으면 alignments 파일이 덮어 씌워집니다." diff --git a/locales/kr/LC_MESSAGES/tools.model.cli.mo b/locales/kr/LC_MESSAGES/tools.model.cli.mo new file mode 100644 index 0000000000..9ccdfde6dc Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.model.cli.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.model.cli.po b/locales/kr/LC_MESSAGES/tools.model.cli.po new file mode 100644 index 0000000000..524943a5c0 --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.model.cli.po @@ -0,0 +1,82 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:51+0000\n" +"PO-Revision-Date: 2024-03-29 00:05+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/model/cli.py:13 +msgid "This tool lets you perform actions on saved Faceswap models." +msgstr "" +"이 도구를 사용하여 저장된 Faceswap 모델에서 작업을 수행할 수 있습니다." + +#: tools/model/cli.py:22 +msgid "A tool for performing actions on Faceswap trained model files" +msgstr "_Faceswap 훈련을 받은 모델 파일에서 작업을 수행하기 위한 도구" + +#: tools/model/cli.py:34 +msgid "" +"Model directory. A directory containing the model you wish to perform an " +"action on." +msgstr "모델 디렉토리. 작업을 수행할 모델이 들어 있는 디렉토리입니다." + +#: tools/model/cli.py:43 +msgid "" +"R|Choose which action you want to perform.\n" +"L|'inference' - Create an inference only copy of the model. Strips any " +"layers from the model which are only required for training. NB: This is for " +"exporting the model for use in external applications. Inference generated " +"models cannot be used within Faceswap. See the 'format' option for " +"specifying the model output format.\n" +"L|'nan-scan' - Scan the model file for NaNs or Infs (invalid data).\n" +"L|'restore' - Restore a model from backup." +msgstr "" +"R|실행할 작업을 선택합니다.\n" +"L|'inference' - 모델의 추론 전용 사본을 만듭니다. 모델에서 훈련에만 필요한 " +"모든 레이어를 제거합니다. NB: 이것은 외부 응용 프로그램에서 사용하기 위해 모" +"델을 내보내기 위한 것입니다. 추론 생성 모델은 Faceswap 내에서 사용할 수 없습" +"니다. 모델 출력 형식을 지정하려면 'format' 옵션을 참조하십시오.\n" +"L|'nan-scan' - 모델 파일에서 NaN 또는 Infs(잘못된 데이터)를 검색합니다.\n" +"L|'restore' - 백업에서 모델을 복원합니다." + +#: tools/model/cli.py:57 tools/model/cli.py:69 +msgid "inference" +msgstr "추론" + +#: tools/model/cli.py:59 +msgid "" +"R|The format to save the model as. Note: Only used for 'inference' job.\n" +"L|'h5' - Standard Keras H5 format. Does not store any custom layer " +"information. Layers will need to be loaded from Faceswap to use.\n" +"L|'saved-model' - Tensorflow's Saved Model format. Contains all information " +"required to load the model outside of Faceswap." +msgstr "" +"R|모델을 저장할 형식입니다. 참고: '추론' 작업에만 사용됩니다.\n" +"L|'h5' - 표준 케라스 H5 형식. 사용자 지정 레이어 정보를 저장하지 않습니다. " +"사용하려면 Faceswap에서 레이어를 로드해야 합니다.\n" +"L| 'saved-model' - 텐서플로의 저장된 모델 형식. Faceswap 외부에서 모델을 로" +"드하는 데 필요한 모든 정보를 포함합니다." + +#: tools/model/cli.py:71 +#, fuzzy +#| msgid "" +#| "Only used for 'inference' job. Generate the inference model for B -> A " +#| "instead of A -> B." +msgid "" +"Only used for 'inference' job. Generate the inference model for B -> A " +"instead of A -> B." +msgstr "" +"'추론' 작업에만 쓰입니다. A -> B 대신 B -> A에 대한 추론 모델을 생성합니다." diff --git a/locales/kr/LC_MESSAGES/tools.preview.mo b/locales/kr/LC_MESSAGES/tools.preview.mo new file mode 100644 index 0000000000..13d6841ba1 Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.preview.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.preview.po b/locales/kr/LC_MESSAGES/tools.preview.po new file mode 100644 index 0000000000..03e8b6631f --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.preview.po @@ -0,0 +1,87 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: 2024-03-29 00:04+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/preview/cli.py:15 +msgid "This command allows you to preview swaps to tweak convert settings." +msgstr "" +"이 명령어는 변환 설정을 변경하기 위한 변환 미리보기를 가능하게 해줍니다." + +#: tools/preview/cli.py:30 +msgid "" +"Preview tool\n" +"Allows you to configure your convert settings with a live preview" +msgstr "" +"미리보기 도구\n" +"라이브로 미리보기를 보면서 변환 설정을 구성할 수 있도록 해줍니다" + +#: tools/preview/cli.py:47 tools/preview/cli.py:57 tools/preview/cli.py:65 +msgid "data" +msgstr "데이터" + +#: tools/preview/cli.py:50 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file." +msgstr "" +"입력 디렉토리 또는 비디오. 처리할 이미지 파일이 들어 있는 디렉토리 또는 비디" +"오 파일의 경로입니다." + +#: tools/preview/cli.py:60 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "입력 alignments 파일의 경로, 만약 제공되지 않는다면 기본 위치" + +#: tools/preview/cli.py:68 +msgid "" +"Model directory. A directory containing the trained model you wish to " +"process." +msgstr "" +"모델 디렉토리. 사용자가 처리하고 싶어하는 훈련된 모델이 있는 디렉토리." + +#: tools/preview/cli.py:74 +msgid "Swap the model. Instead of A -> B, swap B -> A" +msgstr "모델을 스왑함. A -> B 대신, B -> A로 스왑함" + +#: tools/preview/control_panels.py:510 +msgid "Save full config" +msgstr "전체 설정을 저장" + +#: tools/preview/control_panels.py:513 +msgid "Reset full config to default values" +msgstr "전체 설정을 기본 값으로 초기화" + +#: tools/preview/control_panels.py:516 +msgid "Reset full config to saved values" +msgstr "전체 설정을 저장된 값으로 초기화" + +#: tools/preview/control_panels.py:667 +#, python-brace-format +msgid "Save {title} config" +msgstr "{title} 설정 저장" + +#: tools/preview/control_panels.py:670 +#, python-brace-format +msgid "Reset {title} config to default values" +msgstr "{title} 설정을 기본 값으로 초기화" + +#: tools/preview/control_panels.py:673 +#, python-brace-format +msgid "Reset {title} config to saved values" +msgstr "{title} 설정을 저장된 값으로 초기화" diff --git a/locales/kr/LC_MESSAGES/tools.sort.cli.mo b/locales/kr/LC_MESSAGES/tools.sort.cli.mo new file mode 100644 index 0000000000..39509c675d Binary files /dev/null and b/locales/kr/LC_MESSAGES/tools.sort.cli.mo differ diff --git a/locales/kr/LC_MESSAGES/tools.sort.cli.po b/locales/kr/LC_MESSAGES/tools.sort.cli.po new file mode 100644 index 0000000000..19d99c1628 --- /dev/null +++ b/locales/kr/LC_MESSAGES/tools.sort.cli.po @@ -0,0 +1,388 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: 2024-03-29 00:04+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ko_KR\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/sort/cli.py:15 +msgid "This command lets you sort images using various methods." +msgstr "이 명령어는 다양한 메소드를 이용하여 이미지를 정렬해줍니다." + +#: tools/sort/cli.py:21 +msgid "" +" Adjust the '-t' ('--threshold') parameter to control the strength of " +"grouping." +msgstr " 그룹화의 강도를 제어하기 위해 '-t' ('--threshold') 인자를 조정하세요." + +#: tools/sort/cli.py:22 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the percentage of color pixels " +"that appear in the image." +msgstr "" +" '-b'('--bins') 매개 변수를 조정하여 그룹화할 bins의 수를 제어합니다. 각 이미" +"지는 이미지에 나타나는 색상 픽셀의 백분율에 따라 bin에 할당됩니다." + +#: tools/sort/cli.py:25 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the number of degrees the face " +"is orientated from center." +msgstr "" +" '-b'('--bins') 매개 변수를 조정하여 그룹화할 bins의 수를 제어합니다. 각 이미" +"지는 얼굴이 이미지 중심에서 떨어진 각도에 따라 bin에 할당됩니다." + +#: tools/sort/cli.py:28 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. The minimum and maximum values are taken for the chosen sort " +"metric. The bins are then populated with the results from the group sorting." +msgstr "" +" '-b'('--bins') 매개 변수를 조정하여 그룹화할 bins의 수를 제어합니다. 선택한 " +"정렬 방법에 대해 최소값과 최대값이 사용됩니다. 그런 다음 bins가 그룹 정렬의 " +"결과로 채워집니다." + +#: tools/sort/cli.py:32 +msgid "faces by blurriness." +msgstr "흐릿한 얼굴." + +#: tools/sort/cli.py:33 +msgid "faces by fft filtered blurriness." +msgstr "fft 필터링된 흐릿한 얼굴." + +#: tools/sort/cli.py:34 +msgid "" +"faces by the estimated distance of the alignments from an 'average' face. " +"This can be useful for eliminating misaligned faces. Sorts from most like an " +"average face to least like an average face." +msgstr "" +"'평균' 얼굴에서 alignments의 추정 거리를 기준으로 하는 얼굴. 이는 잘못 정렬" +"된 얼굴을 제거하는 데 유용할 수 있습니다. 가장 평균 얼굴에서 가장 덜 평균 얼" +"굴순으로 정렬합니다." + +#: tools/sort/cli.py:37 +msgid "" +"faces using VGG Face2 by face similarity. This uses a pairwise clustering " +"algorithm to check the distances between 512 features on every face in your " +"set and order them appropriately." +msgstr "" +"얼굴 유사성에 따라 VGG Face2를 사용하는 얼굴. 이 알고리즘은 쌍별 클러스터링 " +"알고리즘을 사용하여 세트의 모든 얼굴에서 512개의 특징 사이의 거리를 확인하고 " +"적절하게 정렬합니다." + +#: tools/sort/cli.py:40 +msgid "faces by their landmarks." +msgstr "특징점이 있는 얼굴." + +#: tools/sort/cli.py:41 +msgid "Like 'face-cnn' but sorts by dissimilarity." +msgstr "'face-cnn'과 비슷하지만 비유사성에 따라 정렬된." + +#: tools/sort/cli.py:42 +msgid "faces by Yaw (rotation left to right)." +msgstr "yaw (왼쪽에서 오른쪽으로 회전)에 의한 얼굴." + +#: tools/sort/cli.py:43 +msgid "faces by Pitch (rotation up and down)." +msgstr "pitch (위에서 아래로 회전)에 의한 얼굴." + +#: tools/sort/cli.py:44 +msgid "" +"faces by Roll (rotation). Aligned faces should have a roll value close to " +"zero. The further the Roll value from zero the higher liklihood the face is " +"misaligned." +msgstr "" +"이동 (회전)에 의한 얼굴. 정렬된 얼굴들은 0에 가까운 이동 값을 가져야 한다. 이" +"동 값이 0에서 멀수록 얼굴들이 잘못 정렬되었을 가능성이 높습니다." + +#: tools/sort/cli.py:46 +msgid "faces by their color histogram." +msgstr "색상 히스토그램에 의한 얼굴." + +#: tools/sort/cli.py:47 +msgid "Like 'hist' but sorts by dissimilarity." +msgstr "'hist' 같지만 비유사성에 따라 정렬된." + +#: tools/sort/cli.py:48 +msgid "" +"images by the average intensity of the converted grayscale color channel." +msgstr "변환된 회색 계열 색상 채널의 평균 강도에 따른 이미지." + +#: tools/sort/cli.py:49 +msgid "" +"images by their number of black pixels. Useful when faces are near borders " +"and a large part of the image is black." +msgstr "" +"검은색 픽셀의 개수에 따른 이미지들. 얼굴이 테두리 근처에 있고 이미지의 대부분" +"이 검은색일 때 유용합니다." + +#: tools/sort/cli.py:51 +msgid "" +"images by the average intensity of the converted Y color channel. Bright " +"lighting and oversaturated images will be ranked first." +msgstr "" +"변환된 Y 색상 채널의 평균 강도를 기준으로 한 이미지. 밝은 조명과 과포화 이미" +"지가 1위를 차지할 것이다." + +#: tools/sort/cli.py:53 +msgid "" +"images by the average intensity of the converted Cg color channel. Green " +"images will be ranked first and red images will be last." +msgstr "" +"변환된 Cg 컬러 채널의 평균 강도를 기준으로 한 이미지. 녹색 이미지가 먼저 순위" +"가 매겨지고 빨간색 이미지가 마지막 순위가 됩니다." + +#: tools/sort/cli.py:55 +msgid "" +"images by the average intensity of the converted Co color channel. Orange " +"images will be ranked first and blue images will be last." +msgstr "" +"변환된 Co 색상 채널의 평균 강도를 기준으로 한 이미지. 주황색 이미지가 먼저 순" +"위가 매겨지고 파란색 이미지가 마지막 순위가 됩니다." + +#: tools/sort/cli.py:57 +msgid "" +"images by their size in the original frame. Faces further from the camera " +"and from lower resolution sources will be sorted first, whilst faces closer " +"to the camera and from higher resolution sources will be sorted last." +msgstr "" +"이미지를 원래 프레임의 크기별로 표시합니다. 카메라에서 더 멀리 떨어져 있고 저" +"해상도 원본에서 온 얼굴이 먼저 정렬되고, 카메라에 더 가까이 있고 고해상도 원" +"본에서 온 얼굴이 마지막으로 정렬됩니다." + +#: tools/sort/cli.py:81 +msgid "Sort faces using a number of different techniques" +msgstr "얼굴을 정렬하는데 사용되는 서로 다른 기술들의 개수" + +#: tools/sort/cli.py:91 tools/sort/cli.py:98 tools/sort/cli.py:110 +#: tools/sort/cli.py:150 +msgid "data" +msgstr "데이터" + +#: tools/sort/cli.py:92 +msgid "Input directory of aligned faces." +msgstr "정렬된 얼굴들의 입력 디렉토리." + +#: tools/sort/cli.py:100 +msgid "" +"Output directory for sorted aligned faces. If not provided and 'keep' is " +"selected then a new folder called 'sorted' will be created within the input " +"folder to house the output. If not provided and 'keep' is not selected then " +"the images will be sorted in-place, overwriting the original contents of the " +"'input_dir'" +msgstr "" +"정렬된 aligned 얼굴의 출력 디렉토리입니다. 제공되지 않은 상태에서 'keep'을 선" +"택하면 출력을 저장하기 위해 입력 폴더 내에 'sorted'라는 새 폴더가 생성됩니" +"다. 제공되지 않고 'keep'을 선택하지 않으면 이미지가 제자리에 정렬되어 " +"'input_dir'의 원래 내용을 덮어씁니다." + +#: tools/sort/cli.py:112 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple folders of faces you wish to sort. The faces will be output to " +"separate sub-folders in the output_dir" +msgstr "" +"R|선택되면 input_dir는 정렬할 여러 개의 얼굴 폴더를 포함하는 상위 폴더여야 합" +"니다. 얼굴은 output_dir의 별도 하위 폴더로 출력됩니다" + +#: tools/sort/cli.py:121 +msgid "sort settings" +msgstr "정렬 설정" + +#: tools/sort/cli.py:124 +msgid "" +"R|Choose how images are sorted. Selecting a sort method gives the images a " +"new filename based on the order the image appears within the given method.\n" +"L|'none': Don't sort the images. When a 'group-by' method is selected, " +"selecting 'none' means that the files will be moved/copied into their " +"respective bins, but the files will keep their original filenames. Selecting " +"'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" +"R|이미지 정렬 방법을 선택합니다. 정렬 방법을 선택하면 이미지가 주어진 방법 내" +"에 나타나는 순서에 따라 이미지에 새 파일 이름이 지정됩니다.\n" +"L|'none': 이미지를 정렬하지 않습니다. 'group-by' 메서드를 선택한 경우 " +"'none'을 선택하면 파일이 각 bin으로 이동/복사되지만 파일은 원래 파일 이름을 " +"유지합니다. 'sort-by' 및 'group-by' 모두에 대해 'none'을 선택해도 아무 효과" +"가 없습니다" + +#: tools/sort/cli.py:136 tools/sort/cli.py:164 tools/sort/cli.py:184 +msgid "group settings" +msgstr "그룹 설정" + +#: tools/sort/cli.py:139 +#, fuzzy +#| msgid "" +#| "R|Selecting a group by method will move/copy files into numbered bins " +#| "based on the selected method.\n" +#| "L|'none': Don't bin the images. Folders will be sorted by the selected " +#| "'sort-by' but will not be binned, instead they will be sorted into a " +#| "single folder. Selecting 'none' for both 'sort-by' and 'group-by' will " +#| "do nothing" +msgid "" +"R|Selecting a group by method will move/copy files into numbered bins based " +"on the selected method.\n" +"L|'none': Don't bin the images. Folders will be sorted by the selected 'sort-" +"by' but will not be binned, instead they will be sorted into a single " +"folder. Selecting 'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" +"R|방법별로 그룹을 선택하면 선택한 방법에 따라 파일이 번호가 매겨진 빈으로 이" +"동/복사됩니다.\n" +"L|'none': 이미지를 버리지 않습니다. 폴더는 선택한 '정렬 기준'에 따라 정렬되지" +"만 버려지진 않고 단일 폴더로 정렬됩니다. 'sort-by' 및 'group-by' 모두에 대해 " +"'none'을 선택해도 아무 효과가 없습니다" + +#: tools/sort/cli.py:152 +msgid "" +"Whether to keep the original files in their original location. Choosing a " +"'sort-by' method means that the files have to be renamed. Selecting 'keep' " +"means that the original files will be kept, and the renamed files will be " +"created in the specified output folder. Unselecting keep means that the " +"original files will be moved and renamed based on the selected sort/group " +"criteria." +msgstr "" +"원본 파일을 원래 위치에 유지할지 여부입니다. '정렬 기준' 방법을 선택하면 파" +"일 이름을 변경해야 합니다. 'keep'을 선택하면 원래 파일이 유지되고 이름이 변경" +"된 파일이 지정된 출력 폴더에 생성됩니다. keep을 선택취소하면 선택한 정렬/그" +"룹 기준에 따라 원래 파일이 이동되고 이름이 변경됩니다." + +#: tools/sort/cli.py:167 +msgid "" +"R|Float value. Minimum threshold to use for grouping comparison with 'face-" +"cnn' 'hist' and 'face' methods.\n" +"The lower the value the more discriminating the grouping is. Leaving -1.0 " +"will allow Faceswap to choose the default value.\n" +"L|For 'face-cnn' 7.2 should be enough, with 4 being very discriminating. \n" +"L|For 'hist' 0.3 should be enough, with 0.2 being very discriminating. \n" +"L|For 'face' between 0.1 (more bins) to 0.5 (fewer bins) should be about " +"right.\n" +"Be careful setting a value that's too extrene in a directory with many " +"images, as this could result in a lot of folders being created. Defaults: " +"face-cnn 7.2, hist 0.3, face 0.25" +msgstr "" +"R|float 값. 'face-cnn', 'hist' 및 'face' 메서드와의 그룹 비교에 사용할 최소 " +"임계값입니다.\n" +"값이 낮을수록 그룹을 더 잘 구별할 수 있습니다. -1.0을 그대로 두면 Faceswap에" +"서 기본값을 선택할 수 있습니다.\n" +"L|'face-cnn'의 경우 7.2이면 충분하며, 4는 매우 많이 구별된다. \n" +"L|'hist'의 경우 0.3이면 충분하며, 0.2는 매우 많이 구별된다. \n" +"L|0.1(더 많은 빈)에서 0.5(더 적은 빈) 사이의 '얼굴'의 경우는 거의 오른쪽이어" +"야 합니다.\n" +"이미지가 많은 디렉터리에서 너무 극단적인 값을 설정하면 폴더가 많이 생성될 수 " +"있으므로 주의하십시오. 기본값: face-cnn 7.2, hist 0.3, face 0.25" + +#: tools/sort/cli.py:187 +#, fuzzy, python-format +#| msgid "" +#| "R|Integer value. Used to control the number of bins created for grouping " +#| "by: any 'blur' methods, 'color' methods or 'face metric' methods " +#| "('distance', 'size') and 'orientation; methods ('yaw', 'pitch'). For any " +#| "other grouping methods see the '-t' ('--threshold') option.\n" +#| "L|For 'face metric' methods the bins are filled, according the the " +#| "distribution of faces between the minimum and maximum chosen metric.\n" +#| "L|For 'color' methods the number of bins represents the divider of the " +#| "percentage of colored pixels. Eg. For a bin number of '5': The first " +#| "folder will have the faces with 0%% to 20%% colored pixels, second 21%% " +#| "to 40%%, etc. Any empty bins will be deleted, so you may end up with " +#| "fewer bins than selected.\n" +#| "L|For 'blur' methods folder 0 will be the least blurry, while the last " +#| "folder will be the blurriest.\n" +#| "L|For 'orientation' methods the number of bins is dictated by how much " +#| "180 degrees is divided. Eg. If 18 is selected, then each folder will be a " +#| "10 degree increment. Folder 0 will contain faces looking the most to the " +#| "left/down whereas the last folder will contain the faces looking the most " +#| "to the right/up. NB: Some bins may be empty if faces do not fit the " +#| "criteria.\n" +#| "Default value: 5" +msgid "" +"R|Integer value. Used to control the number of bins created for grouping by: " +"any 'blur' methods, 'color' methods or 'face metric' methods ('distance', " +"'size') and 'orientation; methods ('yaw', 'pitch'). For any other grouping " +"methods see the '-t' ('--threshold') option.\n" +"L|For 'face metric' methods the bins are filled, according the the " +"distribution of faces between the minimum and maximum chosen metric.\n" +"L|For 'color' methods the number of bins represents the divider of the " +"percentage of colored pixels. Eg. For a bin number of '5': The first folder " +"will have the faces with 0%% to 20%% colored pixels, second 21%% to 40%%, " +"etc. Any empty bins will be deleted, so you may end up with fewer bins than " +"selected.\n" +"L|For 'blur' methods folder 0 will be the least blurry, while the last " +"folder will be the blurriest.\n" +"L|For 'orientation' methods the number of bins is dictated by how much 180 " +"degrees is divided. Eg. If 18 is selected, then each folder will be a 10 " +"degree increment. Folder 0 will contain faces looking the most to the left/" +"down whereas the last folder will contain the faces looking the most to the " +"right/up. NB: Some bins may be empty if faces do not fit the criteria. \n" +"Default value: 5" +msgstr "" +"R| 정수 값. 그룹화를 위해 생성된 bins의 수를 제어하는 데 사용됩니다. 임의의 " +"'blur' 방법, 'color' 방법 또는 'face metric' 방법('거리', '크기'), " +"'orientation' 방법('yaw', 'pitch'). 다른 그룹화 방법은 '-t'('--임계값') 옵션" +"을 참조하십시오.\n" +"L|'face metric' 방법의 경우 선택한 최소 메트릭과 최대 메트릭 사이의 얼굴 분포" +"에 따라 bins가 채워집니다.\n" +"L|'color' 방법의 경우 bins의 수는 색상 픽셀의 백분율을 나눈 값을 나타냅니다. " +"예: bin 번호가 '5'인 경우: 첫 번째 폴더는 0%%에서 20%%의 색상 픽셀을 가진 얼" +"굴을 가질 것이고, 두 번째는 21%%에서 40%% 등을 가질 것이다. 텅 빈 bins는 삭제" +"되므로 선택한 bins보다 더 적은 bins을 가질 수 있습니다.\n" +"L|'blur' 메서드의 경우 폴더 0이 가장 흐림이 적으며 마지막 폴더가 가장 흐림이 " +"많습니다.\n" +"L|'orientation' 방법의 경우 bins의 수는 180도를 얼마나 나누느냐에 따라 결정됩" +"니다. 예: 18을 선택하면 각 폴더가 10도씩 증가합니다. 폴더 0은 왼쪽/아래쪽 얼" +"굴을 가장 많이 포함하는 반면, 마지막 폴더는 오른쪽/위 얼굴을 가장 많이 포함합" +"니다. 주의: 얼굴이 기준에 맞지 않으면 일부 bins가 비어 있을 수 있습니다.\n" +"기본값: 5" + +#: tools/sort/cli.py:207 tools/sort/cli.py:217 +msgid "settings" +msgstr "설정" + +#: tools/sort/cli.py:210 +msgid "" +"Logs file renaming changes if grouping by renaming, or it logs the file " +"copying/movement if grouping by folders. If no log file is specified with " +"'--log-file', then a 'sort_log.json' file will be created in the input " +"directory." +msgstr "" +"만약 renaming별로 그룹화하면 로그 파일에서 renaming이 변경됩니다. 또는 폴더별" +"로 그룹화하는 경우 파일 복사/이동을 기록합니다. '--log-file'로 로그 파일을 지" +"정하지 않으면 'sort_log.json' 파일이 입력 디렉토리에 생성됩니다." + +#: tools/sort/cli.py:221 +msgid "" +"Specify a log file to use for saving the renaming or grouping information. " +"If specified extension isn't 'json' or 'yaml', then json will be used as the " +"serializer, with the supplied filename. Default: sort_log.json" +msgstr "" +"_renaming 또는 grouping 정보를 저장하는 데 사용할 로그 파일을 지정합니다. 지" +"정된 확장자가 'json' 또는 'yaml'이 아니면 json이 제공된 파일 이름과 함께 직렬" +"화기로 사용됩니다. 기본값: sort_log.json" + +#~ msgid " option is deprecated. Use 'yaw'" +#~ msgstr " 이 옵션은 더 이상 사용되지 않습니다. 'yaw'를 사용하세요" + +#~ msgid " option is deprecated. Use 'color-black'" +#~ msgstr " 이 옵션은 더 이상 사용되지 않습니다. 'color-black'을 사용하세요" + +#~ msgid "output" +#~ msgstr "출력" + +#~ msgid "" +#~ "Deprecated and no longer used. The final processing will be dictated by " +#~ "the sort/group by methods and whether 'keep_original' is selected." +#~ msgstr "" +#~ "폐기되었고 더 이상 사용되지 않습니다. 최종 처리는 sort/group-by 메서드와 " +#~ "'keep_original'이 선택되었는지 여부에 의해 결정됩니다." diff --git a/locales/lib.cli.args.pot b/locales/lib.cli.args.pot new file mode 100644 index 0000000000..03a77c5a74 --- /dev/null +++ b/locales/lib.cli.args.pot @@ -0,0 +1,50 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 18:06+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: lib/cli/args.py:188 lib/cli/args.py:199 lib/cli/args.py:208 +#: lib/cli/args.py:219 +msgid "Global Options" +msgstr "" + +#: lib/cli/args.py:190 +msgid "" +"R|Exclude GPUs from use by Faceswap. Select the number(s) which correspond " +"to any GPU(s) that you do not wish to be made available to Faceswap. " +"Selecting all GPUs here will force Faceswap into CPU mode.\n" +"L|{}" +msgstr "" + +#: lib/cli/args.py:201 +msgid "" +"Optionally overide the saved config with the path to a custom config file." +msgstr "" + +#: lib/cli/args.py:210 +msgid "" +"Log level. Stick with INFO or VERBOSE unless you need to file an error " +"report. Be careful with TRACE as it will generate a lot of data" +msgstr "" + +#: lib/cli/args.py:220 +msgid "Path to store the logfile. Leave blank to store in the faceswap folder" +msgstr "" + +#: lib/cli/args.py:319 +msgid "Output to Shell console instead of GUI console" +msgstr "" diff --git a/locales/lib.cli.args_extract_convert.pot b/locales/lib.cli.args_extract_convert.pot new file mode 100644 index 0000000000..d650f81997 --- /dev/null +++ b/locales/lib.cli.args_extract_convert.pot @@ -0,0 +1,455 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-12 11:56+0100\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: lib/cli/args_extract_convert.py:46 lib/cli/args_extract_convert.py:56 +#: lib/cli/args_extract_convert.py:64 lib/cli/args_extract_convert.py:122 +#: lib/cli/args_extract_convert.py:483 lib/cli/args_extract_convert.py:492 +msgid "Data" +msgstr "" + +#: lib/cli/args_extract_convert.py:48 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file. NB: This should be the source video/" +"frames NOT the source faces." +msgstr "" + +#: lib/cli/args_extract_convert.py:57 +msgid "Output directory. This is where the converted files will be saved." +msgstr "" + +#: lib/cli/args_extract_convert.py:66 +msgid "" +"Optional path to an alignments file. Leave blank if the alignments file is " +"at the default location." +msgstr "" + +#: lib/cli/args_extract_convert.py:97 +msgid "" +"Extract faces from image or video sources.\n" +"Extraction plugins can be configured in the 'Settings' Menu" +msgstr "" + +#: lib/cli/args_extract_convert.py:124 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple videos and/or folders of images you wish to extract from. The faces " +"will be output to separate sub-folders in the output_dir." +msgstr "" + +#: lib/cli/args_extract_convert.py:133 lib/cli/args_extract_convert.py:152 +#: lib/cli/args_extract_convert.py:167 lib/cli/args_extract_convert.py:206 +#: lib/cli/args_extract_convert.py:224 lib/cli/args_extract_convert.py:237 +#: lib/cli/args_extract_convert.py:247 lib/cli/args_extract_convert.py:257 +#: lib/cli/args_extract_convert.py:503 lib/cli/args_extract_convert.py:529 +#: lib/cli/args_extract_convert.py:568 +msgid "Plugins" +msgstr "" + +#: lib/cli/args_extract_convert.py:135 +msgid "" +"R|Detector to use. Some of these have configurable settings in '/config/" +"extract.ini' or 'Settings > Configure Extract 'Plugins':\n" +"L|cv2-dnn: A CPU only extractor which is the least reliable and least " +"resource intensive. Use this if not using a GPU and time is important.\n" +"L|mtcnn: Good detector. Fast on CPU, faster on GPU. Uses fewer resources " +"than other GPU detectors but can often return more false positives.\n" +"L|s3fd: Best detector. Slow on CPU, faster on GPU. Can detect more faces and " +"fewer false positives than other GPU detectors, but is a lot more resource " +"intensive.\n" +"L|external: Import a face detection bounding box from a json file. " +"(configurable in Detect settings)" +msgstr "" + +#: lib/cli/args_extract_convert.py:154 +msgid "" +"R|Aligner to use.\n" +"L|cv2-dnn: A CPU only landmark detector. Faster, less resource intensive, " +"but less accurate. Only use this if not using a GPU and time is important.\n" +"L|fan: Best aligner. Fast on GPU, slow on CPU.\n" +"L|external: Import 68 point 2D landmarks or an aligned bounding box from a " +"json file. (configurable in Align settings)" +msgstr "" + +#: lib/cli/args_extract_convert.py:169 +msgid "" +"R|Additional Masker(s) to use. The masks generated here will all take up GPU " +"RAM. You can select none, one or multiple masks, but the extraction may take " +"longer the more you select. NB: The Extended and Components (landmark based) " +"masks are automatically generated on extraction.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU so will not use any additional VRAM.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"The auto generated masks are as follows:\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" +msgstr "" + +#: lib/cli/args_extract_convert.py:208 +msgid "" +"R|Performing normalization can help the aligner better align faces with " +"difficult lighting conditions at an extraction speed cost. Different methods " +"will yield different results on different sets. NB: This does not impact the " +"output face, just the input to the aligner.\n" +"L|none: Don't perform normalization on the face.\n" +"L|clahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"L|hist: Equalize the histograms on the RGB channels.\n" +"L|mean: Normalize the face colors to the mean." +msgstr "" + +#: lib/cli/args_extract_convert.py:226 +msgid "" +"The number of times to re-feed the detected face into the aligner. Each time " +"the face is re-fed into the aligner the bounding box is adjusted by a small " +"amount. The final landmarks are then averaged from each iteration. Helps to " +"remove 'micro-jitter' but at the cost of slower extraction speed. The more " +"times the face is re-fed into the aligner, the less micro-jitter should " +"occur but the longer extraction will take." +msgstr "" + +#: lib/cli/args_extract_convert.py:239 +msgid "" +"Re-feed the initially found aligned face through the aligner. Can help " +"produce better alignments for faces that are rotated beyond 45 degrees in " +"the frame or are at extreme angles. Slows down extraction." +msgstr "" + +#: lib/cli/args_extract_convert.py:249 +msgid "" +"If a face isn't found, rotate the images to try to find a face. Can find " +"more faces at the cost of extraction speed. Pass in a single number to use " +"increments of that size up to 360, or pass in a list of numbers to enumerate " +"exactly what angles to check." +msgstr "" + +#: lib/cli/args_extract_convert.py:259 +msgid "" +"Obtain and store face identity encodings from VGGFace2. Slows down extract a " +"little, but will save time if using 'sort by face'" +msgstr "" + +#: lib/cli/args_extract_convert.py:269 lib/cli/args_extract_convert.py:280 +#: lib/cli/args_extract_convert.py:293 lib/cli/args_extract_convert.py:307 +#: lib/cli/args_extract_convert.py:614 lib/cli/args_extract_convert.py:623 +#: lib/cli/args_extract_convert.py:638 lib/cli/args_extract_convert.py:651 +#: lib/cli/args_extract_convert.py:665 +msgid "Face Processing" +msgstr "" + +#: lib/cli/args_extract_convert.py:271 +msgid "" +"Filters out faces detected below this size. Length, in pixels across the " +"diagonal of the bounding box. Set to 0 for off" +msgstr "" + +#: lib/cli/args_extract_convert.py:282 +msgid "" +"Optionally filter out people who you do not wish to extract by passing in " +"images of those people. Should be a small variety of images at different " +"angles and in different conditions. A folder containing the required images " +"or multiple image files, space separated, can be selected." +msgstr "" + +#: lib/cli/args_extract_convert.py:295 +msgid "" +"Optionally select people you wish to extract by passing in images of that " +"person. Should be a small variety of images at different angles and in " +"different conditions A folder containing the required images or multiple " +"image files, space separated, can be selected." +msgstr "" + +#: lib/cli/args_extract_convert.py:309 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Higher values are stricter." +msgstr "" + +#: lib/cli/args_extract_convert.py:318 lib/cli/args_extract_convert.py:331 +#: lib/cli/args_extract_convert.py:344 lib/cli/args_extract_convert.py:356 +msgid "output" +msgstr "" + +#: lib/cli/args_extract_convert.py:320 +msgid "" +"The output size of extracted faces. Make sure that the model you intend to " +"train supports your required size. This will only need to be changed for hi-" +"res models." +msgstr "" + +#: lib/cli/args_extract_convert.py:333 +msgid "" +"Extract every 'nth' frame. This option will skip frames when extracting " +"faces. For example a value of 1 will extract faces from every frame, a value " +"of 10 will extract faces from every 10th frame." +msgstr "" + +#: lib/cli/args_extract_convert.py:346 +msgid "" +"Automatically save the alignments file after a set amount of frames. By " +"default the alignments file is only saved at the end of the extraction " +"process. NB: If extracting in 2 passes then the alignments file will only " +"start to be saved out during the second pass. WARNING: Don't interrupt the " +"script when writing the file because it might get corrupted. Set to 0 to " +"turn off" +msgstr "" + +#: lib/cli/args_extract_convert.py:357 +msgid "Draw landmarks on the ouput faces for debugging purposes." +msgstr "" + +#: lib/cli/args_extract_convert.py:363 lib/cli/args_extract_convert.py:373 +#: lib/cli/args_extract_convert.py:381 lib/cli/args_extract_convert.py:388 +#: lib/cli/args_extract_convert.py:678 lib/cli/args_extract_convert.py:691 +#: lib/cli/args_extract_convert.py:712 lib/cli/args_extract_convert.py:718 +msgid "settings" +msgstr "" + +#: lib/cli/args_extract_convert.py:365 +msgid "" +"Don't run extraction in parallel. Will run each part of the extraction " +"process separately (one after the other) rather than all at the same time. " +"Useful if VRAM is at a premium." +msgstr "" + +#: lib/cli/args_extract_convert.py:375 +msgid "" +"Skips frames that have already been extracted and exist in the alignments " +"file" +msgstr "" + +#: lib/cli/args_extract_convert.py:382 +msgid "Skip frames that already have detected faces in the alignments file" +msgstr "" + +#: lib/cli/args_extract_convert.py:389 +msgid "Skip saving the detected faces to disk. Just create an alignments file" +msgstr "" + +#: lib/cli/args_extract_convert.py:463 +msgid "" +"Swap the original faces in a source video/images to your final faces.\n" +"Conversion plugins can be configured in the 'Settings' Menu" +msgstr "" + +#: lib/cli/args_extract_convert.py:485 +msgid "" +"Only required if converting from images to video. Provide The original video " +"that the source frames were extracted from (for extracting the fps and " +"audio)." +msgstr "" + +#: lib/cli/args_extract_convert.py:494 +msgid "" +"Model directory. The directory containing the trained model you wish to use " +"for conversion." +msgstr "" + +#: lib/cli/args_extract_convert.py:505 +msgid "" +"R|Performs color adjustment to the swapped face. Some of these options have " +"configurable settings in '/config/convert.ini' or 'Settings > Configure " +"Convert Plugins':\n" +"L|avg-color: Adjust the mean of each color channel in the swapped " +"reconstruction to equal the mean of the masked area in the original image.\n" +"L|color-transfer: Transfers the color distribution from the source to the " +"target image using the mean and standard deviations of the L*a*b* color " +"space.\n" +"L|manual-balance: Manually adjust the balance of the image in a variety of " +"color spaces. Best used with the Preview tool to set correct values.\n" +"L|match-hist: Adjust the histogram of each color channel in the swapped " +"reconstruction to equal the histogram of the masked area in the original " +"image.\n" +"L|seamless-clone: Use cv2's seamless clone function to remove extreme " +"gradients at the mask seam by smoothing colors. Generally does not give very " +"satisfactory results.\n" +"L|none: Don't perform color adjustment." +msgstr "" + +#: lib/cli/args_extract_convert.py:531 +msgid "" +"R|Masker to use. NB: The mask you require must exist within the alignments " +"file. You can add additional masks with the Mask Tool.\n" +"L|none: Don't use a mask.\n" +"L|bisenet-fp_face: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'face' or " +"'legacy' centering.\n" +"L|bisenet-fp_head: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'head' " +"centering.\n" +"L|custom_face: Custom user created, face centered mask.\n" +"L|custom_head: Custom user created, head centered mask.\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"L|predicted: If the 'Learn Mask' option was enabled during training, this " +"will use the mask that was created by the trained model." +msgstr "" + +#: lib/cli/args_extract_convert.py:570 +msgid "" +"R|The plugin to use to output the converted images. The writers are " +"configurable in '/config/convert.ini' or 'Settings > Configure Convert " +"Plugins:'\n" +"L|ffmpeg: [video] Writes out the convert straight to video. When the input " +"is a series of images then the '-ref' (--reference-video) parameter must be " +"set.\n" +"L|gif: [animated image] Create an animated gif.\n" +"L|opencv: [images] The fastest image writer, but less options and formats " +"than other plugins.\n" +"L|patch: [images] Outputs the raw swapped face patch, along with the " +"transformation matrix required to re-insert the face back into the original " +"frame. Use this option if you wish to post-process and composite the final " +"face within external tools.\n" +"L|pillow: [images] Slower than opencv, but has more options and supports " +"more formats." +msgstr "" + +#: lib/cli/args_extract_convert.py:591 lib/cli/args_extract_convert.py:600 +#: lib/cli/args_extract_convert.py:703 +msgid "Frame Processing" +msgstr "" + +#: lib/cli/args_extract_convert.py:593 +#, python-format +msgid "" +"Scale the final output frames by this amount. 100%% will output the frames " +"at source dimensions. 50%% at half size 200%% at double size" +msgstr "" + +#: lib/cli/args_extract_convert.py:602 +msgid "" +"Frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use " +"--frame-ranges 10-50 90-100. Frames falling outside of the selected range " +"will be discarded unless '-k' (--keep-unchanged) is selected. NB: If you are " +"converting from images, then the filenames must end with the frame-number!" +msgstr "" + +#: lib/cli/args_extract_convert.py:616 +msgid "" +"Scale the swapped face by this percentage. Positive values will enlarge the " +"face, Negative values will shrink the face." +msgstr "" + +#: lib/cli/args_extract_convert.py:625 +msgid "" +"If you have not cleansed your alignments file, then you can filter out faces " +"by defining a folder here that contains the faces extracted from your input " +"files/video. If this folder is defined, then only faces that exist within " +"your alignments file and also exist within the specified folder will be " +"converted. Leaving this blank will convert all faces that exist within the " +"alignments file." +msgstr "" + +#: lib/cli/args_extract_convert.py:640 +msgid "" +"Optionally filter out people who you do not wish to process by passing in an " +"image of that person. Should be a front portrait with a single person in the " +"image. Multiple images can be added space separated. NB: Using face filter " +"will significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" + +#: lib/cli/args_extract_convert.py:653 +msgid "" +"Optionally select people you wish to process by passing in an image of that " +"person. Should be a front portrait with a single person in the image. " +"Multiple images can be added space separated. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" + +#: lib/cli/args_extract_convert.py:667 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Lower values are stricter. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" + +#: lib/cli/args_extract_convert.py:680 +msgid "" +"The maximum number of parallel processes for performing conversion. " +"Converting images is system RAM heavy so it is possible to run out of memory " +"if you have a lot of processes and not enough RAM to accommodate them all. " +"Setting this to 0 will use the maximum available. No matter what you set " +"this to, it will never attempt to use more processes than are available on " +"your system. If singleprocess is enabled this setting will be ignored." +msgstr "" + +#: lib/cli/args_extract_convert.py:693 +msgid "" +"Enable On-The-Fly Conversion. NOT recommended. You should generate a clean " +"alignments file for your destination video. However, if you wish you can " +"generate the alignments on-the-fly by enabling this option. This will use an " +"inferior extraction pipeline and will lead to substandard results. If an " +"alignments file is found, this option will be ignored." +msgstr "" + +#: lib/cli/args_extract_convert.py:705 +msgid "" +"When used with --frame-ranges outputs the unchanged frames that are not " +"processed instead of discarding them." +msgstr "" + +#: lib/cli/args_extract_convert.py:713 +msgid "Swap the model. Instead converting from of A -> B, converts B -> A" +msgstr "" + +#: lib/cli/args_extract_convert.py:719 +msgid "Disable multiprocessing. Slower but less resource intensive." +msgstr "" diff --git a/locales/lib.cli.args_train.pot b/locales/lib.cli.args_train.pot new file mode 100644 index 0000000000..40a785e681 --- /dev/null +++ b/locales/lib.cli.args_train.pot @@ -0,0 +1,252 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-15 20:02+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: lib/cli/args_train.py:30 +msgid "" +"Train a model on extracted original (A) and swap (B) faces.\n" +"Training models can take a long time. Anything from 24hrs to over a week\n" +"Model plugins can be configured in the 'Settings' Menu" +msgstr "" + +#: lib/cli/args_train.py:49 lib/cli/args_train.py:58 +msgid "faces" +msgstr "" + +#: lib/cli/args_train.py:51 +msgid "" +"Input directory. A directory containing training images for face A. This is " +"the original face, i.e. the face that you want to remove and replace with " +"face B." +msgstr "" + +#: lib/cli/args_train.py:60 +msgid "" +"Input directory. A directory containing training images for face B. This is " +"the swap face, i.e. the face that you want to place onto the head of person " +"A." +msgstr "" + +#: lib/cli/args_train.py:67 lib/cli/args_train.py:80 lib/cli/args_train.py:97 +#: lib/cli/args_train.py:123 lib/cli/args_train.py:133 +msgid "model" +msgstr "" + +#: lib/cli/args_train.py:69 +msgid "" +"Model directory. This is where the training data will be stored. You should " +"always specify a new folder for new models. If starting a new model, select " +"either an empty folder, or a folder which does not exist (which will be " +"created). If continuing to train an existing model, specify the location of " +"the existing model." +msgstr "" + +#: lib/cli/args_train.py:82 +msgid "" +"R|Load the weights from a pre-existing model into a newly created model. For " +"most models this will load weights from the Encoder of the given model into " +"the encoder of the newly created model. Some plugins may have specific " +"configuration options allowing you to load weights from other layers. " +"Weights will only be loaded when creating a new model. This option will be " +"ignored if you are resuming an existing model. Generally you will also want " +"to 'freeze-weights' whilst the rest of your model catches up with your " +"Encoder.\n" +"NB: Weights can only be loaded from models of the same plugin as you intend " +"to train." +msgstr "" + +#: lib/cli/args_train.py:99 +msgid "" +"R|Select which trainer to use. Trainers can be configured from the Settings " +"menu or the config folder.\n" +"L|original: The original model created by /u/deepfakes.\n" +"L|dfaker: 64px in/128px out model from dfaker. Enable 'warp-to-landmarks' " +"for full dfaker method.\n" +"L|dfl-h128: 128px in/out model from deepfacelab\n" +"L|dfl-sae: Adaptable model from deepfacelab\n" +"L|dlight: A lightweight, high resolution DFaker variant.\n" +"L|iae: A model that uses intermediate layers to try to get better details\n" +"L|lightweight: A lightweight model for low-end cards. Don't expect great " +"results. Can train as low as 1.6GB with batch size 8.\n" +"L|realface: A high detail, dual density model based on DFaker, with " +"customizable in/out resolution. The autoencoders are unbalanced so B>A swaps " +"won't work so well. By andenixa et al. Very configurable.\n" +"L|unbalanced: 128px in/out model from andenixa. The autoencoders are " +"unbalanced so B>A swaps won't work so well. Very configurable.\n" +"L|villain: 128px in/out model from villainguy. Very resource hungry (You " +"will require a GPU with a fair amount of VRAM). Good for details, but more " +"susceptible to color differences." +msgstr "" + +#: lib/cli/args_train.py:125 +msgid "" +"Output a summary of the model and exit. If a model folder is provided then a " +"summary of the saved model is displayed. Otherwise a summary of the model " +"that would be created by the chosen plugin and configuration settings is " +"displayed." +msgstr "" + +#: lib/cli/args_train.py:135 +msgid "" +"Freeze the weights of the model. Freezing weights means that some of the " +"parameters in the model will no longer continue to learn, but those that are " +"not frozen will continue to learn. For most models, this will freeze the " +"encoder, but some models may have configuration options for freezing other " +"layers." +msgstr "" + +#: lib/cli/args_train.py:147 lib/cli/args_train.py:160 +#: lib/cli/args_train.py:174 lib/cli/args_train.py:183 +#: lib/cli/args_train.py:190 lib/cli/args_train.py:199 +msgid "training" +msgstr "" + +#: lib/cli/args_train.py:149 +msgid "" +"Batch size. This is the number of images processed through the model for " +"each side per iteration. NB: As the model is fed 2 sides at a time, the " +"actual number of images within the model at any one time is double the " +"number that you set here. Larger batches require more GPU RAM." +msgstr "" + +#: lib/cli/args_train.py:162 +msgid "" +"Length of training in iterations. This is only really used for automation. " +"There is no 'correct' number of iterations a model should be trained for. " +"You should stop training when you are happy with the previews. However, if " +"you want the model to stop automatically at a set number of iterations, you " +"can set that value here." +msgstr "" + +#: lib/cli/args_train.py:176 +msgid "" +"Learning rate warmup. Linearly increase the learning rate from 0 to the " +"chosen target rate over the number of iterations given here. 0 to disable." +msgstr "" + +#: lib/cli/args_train.py:184 +msgid "Use distibuted training on multi-gpu setups." +msgstr "" + +#: lib/cli/args_train.py:192 +msgid "" +"Disables TensorBoard logging. NB: Disabling logs means that you will not be " +"able to use the graph or analysis for this session in the GUI." +msgstr "" + +#: lib/cli/args_train.py:201 +msgid "" +"Use the Learning Rate Finder to discover the optimal learning rate for " +"training. For new models, this will calculate the optimal learning rate for " +"the model. For existing models this will use the optimal learning rate that " +"was discovered when initializing the model. Setting this option will ignore " +"the manually configured learning rate (configurable in train settings)." +msgstr "" + +#: lib/cli/args_train.py:214 lib/cli/args_train.py:224 +msgid "Saving" +msgstr "" + +#: lib/cli/args_train.py:215 +msgid "Sets the number of iterations between each model save." +msgstr "" + +#: lib/cli/args_train.py:226 +msgid "" +"Sets the number of iterations before saving a backup snapshot of the model " +"in it's current state. Set to 0 for off." +msgstr "" + +#: lib/cli/args_train.py:233 lib/cli/args_train.py:245 +#: lib/cli/args_train.py:257 +msgid "timelapse" +msgstr "" + +#: lib/cli/args_train.py:235 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'A' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-B parameter." +msgstr "" + +#: lib/cli/args_train.py:247 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'B' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-A parameter." +msgstr "" + +#: lib/cli/args_train.py:259 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. If " +"the input folders are supplied but no output folder, it will default to your " +"model folder/timelapse/" +msgstr "" + +#: lib/cli/args_train.py:268 lib/cli/args_train.py:275 +msgid "preview" +msgstr "" + +#: lib/cli/args_train.py:269 +msgid "Show training preview output. in a separate window." +msgstr "" + +#: lib/cli/args_train.py:277 +msgid "" +"Writes the training result to a file. The image will be stored in the root " +"of your FaceSwap folder." +msgstr "" + +#: lib/cli/args_train.py:284 lib/cli/args_train.py:294 +#: lib/cli/args_train.py:304 lib/cli/args_train.py:314 +msgid "augmentation" +msgstr "" + +#: lib/cli/args_train.py:286 +msgid "" +"Warps training faces to closely matched Landmarks from the opposite face-set " +"rather than randomly warping the face. This is the 'dfaker' way of doing " +"warping." +msgstr "" + +#: lib/cli/args_train.py:296 +msgid "" +"To effectively learn, a random set of images are flipped horizontally. " +"Sometimes it is desirable for this not to occur. Generally this should be " +"left off except for during 'fit training'." +msgstr "" + +#: lib/cli/args_train.py:306 +msgid "" +"Color augmentation helps make the model less susceptible to color " +"differences between the A and B sets, at an increased training time cost. " +"Enable this option to disable color augmentation." +msgstr "" + +#: lib/cli/args_train.py:316 +msgid "" +"Warping is integral to training the Neural Network. This option should only " +"be enabled towards the very end of training to try to bring out more detail. " +"Think of it as 'fine-tuning'. Enabling this option from the beginning is " +"likely to kill a model and lead to terrible results." +msgstr "" diff --git a/locales/lib.config.objects.pot b/locales/lib.config.objects.pot new file mode 100644 index 0000000000..1290692e57 --- /dev/null +++ b/locales/lib.config.objects.pot @@ -0,0 +1,61 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-11 19:02+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: lib/config/objects.py:115 +msgid "" +"\n" +"This option can be updated for existing models.\n" +msgstr "" + +#: lib/config/objects.py:117 +msgid "" +"\n" +"If selecting multiple options then each option should be separated by a " +"space or a comma (e.g. item1, item2, item3)\n" +msgstr "" + +#: lib/config/objects.py:120 +msgid "" +"\n" +"Choose from: {}" +msgstr "" + +#: lib/config/objects.py:122 +msgid "" +"\n" +"Choose from: True, False" +msgstr "" + +#: lib/config/objects.py:126 +msgid "" +"\n" +"Select an integer between {} and {}" +msgstr "" + +#: lib/config/objects.py:130 +msgid "" +"\n" +"Select a decimal number between {} and {}" +msgstr "" + +#: lib/config/objects.py:132 +msgid "" +"\n" +"[Default: {}]" +msgstr "" diff --git a/locales/plugins.extract.extract_config.pot b/locales/plugins.extract.extract_config.pot new file mode 100644 index 0000000000..fc012eb3ff --- /dev/null +++ b/locales/plugins.extract.extract_config.pot @@ -0,0 +1,105 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-12 13:11+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: plugins/extract/extract_config.py:23 +msgid "Options that apply to all extraction plugins" +msgstr "" + +#: plugins/extract/extract_config.py:30 plugins/extract/extract_config.py:45 +#: plugins/extract/extract_config.py:60 plugins/extract/extract_config.py:72 +#: plugins/extract/extract_config.py:85 plugins/extract/extract_config.py:95 +#: plugins/extract/extract_config.py:107 +msgid "filters" +msgstr "" + +#: plugins/extract/extract_config.py:32 +msgid "" +"Filters out faces below this size. This is a multiplier of the minimum " +"dimension of the frame (i.e. 1280x720 = 720). If the original face extract " +"box is smaller than the minimum dimension times this multiplier, it is " +"considered a false positive and discarded. Faces which are found to be " +"unusually smaller than the frame tend to be misaligned images, except in " +"extreme long-shots. These can be usually be safely discarded." +msgstr "" + +#: plugins/extract/extract_config.py:47 +msgid "" +"Filters out faces above this size. This is a multiplier of the minimum " +"dimension of the frame (i.e. 1280x720 = 720). If the original face extract " +"box is larger than the minimum dimension times this multiplier, it is " +"considered a false positive and discarded. Faces which are found to be " +"unusually larger than the frame tend to be misaligned images except in " +"extreme close-ups. These can be usually be safely discarded." +msgstr "" + +#: plugins/extract/extract_config.py:62 +msgid "" +"Filters out faces who's landmarks are above this distance from an 'average' " +"face. Values above 15 tend to be fairly safe. Values above 10 will remove " +"more false positives, but may also filter out some faces at extreme angles." +msgstr "" + +#: plugins/extract/extract_config.py:74 +msgid "" +"Filters out faces who's calculated roll is greater than zero +/- this value " +"in degrees. Aligned faces should have a roll value close to zero. Values " +"that are a significant distance from 0 degrees tend to be misaligned images. " +"These can usually be safely disgarded." +msgstr "" + +#: plugins/extract/extract_config.py:87 +msgid "" +"Filters out faces where the lowest point of the aligned face's eye or " +"eyebrow is lower than the highest point of the aligned face's mouth. Any " +"faces where this occurs are misaligned and can be safely disgarded." +msgstr "" + +#: plugins/extract/extract_config.py:97 +msgid "" +"If enabled, and 're-feed' has been selected for extraction, then interim " +"alignments will be filtered prior to averaging the final landmarks. This can " +"help improve the final alignments by removing any obvious misaligns from the " +"interim results, and may also help pick up difficult alignments. If " +"disabled, then all re-feed results will be averaged." +msgstr "" + +#: plugins/extract/extract_config.py:109 +msgid "" +"If enabled, saves any filtered out images into a sub-folder during the " +"extraction process. If disabled, filtered faces are deleted. Note: The faces " +"will always be filtered out of the alignments file, regardless of whether " +"you keep the faces or not." +msgstr "" + +#: plugins/extract/extract_config.py:118 plugins/extract/extract_config.py:128 +msgid "re-align" +msgstr "" + +#: plugins/extract/extract_config.py:120 +msgid "" +"If enabled, and 're-align' has been selected for extraction, then all re-" +"feed iterations are re-aligned. If disabled, then only the final averaged " +"output from re-feed will be re-aligned." +msgstr "" + +#: plugins/extract/extract_config.py:130 +msgid "" +"If enabled, and 're-align' has been selected for extraction, then any " +"alignments which would be filtered out will not be re-aligned." +msgstr "" diff --git a/locales/plugins.train.train_config.pot b/locales/plugins.train.train_config.pot new file mode 100644 index 0000000000..c9bff80353 --- /dev/null +++ b/locales/plugins.train.train_config.pot @@ -0,0 +1,749 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-13 13:39+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: plugins/train/train_config.py:21 +msgid "" +"\n" +"NB: Unless specifically stated, values changed here will only take effect " +"when creating a new model." +msgstr "" + +#: plugins/train/train_config.py:30 +msgid "Options that apply to all models" +msgstr "" + +#: plugins/train/train_config.py:43 plugins/train/train_config.py:66 +#: plugins/train/train_config.py:86 +msgid "face" +msgstr "" + +#: plugins/train/train_config.py:45 +msgid "" +"How to center the training image. The extracted images are centered on the " +"middle of the skull based on the face's estimated pose. A subsection of " +"these images are used for training. The centering used dictates how this " +"subsection will be cropped from the aligned images.\n" +"\tface: Centers the training image on the center of the face, adjusting for " +"pitch and yaw.\n" +"\thead: Centers the training image on the center of the head, adjusting for " +"pitch and yaw. NB: You should only select head centering if you intend to " +"include the full head (including hair) in the final swap. This may give " +"mixed results. Additionally, it is only worth choosing head centering if you " +"are training with a mask that includes the hair (e.g. BiSeNet-FP-Head).\n" +"\tlegacy: The 'original' extraction technique. Centers the training image " +"near the tip of the nose with no adjustment. Can result in the edges of the " +"face appearing outside of the training area." +msgstr "" + +#: plugins/train/train_config.py:68 +msgid "" +"How much of the extracted image to train on. A lower coverage will limit the " +"model's scope to a zoomed-in central area while higher amounts can include " +"the entire face. A trade-off exists between lower amounts given more detail " +"versus higher amounts avoiding noticeable swap transitions. For 'Face' " +"centering you will want to leave this above 75%. For Head centering you will " +"most likely want to set this to 100%. Sensible values for 'Legacy' centering " +"are:\n" +"\t62.5% spans from eyebrow to eyebrow.\n" +"\t75.0% spans from temple to temple.\n" +"\t87.5% spans from ear to ear.\n" +"\t100.0% is a mugshot." +msgstr "" + +#: plugins/train/train_config.py:88 +msgid "" +"How much to adjust the vertical position of the aligned face as a percentage " +"of face image size. Negative values move the face up (expose more chin and " +"less forehead). Positive values move the face down (expose less chin and " +"more forehead)" +msgstr "" + +#: plugins/train/train_config.py:99 plugins/train/train_config.py:109 +msgid "initialization" +msgstr "" + +#: plugins/train/train_config.py:101 +msgid "" +"Use ICNR to tile the default initializer in a repeating pattern. This " +"strategy is designed for pairing with sub-pixel / pixel shuffler to reduce " +"the 'checkerboard effect' in image reconstruction. \n" +"\t https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf" +msgstr "" + +#: plugins/train/train_config.py:111 +msgid "" +"Use Convolution Aware Initialization for convolutional layers. This can help " +"eradicate the vanishing and exploding gradient problem as well as lead to " +"higher accuracy, lower loss and faster convergence.\n" +"NB:\n" +"\t This can use more VRAM when creating a new model so you may want to lower " +"the batch size for the first run. The batch size can be raised again when " +"reloading the model.\n" +"\t Multi-GPU is not supported for this option, so you should start the model " +"on a single GPU. Once training has started, you can stop training, enable " +"multi-GPU and resume.\n" +"\t Building the model will likely take several minutes as the calculations " +"for this initialization technique are expensive. This will only impact " +"starting a new model." +msgstr "" + +#: plugins/train/train_config.py:126 plugins/train/train_config.py:138 +#: plugins/train/train_config.py:155 +msgid "Learning Rate Finder" +msgstr "" + +#: plugins/train/train_config.py:128 +msgid "" +"The number of iterations to process to find the optimal learning rate. " +"Higher values will take longer, but will be more accurate." +msgstr "" + +#: plugins/train/train_config.py:140 +msgid "" +"The operation mode for the learning rate finder. Only applicable to new " +"models. For existing models this will always default to 'set'.\n" +"\tset - Train with the discovered optimal learning rate.\n" +"\tgraph_and_set - Output a graph in the training folder showing the " +"discovered learning rates and train with the optimal learning rate.\n" +"\tgraph_and_exit - Output a graph in the training folder with the discovered " +"learning rates and exit." +msgstr "" + +#: plugins/train/train_config.py:157 +msgid "" +"How aggressively to set the Learning Rate. More aggressive can learn faster, " +"but is more likely to lead to exploding gradients.\n" +"\tdefault - The default optimal learning rate. A safe choice for nearly all " +"use cases.\n" +"\taggressive - Set's a higher learning rate than the default. May learn " +"faster but with a higher chance of exploding gradients.\n" +"\textreme - The highest optimal learning rate. A much higher risk of " +"exploding gradients." +msgstr "" + +#: plugins/train/train_config.py:172 plugins/train/train_config.py:183 +#: plugins/train/train_config.py:199 +msgid "network" +msgstr "" + +#: plugins/train/train_config.py:174 +msgid "" +"Use reflection padding rather than zero padding with convolutions. Each " +"convolution must pad the image boundaries to maintain the proper sizing. " +"More complex padding schemes can reduce artifacts at the border of the " +"image.\n" +"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" +msgstr "" + +#: plugins/train/train_config.py:185 +msgid "" +"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed " +"precision allows you to use a mix of float16 with float32, to get the " +"performance benefits from float16 and the numeric stability benefits from " +"float32.\n" +"\n" +"This is untested on non-Nvidia cards, but will run on most Nvidia models. it " +"will only speed up training on more recent GPUs. Those with compute " +"capability 7.0 or higher will see the greatest performance benefit from " +"mixed precision because they have Tensor Cores. Older GPUs offer no math " +"performance benefit for using mixed precision, however memory and bandwidth " +"savings can enable some speedups. Generally RTX GPUs and later will offer " +"the most benefit." +msgstr "" + +#: plugins/train/train_config.py:201 +msgid "" +"If a 'NaN' is generated in the model, this means that the model has " +"corrupted and the model is likely to start deteriorating from this point on. " +"Enabling NaN protection will stop training immediately in the event of a " +"NaN. The last save will not contain the NaN, so you may still be able to " +"rescue your model." +msgstr "" + +#: plugins/train/train_config.py:211 +msgid "convert" +msgstr "" + +#: plugins/train/train_config.py:213 +msgid "" +"[GPU Only]. The number of faces to feed through the model at once when " +"running the Convert process.\n" +"\n" +"NB: Increasing this figure is unlikely to improve convert speed, however, if " +"you are getting Out of Memory errors, then you may want to reduce the batch " +"size." +msgstr "" + +#: plugins/train/train_config.py:224 +msgid "" +"Focal Frequency Loss. Analyzes the frequency spectrum of the images rather " +"than the images themselves. This loss function can be used on its own, but " +"the original paper found increased benefits when using it as a complementary " +"loss to another spacial loss function (e.g. MSE). Ref: Focal Frequency Loss " +"for Image Reconstruction and Synthesis https://arxiv.org/pdf/2012.12821.pdf " +"NB: This loss does not currently work on AMD cards." +msgstr "" + +#: plugins/train/train_config.py:231 +msgid "" +"Nvidia FLIP. A perceptual loss measure that approximates the difference " +"perceived by humans as they alternate quickly (or flip) between two images. " +"Used on its own and this loss function creates a distinct grid on the " +"output. However it can be helpful when used as a complimentary loss " +"function. Ref: FLIP: A Difference Evaluator for Alternating Images: https://" +"research.nvidia.com/sites/default/files/node/3260/FLIP_Paper.pdf" +msgstr "" + +#: plugins/train/train_config.py:238 +msgid "" +"Gradient Magnitude Similarity Deviation seeks to match the global standard " +"deviation of the pixel to pixel differences between two images. Similar in " +"approach to SSIM. Ref: Gradient Magnitude Similarity Deviation: An Highly " +"Efficient Perceptual Image Quality Index https://arxiv.org/ftp/arxiv/papers/" +"1308/1308.3052.pdf" +msgstr "" + +#: plugins/train/train_config.py:243 +msgid "" +"The L_inf norm will reduce the largest individual pixel error in an image. " +"As each largest error is minimized sequentially, the overall error is " +"improved. This loss will be extremely focused on outliers." +msgstr "" + +#: plugins/train/train_config.py:247 +msgid "" +"Laplacian Pyramid Loss. Attempts to improve results by focussing on edges " +"using Laplacian Pyramids. As this loss function gives priority to edges over " +"other low-frequency information, like color, it should not be used on its " +"own. The original implementation uses this loss as a complimentary function " +"to MSE. Ref: Optimizing the Latent Space of Generative Networks https://" +"arxiv.org/abs/1707.05776" +msgstr "" + +#: plugins/train/train_config.py:254 +msgid "" +"LPIPS is a perceptual loss that uses the feature outputs of other pretrained " +"models as a loss metric. Be aware that this loss function will use more " +"VRAM. Used on its own and this loss will create a distinct moire pattern on " +"the output, however it can be helpful as a complimentary loss function. The " +"output of this function is strong, so depending on your chosen primary loss " +"function, you are unlikely going to want to set the weight above about 25%. " +"Ref: The Unreasonable Effectiveness of Deep Features as a Perceptual Metric " +"http://arxiv.org/abs/1801.03924\n" +"This variant uses the AlexNet backbone. A fairly light and old model which " +"performed best in the paper's original implementation.\n" +"NB: For AMD Users the final linear layer is not implemented." +msgstr "" + +#: plugins/train/train_config.py:264 +msgid "" +"Same as lpips_alex, but using the SqueezeNet backbone. A more lightweight " +"version of AlexNet.\n" +"NB: For AMD Users the final linear layer is not implemented." +msgstr "" + +#: plugins/train/train_config.py:267 +msgid "" +"Same as lpips_alex, but using the VGG16 backbone. A more heavyweight model.\n" +"NB: For AMD Users the final linear layer is not implemented." +msgstr "" + +#: plugins/train/train_config.py:270 +msgid "" +"log(cosh(x)) acts similar to MSE for small errors and to MAE for large " +"errors. Like MSE, it is very stable and prevents overshoots when errors are " +"near zero. Like MAE, it is robust to outliers." +msgstr "" + +#: plugins/train/train_config.py:274 +msgid "" +"Mean absolute error will guide reconstructions of each pixel towards its " +"median value in the training dataset. Robust to outliers but as a median, it " +"can potentially ignore some infrequent image types in the dataset." +msgstr "" + +#: plugins/train/train_config.py:278 +msgid "" +"Mean squared error will guide reconstructions of each pixel towards its " +"average value in the training dataset. As an avg, it will be susceptible to " +"outliers and typically produces slightly blurrier results. Ref: Multi-Scale " +"Structural Similarity for Image Quality Assessment https://www.cns.nyu.edu/" +"pub/eero/wang03b.pdf" +msgstr "" + +#: plugins/train/train_config.py:283 +msgid "" +"Multiscale Structural Similarity Index Metric is similar to SSIM except that " +"it performs the calculations along multiple scales of the input image." +msgstr "" + +#: plugins/train/train_config.py:286 +msgid "" +"Smooth_L1 is a modification of the MAE loss to correct two of its " +"disadvantages. This loss has improved stability and guidance for small " +"errors. Ref: A General and Adaptive Robust Loss Function https://arxiv.org/" +"pdf/1701.03077.pdf" +msgstr "" + +#: plugins/train/train_config.py:290 +msgid "" +"Structural Similarity Index Metric is a perception-based loss that considers " +"changes in texture, luminance, contrast, and local spatial statistics of an " +"image. Potentially delivers more realistic looking images. Ref: Image " +"Quality Assessment: From Error Visibility to Structural Similarity http://" +"www.cns.nyu.edu/pub/eero/wang03-reprint.pdf" +msgstr "" + +#: plugins/train/train_config.py:295 +msgid "" +"Instead of minimizing the difference between the absolute value of each " +"pixel in two reference images, compute the pixel to pixel spatial difference " +"in each image and then minimize that difference between two images. Allows " +"for large color shifts, but maintains the structure of the image." +msgstr "" + +#: plugins/train/train_config.py:299 +msgid "Do not use an additional loss function." +msgstr "" + +#: plugins/train/train_config.py:315 +msgid "" +"Loss configuration options\n" +"Loss is the mechanism by which a Neural Network judges how well it thinks " +"that it is recreating a face." +msgstr "" + +#: plugins/train/train_config.py:321 plugins/train/train_config.py:331 +#: plugins/train/train_config.py:343 plugins/train/train_config.py:362 +#: plugins/train/train_config.py:372 plugins/train/train_config.py:391 +#: plugins/train/train_config.py:402 plugins/train/train_config.py:421 +#: plugins/train/train_config.py:436 plugins/train/train_config.py:450 +#: plugins/train/train_config.py:464 +msgid "loss" +msgstr "" + +#: plugins/train/train_config.py:322 +msgid "The loss function to use." +msgstr "" + +#: plugins/train/train_config.py:333 +msgid "" +"The second loss function to use. If using a structural based loss (such as " +"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " +"regularization (MSE) function. You can adjust the weighting of this loss " +"function with the loss_weight_2 option.\n" +"\n" +"\t\n" +"\n" +"\t" +msgstr "" + +#: plugins/train/train_config.py:345 +msgid "" +"The amount of weight to apply to the second loss function.\n" +"\n" +"\n" +"\n" +"The value given here is as a percentage denoting how much the selected " +"function should contribute to the overall loss cost of the model. For " +"example:\n" +"\t 100 - The loss calculated for the second loss function will be applied at " +"its full amount towards the overall loss score. \n" +"\t 25 - The loss calculated for the second loss function will be reduced by " +"a quarter prior to adding to the overall loss score. \n" +"\t 400 - The loss calculated for the second loss function will be mulitplied " +"4 times prior to adding to the overall loss score. \n" +"\t 0 - Disables the second loss function altogether." +msgstr "" + +#: plugins/train/train_config.py:363 +msgid "" +"The third loss function to use. You can adjust the weighting of this loss " +"function with the loss_weight_3 option.\n" +"\n" +"\t\n" +"\n" +"\t" +msgstr "" + +#: plugins/train/train_config.py:374 +msgid "" +"The amount of weight to apply to the third loss function.\n" +"\n" +"\n" +"\n" +"The value given here is as a percentage denoting how much the selected " +"function should contribute to the overall loss cost of the model. For " +"example:\n" +"\t 100 - The loss calculated for the third loss function will be applied at " +"its full amount towards the overall loss score. \n" +"\t 25 - The loss calculated for the third loss function will be reduced by a " +"quarter prior to adding to the overall loss score. \n" +"\t 400 - The loss calculated for the third loss function will be mulitplied " +"4 times prior to adding to the overall loss score. \n" +"\t 0 - Disables the third loss function altogether." +msgstr "" + +#: plugins/train/train_config.py:393 +msgid "" +"The fourth loss function to use. You can adjust the weighting of this loss " +"function with the loss_weight_3 option.\n" +"\n" +"\t\n" +"\n" +"\t" +msgstr "" + +#: plugins/train/train_config.py:404 +msgid "" +"The amount of weight to apply to the fourth loss function.\n" +"\n" +"\n" +"\n" +"The value given here is as a percentage denoting how much the selected " +"function should contribute to the overall loss cost of the model. For " +"example:\n" +"\t 100 - The loss calculated for the fourth loss function will be applied at " +"its full amount towards the overall loss score. \n" +"\t 25 - The loss calculated for the fourth loss function will be reduced by " +"a quarter prior to adding to the overall loss score. \n" +"\t 400 - The loss calculated for the fourth loss function will be mulitplied " +"4 times prior to adding to the overall loss score. \n" +"\t 0 - Disables the fourth loss function altogether." +msgstr "" + +#: plugins/train/train_config.py:423 +msgid "" +"The loss function to use when learning a mask.\n" +"\t MAE - Mean absolute error will guide reconstructions of each pixel " +"towards its median value in the training dataset. Robust to outliers but as " +"a median, it can potentially ignore some infrequent image types in the " +"dataset.\n" +"\t MSE - Mean squared error will guide reconstructions of each pixel towards " +"its average value in the training dataset. As an average, it will be " +"susceptible to outliers and typically produces slightly blurrier results." +msgstr "" + +#: plugins/train/train_config.py:438 +msgid "" +"The amount of priority to give to the eyes.\n" +"\n" +"The value given here is as a multiplier of the main loss score. For " +"example:\n" +"\t 1 - The eyes will receive the same priority as the rest of the face. \n" +"\t 10 - The eyes will be given a score 10 times higher than the rest of the " +"face.\n" +"\n" +"NB: Penalized Mask Loss must be enable to use this option." +msgstr "" + +#: plugins/train/train_config.py:452 +msgid "" +"The amount of priority to give to the mouth.\n" +"\n" +"The value given here is as a multiplier of the main loss score. For " +"Example:\n" +"\t 1 - The mouth will receive the same priority as the rest of the face. \n" +"\t 10 - The mouth will be given a score 10 times higher than the rest of the " +"face.\n" +"\n" +"NB: Penalized Mask Loss must be enable to use this option." +msgstr "" + +#: plugins/train/train_config.py:466 +msgid "" +"Image loss function is weighted by mask presence. For areas of the image " +"without the facial mask, reconstruction errors will be ignored while the " +"masked face area is prioritized. May increase overall quality by focusing " +"attention on the core face area." +msgstr "" + +#: plugins/train/train_config.py:473 plugins/train/train_config.py:514 +#: plugins/train/train_config.py:525 plugins/train/train_config.py:539 +#: plugins/train/train_config.py:549 +msgid "mask" +msgstr "" + +#: plugins/train/train_config.py:475 +msgid "" +"The mask to be used for training. If you have selected 'Learn Mask' or " +"'Penalized Mask Loss' you must select a value other than 'none'. The " +"required mask should have been selected as part of the Extract process. If " +"it does not exist in the alignments file then it will be generated prior to " +"training commencing.\n" +"\tnone: Don't use a mask.\n" +"\tbisenet-fp_face: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'face' or " +"'legacy' centering.\n" +"\tbisenet-fp_head: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'head' " +"centering.\n" +"\tcomponents: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"\tcustom_face: Custom user created, face centered mask.\n" +"\tcustom_head: Custom user created, head centered mask.\n" +"\textended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"\tvgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"\tvgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"\tunet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance." +msgstr "" + +#: plugins/train/train_config.py:516 +msgid "" +"Dilate or erode the mask. Negative values erode the mask (make it smaller). " +"Positive values dilate the mask (make it larger). The value given is a " +"percentage of the total mask size." +msgstr "" + +#: plugins/train/train_config.py:527 +msgid "" +"Apply gaussian blur to the mask input. This has the effect of smoothing the " +"edges of the mask, which can help with poorly calculated masks and give less " +"of a hard edge to the predicted mask. The size is in pixels (calculated from " +"a 128px mask). Set to 0 to not apply gaussian blur. This value should be " +"odd, if an even number is passed in then it will be rounded to the next odd " +"number." +msgstr "" + +#: plugins/train/train_config.py:541 +msgid "" +"Sets pixels that are near white to white and near black to black. Set to 0 " +"for off." +msgstr "" + +#: plugins/train/train_config.py:551 +msgid "" +"Dedicate a portion of the model to learning how to duplicate the input mask. " +"Increases VRAM usage in exchange for learning a quick ability to try to " +"replicate more complex mask models." +msgstr "" + +#: plugins/train/train_config.py:559 +msgid "" +"Optimizer configuration options\n" +"The optimizer applies the output of the loss function to the model.\n" +msgstr "" + +#: plugins/train/train_config.py:565 plugins/train/train_config.py:600 +#: plugins/train/train_config.py:613 plugins/train/train_config.py:634 +msgid "optimizer" +msgstr "" + +#: plugins/train/train_config.py:567 +msgid "" +"The optimizer to use.\n" +"\t adabelief - Adapting Stepsizes by the Belief in Observed Gradients. An " +"optimizer with the aim to converge faster, generalize better and remain more " +"stable. (https://arxiv.org/abs/2010.07468). NB: Epsilon for AdaBelief needs " +"to be set to a smaller value than other Optimizers. Generally setting the " +"'Epsilon Exponent' to around '-16' should work.\n" +"\t adam - Adaptive Moment Optimization. A stochastic gradient descent method " +"that is based on adaptive estimation of first-order and second-order " +"moments.\n" +"\t adamax - a variant of Adam based on the infinity norm. Due to its " +"capability of adjusting the learning rate based on data characteristics, it " +"is suited to learn time-variant process, parameters follow those provided in " +"the paper\n" +"\t adamw - Like 'adam' but with an added method to decay weights per the " +"techniques discussed in the paper (https://arxiv.org/abs/1711.05101). NB: " +"Weight decay should be set at 0.004 for default implementation.\n" +"\t lion - A method that uses the sign operator to control the magnitude of " +"the update, rather than relying on second-order moments (Adam). saves VRAM " +"by only tracking the momentum. Performance gains should be better with " +"larger batch sizes. A suitable learning rate for Lion is typically 3-10x " +"smaller than that for AdamW. The weight decay for Lion should be 3-10x " +"larger than that for AdamW to maintain a similar strength.\n" +"\t nadam - Adaptive Moment Optimization with Nesterov Momentum. Much like " +"Adam but uses a different formula for calculating momentum.\n" +"\t rms-prop - Root Mean Square Propagation. Maintains a moving (discounted) " +"average of the square of the gradients. Divides the gradient by the root of " +"this average." +msgstr "" + +#: plugins/train/train_config.py:602 +msgid "" +"Learning rate - how fast your network will learn (how large are the " +"modifications to the model weights after one batch of training). Values that " +"are too large might result in model crashes and the inability of the model " +"to find the best solution. Values that are too small might be unable to " +"escape from dead-ends and find the best global minimum." +msgstr "" + +#: plugins/train/train_config.py:615 +msgid "" +"The epsilon adds a small constant to weight updates to attempt to avoid " +"'divide by zero' errors. Unless you are using the AdaBelief Optimizer, then " +"Generally this option should be left at default value, For AdaBelief, " +"setting this to around '-16' should work.\n" +"In all instances if you are getting 'NaN' loss values, and have been unable " +"to resolve the issue any other way (for example, increasing batch size, or " +"lowering learning rate), then raising the epsilon can lead to a more stable " +"model. It may, however, come at the cost of slower training and a less " +"accurate final result.\n" +"Note: The value given here is the 'exponent' to the epsilon. For example, " +"choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the " +"epsilon to 0.001 (1e-3).\n" +"Note: Not used by the Lion optimizer" +msgstr "" + +#: plugins/train/train_config.py:636 +msgid "" +"When to save the Optimizer Weights. Saving the optimizer weights is not " +"necessary and will increase the model file size 3x (and by extension the " +"amount of time it takes to save the model). However, it can be useful to " +"save these weights if you want to guarantee that a resumed model carries off " +"exactly from where it left off, rather than spending a few hundred " +"iterations catching up.\n" +"\t never - Don't save optimizer weights.\n" +"\t always - Save the optimizer weights at every save iteration. Model saving " +"will take longer, due to the increased file size, but you will always have " +"the last saved optimizer state in your model file.\n" +"\t exit - Only save the optimizer weights when explicitly terminating a " +"model. This can be when the model is actively stopped or when the target " +"iterations are met. Note: If the training session ends because of another " +"reason (e.g. power outage, Out of Memory Error, NaN detected) then the " +"optimizer weights will NOT be saved." +msgstr "" + +#: plugins/train/train_config.py:657 plugins/train/train_config.py:676 +#: plugins/train/train_config.py:695 +msgid "clipping" +msgstr "" + +#: plugins/train/train_config.py:659 +msgid "" +"Apply clipping to the gradients. Can help prevent NaNs and improve model " +"optimization at the expense of VRAM.\n" +"\tautoclip: Analyzes the gradient weights and adjusts the normalization " +"value dynamically to fit the data\n" +"\tglobal_norm: Clips the gradient of each weight so that the global norm is " +"no higher than the given value.\n" +"\tnorm: Clips the gradient of each weight so that its norm is no higher than " +"the given value.\n" +"\tvalue: Clips the gradient of each weight so that it is no higher than the " +"given value.\n" +"\tnone: Don't perform any clipping to the gradients." +msgstr "" + +#: plugins/train/train_config.py:678 +msgid "" +"The amount of clipping to perform.\n" +"\tautoclip: The percentile to clip at. A value of 1.0 will clip at the 10th " +"percentile a value of 2.5 will clip at the 25th percentile etc. Default: " +"1.0\n" +"\tglobal_norm: The gradient of each weight is clipped so that the global " +"norm is no higher than this value.\n" +"\tnorm: The gradient of each weight is clipped so that its norm is no higher " +"than this value.\n" +"\tvalue: The gradient of each weight is clipped to be no higher than this " +"value.\n" +"\tnone: This option is ignored." +msgstr "" + +#: plugins/train/train_config.py:697 +msgid "" +"The maximum number of prior iterations for autoclipper to analyze when " +"calculating the normalization amount. 0 to always include all prior " +"iterations." +msgstr "" + +#: plugins/train/train_config.py:706 plugins/train/train_config.py:715 +msgid "updates" +msgstr "" + +#: plugins/train/train_config.py:707 +msgid "" +"If set, weight decay is applied. 0.0 for no weight decay. Default is 0.0 for " +"all optimizers except AdamW (0.004)" +msgstr "" + +#: plugins/train/train_config.py:717 +msgid "" +"Values above 1 will enable Gradient Accumulation. Updates will not be at " +"every iteration; instead they will occur every number of iterations given " +"here. The update will be the average value of the gradients since the last " +"update. Can be useful when your batch size is very small, in order to reduce " +"gradient noise at each update iteration." +msgstr "" + +#: plugins/train/train_config.py:728 plugins/train/train_config.py:738 +#: plugins/train/train_config.py:749 +msgid "exponential moving average" +msgstr "" + +#: plugins/train/train_config.py:730 +msgid "" +"Enable exponential moving average (EMA). EMA consists of computing an " +"exponential moving average of the weights of the model (as the weight values " +"change after each training batch), and periodically overwriting the weights " +"with their moving average" +msgstr "" + +#: plugins/train/train_config.py:740 +msgid "" +"Only used if use_ema is enabled. This is the momentum to use when computing " +"the EMA of the model's weights: new_average = ema_momentum * old_average + " +"(1 - ema_momentum) * current_variable_value." +msgstr "" + +#: plugins/train/train_config.py:751 +msgid "" +"Only used if use_ema is enabled. Set the number of iterations, to overwrite " +"the model variable by its moving average. " +msgstr "" + +#: plugins/train/train_config.py:759 plugins/train/train_config.py:770 +#: plugins/train/train_config.py:781 +msgid "optimizer specific" +msgstr "" + +#: plugins/train/train_config.py:761 +msgid "" +"The exponential decay rate for the 1st moment estimates. Used for the " +"following Optimizers: AdaBelief, Adam, Adamax, AdamW, Lion, nAdam. Ignored " +"for all others." +msgstr "" + +#: plugins/train/train_config.py:772 +msgid "" +"The exponential decay rate for the 2nd moment estimates. Used for the " +"following Optimizers: AdaBelief, Adam, Adamax, AdamW, Lion, nAdam. Ignored " +"for all others." +msgstr "" + +#: plugins/train/train_config.py:783 +msgid "" +"Whether to apply AMSGrad variant of the algorithm from the paper 'On the " +"Convergence of Adam and beyond. Used for the following Optimizers: " +"AdaBelief, Adam, AdamW. Ignored for all others.'" +msgstr "" diff --git a/locales/plugins.train.trainer.trainer_config.pot b/locales/plugins.train.trainer.trainer_config.pot new file mode 100644 index 0000000000..8c44e5e618 --- /dev/null +++ b/locales/plugins.train.trainer.trainer_config.pot @@ -0,0 +1,110 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-12 20:45+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: plugins/train/trainer/trainer_config.py:30 +#, python-format +msgid "" +"Data Augmentation Options.\n" +"WARNING: The defaults for augmentation will be fine for 99.9% of use cases. " +"Only change them if you absolutely know what you are doing!" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:42 +#: plugins/train/trainer/trainer_config.py:50 +#: plugins/train/trainer/trainer_config.py:60 +msgid "evaluation" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:43 +msgid "" +"Number of sample faces to display for each side in the preview when training." +msgstr "" + +#: plugins/train/trainer/trainer_config.py:51 +msgid "" +"The opacity of the mask overlay in the training preview. Lower values are " +"more transparent." +msgstr "" + +#: plugins/train/trainer/trainer_config.py:61 +msgid "The RGB hex color to use for the mask overlay in the training preview." +msgstr "" + +#: plugins/train/trainer/trainer_config.py:66 +#: plugins/train/trainer/trainer_config.py:74 +#: plugins/train/trainer/trainer_config.py:82 +#: plugins/train/trainer/trainer_config.py:91 +msgid "image augmentation" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:67 +msgid "Percentage amount to randomly zoom each training image in and out." +msgstr "" + +#: plugins/train/trainer/trainer_config.py:75 +msgid "Percentage amount to randomly rotate each training image." +msgstr "" + +#: plugins/train/trainer/trainer_config.py:83 +msgid "" +"Percentage amount to randomly shift each training image horizontally and " +"vertically." +msgstr "" + +#: plugins/train/trainer/trainer_config.py:92 +msgid "" +"Percentage chance to randomly flip each training image horizontally.\n" +"NB: This is ignored if the 'no-flip' option is enabled" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:100 +#: plugins/train/trainer/trainer_config.py:109 +#: plugins/train/trainer/trainer_config.py:119 +#: plugins/train/trainer/trainer_config.py:130 +msgid "color augmentation" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:101 +msgid "" +"Percentage amount to randomly alter the lightness of each training image.\n" +"NB: This is ignored if the 'no-augment-color' option is enabled" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:110 +msgid "" +"Percentage amount to randomly alter the 'a' and 'b' colors of the L*a*b* " +"color space of each training image.\n" +"NB: This is ignored if the 'no-augment-color' optionis enabled" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:120 +msgid "" +"Percentage chance to perform Contrast Limited Adaptive Histogram " +"Equalization on each training image.\n" +"NB: This is ignored if the 'no-augment-color' option is enabled" +msgstr "" + +#: plugins/train/trainer/trainer_config.py:131 +msgid "" +"The grid size dictates how much Contrast Limited Adaptive Histogram " +"Equalization is performed on any training image selected for clahe. Contrast " +"will be applied randomly with a gridsize of 0 up to the maximum. This value " +"is a multiplier calculated from the training image size.\n" +"NB: This is ignored if the 'no-augment-color' option is enabled" +msgstr "" diff --git a/locales/ru/LC_MESSAGES/faceswap.mo b/locales/ru/LC_MESSAGES/faceswap.mo new file mode 100644 index 0000000000..db2449a789 Binary files /dev/null and b/locales/ru/LC_MESSAGES/faceswap.mo differ diff --git a/locales/ru/LC_MESSAGES/faceswap.po b/locales/ru/LC_MESSAGES/faceswap.po new file mode 100644 index 0000000000..7c37585784 --- /dev/null +++ b/locales/ru/LC_MESSAGES/faceswap.po @@ -0,0 +1,34 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"POT-Creation-Date: 2021-02-18 23:48-0000\n" +"PO-Revision-Date: 2023-04-11 12:56+0700\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.2.2\n" + +#: faceswap.py:43 +msgid "Extract the faces from pictures or a video" +msgstr "Извлечение лиц из картинок или видео" + +#: faceswap.py:44 +msgid "Train a model for the two faces A and B" +msgstr "Обучить модель для двух лиц A и B" + +#: faceswap.py:47 +msgid "Convert source pictures or video to a new one with the face swapped" +msgstr "Преобразование исходных изображений или видео в новое с заменой лиц" + +#: faceswap.py:48 +msgid "Launch the Faceswap Graphical User Interface" +msgstr "Запуск графического интерфейса Faceswap" diff --git a/locales/ru/LC_MESSAGES/gui.menu.mo b/locales/ru/LC_MESSAGES/gui.menu.mo new file mode 100644 index 0000000000..15df09a5cd Binary files /dev/null and b/locales/ru/LC_MESSAGES/gui.menu.mo differ diff --git a/locales/ru/LC_MESSAGES/gui.menu.po b/locales/ru/LC_MESSAGES/gui.menu.po new file mode 100644 index 0000000000..581dfde23f --- /dev/null +++ b/locales/ru/LC_MESSAGES/gui.menu.po @@ -0,0 +1,156 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-06-07 13:54+0100\n" +"PO-Revision-Date: 2023-06-07 20:29+0700\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.3.1\n" + +#: lib/gui/menu.py:37 +msgid "faceswap.dev - Guides and Forum" +msgstr "faceswap.dev - Руководства и Форум" + +#: lib/gui/menu.py:38 +msgid "Patreon - Support this project" +msgstr "Patreon - Поддержите этот проект" + +#: lib/gui/menu.py:39 +msgid "Discord - The FaceSwap Discord server" +msgstr "Discord - Discord сервер Faceswap" + +#: lib/gui/menu.py:40 +msgid "Github - Our Source Code" +msgstr "Github - Наш исходный код" + +#: lib/gui/menu.py:60 +msgid "File" +msgstr "Файл" + +#: lib/gui/menu.py:61 +msgid "Settings" +msgstr "Настройки" + +#: lib/gui/menu.py:62 +msgid "Help" +msgstr "Помощь" + +#: lib/gui/menu.py:85 +msgid "Configure Settings..." +msgstr "Настройки..." + +#: lib/gui/menu.py:116 +msgid "New Project..." +msgstr "Новый проект..." + +#: lib/gui/menu.py:121 +msgid "Open Project..." +msgstr "Открыть проект..." + +#: lib/gui/menu.py:126 +msgid "Save Project" +msgstr "Сохранить проект" + +#: lib/gui/menu.py:131 +msgid "Save Project as..." +msgstr "Сохранить проект как..." + +#: lib/gui/menu.py:136 +msgid "Reload Project from Disk" +msgstr "Перезагрузить Проект из диска" + +#: lib/gui/menu.py:141 +msgid "Close Project" +msgstr "Закрыть проект" + +#: lib/gui/menu.py:147 +msgid "Open Task..." +msgstr "Открыть задачу..." + +#: lib/gui/menu.py:154 +msgid "Open recent" +msgstr "Открытые недавно" + +#: lib/gui/menu.py:156 +msgid "Quit" +msgstr "Выход" + +#: lib/gui/menu.py:211 +msgid "{} Task" +msgstr "{} Задача" + +#: lib/gui/menu.py:223 +msgid "Clear recent files" +msgstr "Очистить недавние файлы" + +#: lib/gui/menu.py:391 +msgid "Check for updates..." +msgstr "Проверить обновления..." + +#: lib/gui/menu.py:394 +msgid "Update Faceswap..." +msgstr "Обновить Faceswap..." + +#: lib/gui/menu.py:398 +msgid "Switch Branch" +msgstr "Сменить ветку" + +#: lib/gui/menu.py:401 +msgid "Resources" +msgstr "Ресурсы" + +#: lib/gui/menu.py:404 +msgid "Output System Information" +msgstr "Вывести информацию о системе" + +#: lib/gui/menu.py:589 +msgid "currently selected Task" +msgstr "текущую выбранную задачу" + +#: lib/gui/menu.py:589 +msgid "Project" +msgstr "Проект" + +#: lib/gui/menu.py:591 +msgid "Reload {} from disk" +msgstr "Перезагрузить {} из диска" + +#: lib/gui/menu.py:593 +msgid "Create a new {}..." +msgstr "Создать новый {}..." + +#: lib/gui/menu.py:595 +msgid "Reset {} to default" +msgstr "Сбросить {} по умолчанию" + +#: lib/gui/menu.py:597 +msgid "Save {}" +msgstr "Сохранить {}" + +#: lib/gui/menu.py:599 +msgid "Save {} as..." +msgstr "Сохранить {} как..." + +#: lib/gui/menu.py:603 +msgid " from a task or project file" +msgstr " из файла задачи или проекта" + +#: lib/gui/menu.py:604 +msgid "Load {}..." +msgstr "Загрузить {}..." + +#: lib/gui/menu.py:659 +msgid "Configure {} settings..." +msgstr "Настройка параметров {}..." diff --git a/locales/ru/LC_MESSAGES/gui.tooltips.mo b/locales/ru/LC_MESSAGES/gui.tooltips.mo new file mode 100644 index 0000000000..390771565d Binary files /dev/null and b/locales/ru/LC_MESSAGES/gui.tooltips.mo differ diff --git a/locales/ru/LC_MESSAGES/gui.tooltips.po b/locales/ru/LC_MESSAGES/gui.tooltips.po new file mode 100644 index 0000000000..c3b59f0cc3 --- /dev/null +++ b/locales/ru/LC_MESSAGES/gui.tooltips.po @@ -0,0 +1,210 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"POT-Creation-Date: 2021-03-22 18:37+0000\n" +"PO-Revision-Date: 2023-06-07 20:31+0700\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.3.1\n" + +#: lib/gui/command.py:184 +msgid "Output command line options to the console" +msgstr "Вывод опций командной строки в консоль" + +#: lib/gui/command.py:195 +msgid "Run the {} script" +msgstr "Запуск сценария {}" + +#: lib/gui/control_helper.py:1234 +msgid "Select a folder..." +msgstr "Выбрать папку..." + +#: lib/gui/control_helper.py:1235 lib/gui/control_helper.py:1236 +msgid "Select a file..." +msgstr "Выбрать файл..." + +#: lib/gui/control_helper.py:1237 +msgid "Select a folder of images..." +msgstr "Выбрать папку с изображениями..." + +#: lib/gui/control_helper.py:1238 +msgid "Select a video..." +msgstr "Выбрать видео..." + +#: lib/gui/control_helper.py:1239 +msgid "Select a model folder..." +msgstr "Выбрать папку с моделью..." + +#: lib/gui/control_helper.py:1240 +msgid "Select one or more files..." +msgstr "Выбрать один или несколько файлов..." + +#: lib/gui/control_helper.py:1241 +msgid "Select a file or folder..." +msgstr "Выбрать файл или папку..." + +#: lib/gui/control_helper.py:1242 +msgid "Select a save location..." +msgstr "Выбрать место сохранения..." + +#: lib/gui/display.py:71 +msgid "Summary statistics for each training session" +msgstr "Сводная статистика для каждой тренировки" + +#: lib/gui/display.py:113 +msgid "Preview updates every 5 seconds" +msgstr "Предпросмотр обновляется каждые 5 секунд" + +#: lib/gui/display.py:122 +msgid "Graph showing Loss vs Iterations" +msgstr "График зависимости потерь от количества итераций" + +#: lib/gui/display.py:125 +msgid "Training preview. Updated on every save iteration" +msgstr "Предпросмотр тренировки. Обновляется каждую сохраняющую итерацию" + +#: lib/gui/display_analysis.py:342 +msgid "Load/Refresh stats for the currently training session" +msgstr "Загрузить/обновить статистику для текущей тренировки" + +#: lib/gui/display_analysis.py:344 +msgid "Clear currently displayed session stats" +msgstr "Очистить отображаемую статистику сессии" + +#: lib/gui/display_analysis.py:346 +msgid "Save session stats to csv" +msgstr "Сохранить статистику сессии в csv файл" + +#: lib/gui/display_analysis.py:348 +msgid "Load saved session stats" +msgstr "Загрузить сохраненную статистику" + +#: lib/gui/display_command.py:94 +msgid "Preview updates at every model save. Click to refresh now." +msgstr "" +"Предпросмотр обновляется при каждом сохранении модели. Нажмите, чтобы " +"обновить сейчас." + +#: lib/gui/display_command.py:261 +msgid "Graph updates at every model save. Click to refresh now." +msgstr "" +"График обновляется при каждом сохранении модели. Нажмите, чтобы обновить " +"сейчас." + +#: lib/gui/display_command.py:275 +msgid "Display the raw loss data" +msgstr "Показать необработанные данные о потерях" + +#: lib/gui/display_command.py:287 +msgid "Display the smoothed loss data" +msgstr "Показать сглаженные данные о потерях" + +#: lib/gui/display_command.py:294 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing." +msgstr "" +"Установите величину сглаживания. 0 - нет сглаживания, 0.99 - максимальное " +"сглаживание." + +#: lib/gui/display_command.py:324 +msgid "Set the number of iterations to display. 0 displays the full session." +msgstr "" +"Установите количество итераций для отображения. 0 отображает полный сеанс." + +#: lib/gui/display_page.py:238 +msgid "Save {}(s) to file" +msgstr "Сохранить {}(ы) в файл" + +#: lib/gui/display_page.py:250 +msgid "Enable or disable {} display" +msgstr "Включить или выключить отображение {}" + +#: lib/gui/popup_configure.py:209 +msgid "Close without saving" +msgstr "Закрыть без сохранения" + +#: lib/gui/popup_configure.py:210 +msgid "Save this page's config" +msgstr "Сохранить конфигурацию этой страницы" + +#: lib/gui/popup_configure.py:211 +msgid "Reset this page's config to default values" +msgstr "Сбросить конфигурацию этой страницы до заводских значений" + +#: lib/gui/popup_configure.py:213 +msgid "Save all settings for the currently selected config" +msgstr "Сохранить все настройки для текущей выбранной конфигурации" + +#: lib/gui/popup_configure.py:216 +msgid "Reset all settings for the currently selected config to default values" +msgstr "" +"Сбросить все настройки для текущей выбранной конфигурации до заводских " +"значений" + +#: lib/gui/popup_configure.py:538 +msgid "Select a plugin to configure:" +msgstr "Выбрать плагин для настройки:" + +#: lib/gui/popup_session.py:191 +msgid "Display {}" +msgstr "Показать {}" + +#: lib/gui/popup_session.py:342 +msgid "Refresh graph" +msgstr "Обновить график" + +#: lib/gui/popup_session.py:344 +msgid "Save display data to csv" +msgstr "Сохранить данные дисплея в csv файл" + +#: lib/gui/popup_session.py:346 +msgid "Number of data points to sample for rolling average" +msgstr "Количество точек данных для выборки среднего значения" + +#: lib/gui/popup_session.py:348 +msgid "Set the smoothing amount. 0 is no smoothing, 0.99 is maximum smoothing" +msgstr "" +"Установите величину сглаживания. 0 - нет сглаживания, 0.99 - максимальное " +"сглаживание" + +#: lib/gui/popup_session.py:350 +msgid "" +"Flatten data points that fall more than 1 standard deviation from the mean " +"to the mean value." +msgstr "" +"Сглаживание точек данных, которые отклоняются от среднего значения более чем " +"на 1 стандартное отклонение, до среднего значения." + +#: lib/gui/popup_session.py:353 +msgid "Display rolling average of the data" +msgstr "Показать среднее значение данных" + +#: lib/gui/popup_session.py:355 +msgid "Smooth the data" +msgstr "Сгладить данные" + +#: lib/gui/popup_session.py:357 +msgid "Display raw data" +msgstr "Показать необработанные данные" + +#: lib/gui/popup_session.py:359 +msgid "Display polynormal data trend" +msgstr "Отображение полинормальной тенденции данных" + +#: lib/gui/popup_session.py:361 +msgid "Set the data to display" +msgstr "Указать данные для отображения" + +#: lib/gui/popup_session.py:363 +msgid "Change y-axis scale" +msgstr "Изменить масштаб оси y" diff --git a/locales/ru/LC_MESSAGES/lib.cli.args.mo b/locales/ru/LC_MESSAGES/lib.cli.args.mo new file mode 100644 index 0000000000..5b51831151 Binary files /dev/null and b/locales/ru/LC_MESSAGES/lib.cli.args.mo differ diff --git a/locales/ru/LC_MESSAGES/lib.cli.args.po b/locales/ru/LC_MESSAGES/lib.cli.args.po new file mode 100755 index 0000000000..2f6a55435f --- /dev/null +++ b/locales/ru/LC_MESSAGES/lib.cli.args.po @@ -0,0 +1,63 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 18:06+0000\n" +"PO-Revision-Date: 2024-03-28 18:23+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.2\n" + +#: lib/cli/args.py:188 lib/cli/args.py:199 lib/cli/args.py:208 +#: lib/cli/args.py:219 +msgid "Global Options" +msgstr "Глобальные Настройки" + +#: lib/cli/args.py:190 +msgid "" +"R|Exclude GPUs from use by Faceswap. Select the number(s) which correspond " +"to any GPU(s) that you do not wish to be made available to Faceswap. " +"Selecting all GPUs here will force Faceswap into CPU mode.\n" +"L|{}" +msgstr "" +"R|Исключить GPU из использования Faceswap. Выберите номер (номера), " +"соответствующие любому GPU, который вы не хотите предоставлять Faceswap. " +"Если выбрать здесь все GPU, Faceswap перейдет в режим CPU.\n" +"L|{}" + +#: lib/cli/args.py:201 +msgid "" +"Optionally overide the saved config with the path to a custom config file." +msgstr "" +"Опционально переопределите сохраненную конфигурацию, указав путь к " +"пользовательскому файлу конфигурации." + +#: lib/cli/args.py:210 +msgid "" +"Log level. Stick with INFO or VERBOSE unless you need to file an error " +"report. Be careful with TRACE as it will generate a lot of data" +msgstr "" +"Уровень логирования. Придерживайтесь INFO или VERBOSE, если только вам не " +"нужно отправить отчет об ошибке. Будьте осторожны с TRACE, поскольку он " +"генерирует много данных" + +#: lib/cli/args.py:220 +msgid "Path to store the logfile. Leave blank to store in the faceswap folder" +msgstr "" +"Путь для хранения файла журнала. Оставьте пустым, чтобы хранить в папке " +"faceswap" + +#: lib/cli/args.py:319 +msgid "Output to Shell console instead of GUI console" +msgstr "Вывод в консоль Shell вместо консоли GUI" diff --git a/locales/ru/LC_MESSAGES/lib.cli.args_extract_convert.mo b/locales/ru/LC_MESSAGES/lib.cli.args_extract_convert.mo new file mode 100644 index 0000000000..51d7f9676f Binary files /dev/null and b/locales/ru/LC_MESSAGES/lib.cli.args_extract_convert.mo differ diff --git a/locales/ru/LC_MESSAGES/lib.cli.args_extract_convert.po b/locales/ru/LC_MESSAGES/lib.cli.args_extract_convert.po new file mode 100755 index 0000000000..e95bf84dd7 --- /dev/null +++ b/locales/ru/LC_MESSAGES/lib.cli.args_extract_convert.po @@ -0,0 +1,710 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-12 11:56+0100\n" +"PO-Revision-Date: 2024-04-12 11:59+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.2\n" + +#: lib/cli/args_extract_convert.py:46 lib/cli/args_extract_convert.py:56 +#: lib/cli/args_extract_convert.py:64 lib/cli/args_extract_convert.py:122 +#: lib/cli/args_extract_convert.py:483 lib/cli/args_extract_convert.py:492 +msgid "Data" +msgstr "Данные" + +#: lib/cli/args_extract_convert.py:48 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file. NB: This should be the source video/" +"frames NOT the source faces." +msgstr "" +"Входная папка или видео. Либо каталог, содержащий файлы изображений, которые " +"вы хотите обработать, либо путь к видеофайлу. ПРИМЕЧАНИЕ: Это должно быть " +"исходное видео/кадры, а не исходные лица." + +#: lib/cli/args_extract_convert.py:57 +msgid "Output directory. This is where the converted files will be saved." +msgstr "Выходная папка. Здесь будут сохранены преобразованные файлы." + +#: lib/cli/args_extract_convert.py:66 +msgid "" +"Optional path to an alignments file. Leave blank if the alignments file is " +"at the default location." +msgstr "" +"Необязательный путь к файлу выравниваний. Оставьте пустым, если файл " +"выравнивания находится в месте по умолчанию." + +#: lib/cli/args_extract_convert.py:97 +msgid "" +"Extract faces from image or video sources.\n" +"Extraction plugins can be configured in the 'Settings' Menu" +msgstr "" +"Извлечение лиц из источников изображений или видео.\n" +"Плагины извлечения можно настроить в меню \"Настройки\"" + +#: lib/cli/args_extract_convert.py:124 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple videos and/or folders of images you wish to extract from. The faces " +"will be output to separate sub-folders in the output_dir." +msgstr "" +"R|Если выбрано, то input_dir должен быть родительской папкой, содержащей " +"несколько видео и/или папок с изображениями, из которых вы хотите извлечь " +"изображение. Лица будут выведены в отдельные вложенные папки в output_dir." + +#: lib/cli/args_extract_convert.py:133 lib/cli/args_extract_convert.py:152 +#: lib/cli/args_extract_convert.py:167 lib/cli/args_extract_convert.py:206 +#: lib/cli/args_extract_convert.py:224 lib/cli/args_extract_convert.py:237 +#: lib/cli/args_extract_convert.py:247 lib/cli/args_extract_convert.py:257 +#: lib/cli/args_extract_convert.py:503 lib/cli/args_extract_convert.py:529 +#: lib/cli/args_extract_convert.py:568 +msgid "Plugins" +msgstr "Плагины" + +#: lib/cli/args_extract_convert.py:135 +msgid "" +"R|Detector to use. Some of these have configurable settings in '/config/" +"extract.ini' or 'Settings > Configure Extract 'Plugins':\n" +"L|cv2-dnn: A CPU only extractor which is the least reliable and least " +"resource intensive. Use this if not using a GPU and time is important.\n" +"L|mtcnn: Good detector. Fast on CPU, faster on GPU. Uses fewer resources " +"than other GPU detectors but can often return more false positives.\n" +"L|s3fd: Best detector. Slow on CPU, faster on GPU. Can detect more faces and " +"fewer false positives than other GPU detectors, but is a lot more resource " +"intensive.\n" +"L|external: Import a face detection bounding box from a json file. " +"(configurable in Detect settings)" +msgstr "" +"R|Детектор для использования. Некоторые из них имеют настраиваемые параметры " +"в '/config/extract.ini' или 'Settings > Configure Extract 'Plugins':\n" +"L|cv2-dnn: Экстрактор только для процессора, который является наименее " +"надежным и наименее ресурсоемким. Используйте его, если не используется GPU " +"и важно время.\n" +"L|mtcnn: Хороший детектор. Быстрый на CPU, еще быстрее на GPU. Использует " +"меньше ресурсов, чем другие детекторы на GPU, но часто может давать больше " +"ложных срабатываний.\n" +"L|s3fd: Лучший детектор. Медленный на CPU, более быстрый на GPU. Может " +"обнаружить больше лиц и меньше ложных срабатываний, чем другие детекторы на " +"GPU, но требует гораздо больше ресурсов.\n" +"L|external: импортируйте ограничивающую коробку обнаружения лица из файла " +"JSON. (настраивается в настройках обнаружения)" + +#: lib/cli/args_extract_convert.py:154 +msgid "" +"R|Aligner to use.\n" +"L|cv2-dnn: A CPU only landmark detector. Faster, less resource intensive, " +"but less accurate. Only use this if not using a GPU and time is important.\n" +"L|fan: Best aligner. Fast on GPU, slow on CPU.\n" +"L|external: Import 68 point 2D landmarks or an aligned bounding box from a " +"json file. (configurable in Align settings)" +msgstr "" +"R|Выравниватель для использования.\n" +"L|cv2-dnn: Детектор ориентиров только для процессора. Быстрее, менее " +"ресурсоемкий, но менее точный. Используйте его, только если не используется " +"GPU и важно время.\n" +"L|fan: Лучший выравниватель. Быстрый на GPU, медленный на CPU.\n" +"L|external: импорт 68 баллов 2D достопримечательности или выровненная " +"ограничивающая коробка из файла JSON. (настраивается в настройках " +"выравнивания)" + +#: lib/cli/args_extract_convert.py:169 +msgid "" +"R|Additional Masker(s) to use. The masks generated here will all take up GPU " +"RAM. You can select none, one or multiple masks, but the extraction may take " +"longer the more you select. NB: The Extended and Components (landmark based) " +"masks are automatically generated on extraction.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU so will not use any additional VRAM.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"The auto generated masks are as follows:\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" +msgstr "" +"R|Дополнительный маскер(ы) для использования. Все маски, созданные здесь, " +"будут занимать видеопамять GPU. Вы можете выбрать ни одной, одну или " +"несколько масок, но извлечение может занять больше времени, чем больше масок " +"вы выберете. Примечание: Расширенные маски и маски компонентов (на основе " +"ориентиров) генерируются автоматически при извлечении.\n" +"L|bisenet-fp: Относительно легкая маска на основе NN, которая обеспечивает " +"более точный контроль над маскируемой областью, включая полное маскирование " +"головы (настраивается в настройках маски).\n" +"L|custom: Фиктивная маска, которая заполняет область маски всеми 1 или 0 " +"(настраивается в настройках). Она необходима только в том случае, если вы " +"собираетесь вручную редактировать пользовательские маски в ручном " +"инструменте. Эта маска не задействует GPU, поэтому не будет использовать " +"дополнительную память VRAM.\n" +"L|vgg-clear: Маска предназначена для интеллектуальной сегментации " +"преимущественно фронтальных лиц без препятствий. Профильные лица и " +"препятствия могут привести к снижению производительности.\n" +"L|vgg-obstructed: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была специально обучена " +"распознавать некоторые препятствия на лице (руки и очки). Лица в профиль " +"могут иметь низкую производительность.\n" +"L|unet-dfl: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была обучена членами " +"сообщества и для дальнейшего описания нуждается в тестировании. Профильные " +"лица могут привести к низкой производительности.\n" +"Автоматически сгенерированные маски выглядят следующим образом:\n" +"L|components: Маска, разработанная для сегментации лица на основе " +"расположения ориентиров. Для создания маски вокруг внешних ориентиров " +"строится выпуклая оболочка.\n" +"L|extended: Маска, предназначенная для сегментации лица на основе " +"расположения ориентиров. Выпуклый корпус строится вокруг внешних ориентиров, " +"и маска расширяется вверх на лоб.\n" +"(например: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" + +#: lib/cli/args_extract_convert.py:208 +msgid "" +"R|Performing normalization can help the aligner better align faces with " +"difficult lighting conditions at an extraction speed cost. Different methods " +"will yield different results on different sets. NB: This does not impact the " +"output face, just the input to the aligner.\n" +"L|none: Don't perform normalization on the face.\n" +"L|clahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"L|hist: Equalize the histograms on the RGB channels.\n" +"L|mean: Normalize the face colors to the mean." +msgstr "" +"R|Проведение нормализации может помочь выравнивателю лучше выравнивать лица " +"со сложными условиями освещения при затратах на скорость извлечения. " +"Различные методы дают разные результаты на разных наборах. NB: Это не влияет " +"на выходное лицо, только на вход выравнивателя.\n" +"L|none: Не выполнять нормализацию лица.\n" +"L|clahe: Выполнить для лица адаптивную гистограммную эквализацию с " +"ограничением контраста.\n" +"L|hist: Уравнять гистограммы в каналах RGB.\n" +"L|mean: Нормализовать цвета лица к среднему значению." + +#: lib/cli/args_extract_convert.py:226 +msgid "" +"The number of times to re-feed the detected face into the aligner. Each time " +"the face is re-fed into the aligner the bounding box is adjusted by a small " +"amount. The final landmarks are then averaged from each iteration. Helps to " +"remove 'micro-jitter' but at the cost of slower extraction speed. The more " +"times the face is re-fed into the aligner, the less micro-jitter should " +"occur but the longer extraction will take." +msgstr "" +"Количество повторных подач обнаруженной области лица в выравниватель. При " +"каждой повторной подаче лица в выравниватель ограничивающая рамка " +"корректируется на небольшую величину. Затем конечные ориентиры усредняются " +"по результатам каждой итерации. Это помогает устранить \"микро-дрожание\", " +"но ценой снижения скорости извлечения. Чем больше раз лицо повторно подается " +"в выравниватель, тем меньше микро-дрожание, но тем больше времени займет " +"извлечение." + +#: lib/cli/args_extract_convert.py:239 +msgid "" +"Re-feed the initially found aligned face through the aligner. Can help " +"produce better alignments for faces that are rotated beyond 45 degrees in " +"the frame or are at extreme angles. Slows down extraction." +msgstr "" +"Повторная подача первоначально найденной выровненной области лица через " +"выравниватель. Может помочь получить лучшее выравнивание для лиц, повернутых " +"в кадре более чем на 45 градусов или расположенных под экстремальными " +"углами. Замедляет извлечение." + +#: lib/cli/args_extract_convert.py:249 +msgid "" +"If a face isn't found, rotate the images to try to find a face. Can find " +"more faces at the cost of extraction speed. Pass in a single number to use " +"increments of that size up to 360, or pass in a list of numbers to enumerate " +"exactly what angles to check." +msgstr "" +"Если лицо не найдено, поворачивает изображения, чтобы попытаться найти лицо. " +"Может найти больше лиц ценой снижения скорости извлечения. Передайте одно " +"число, чтобы использовать приращения этого размера до 360, или передайте " +"список чисел, чтобы перечислить, какие именно углы нужно проверить." + +#: lib/cli/args_extract_convert.py:259 +msgid "" +"Obtain and store face identity encodings from VGGFace2. Slows down extract a " +"little, but will save time if using 'sort by face'" +msgstr "" +"Получение и хранение кодировок идентификации лица из VGGFace2. Немного " +"замедляет извлечение, но экономит время при использовании \"сортировки по " +"лицам\"." + +#: lib/cli/args_extract_convert.py:269 lib/cli/args_extract_convert.py:280 +#: lib/cli/args_extract_convert.py:293 lib/cli/args_extract_convert.py:307 +#: lib/cli/args_extract_convert.py:614 lib/cli/args_extract_convert.py:623 +#: lib/cli/args_extract_convert.py:638 lib/cli/args_extract_convert.py:651 +#: lib/cli/args_extract_convert.py:665 +msgid "Face Processing" +msgstr "Обработка лиц" + +#: lib/cli/args_extract_convert.py:271 +msgid "" +"Filters out faces detected below this size. Length, in pixels across the " +"diagonal of the bounding box. Set to 0 for off" +msgstr "" +"Отфильтровывает лица, обнаруженные ниже этого размера. Длина в пикселях по " +"диагонали ограничивающего поля. Установите значение 0, чтобы выключить" + +#: lib/cli/args_extract_convert.py:282 +msgid "" +"Optionally filter out people who you do not wish to extract by passing in " +"images of those people. Should be a small variety of images at different " +"angles and in different conditions. A folder containing the required images " +"or multiple image files, space separated, can be selected." +msgstr "" +"По желанию отфильтруйте людей, которых вы не хотите извлекать, передав " +"изображения этих людей. Должно быть небольшое разнообразие изображений под " +"разными углами и в разных условиях. Можно выбрать папку, содержащую " +"необходимые изображения, или несколько файлов изображений, разделенных " +"пробелами." + +#: lib/cli/args_extract_convert.py:295 +msgid "" +"Optionally select people you wish to extract by passing in images of that " +"person. Should be a small variety of images at different angles and in " +"different conditions A folder containing the required images or multiple " +"image files, space separated, can be selected." +msgstr "" +"По желанию выберите людей, которых вы хотите извлечь, передав изображения " +"этого человека. Должно быть небольшое разнообразие изображений под разными " +"углами и в разных условиях. Можно выбрать папку, содержащую необходимые " +"изображения, или несколько файлов изображений, разделенных пробелами." + +#: lib/cli/args_extract_convert.py:309 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Higher values are stricter." +msgstr "" +"Для использования с дополнительными файлами nfilter/filter. Порог для " +"положительного распознавания лица. Более высокие значения являются более " +"строгими." + +#: lib/cli/args_extract_convert.py:318 lib/cli/args_extract_convert.py:331 +#: lib/cli/args_extract_convert.py:344 lib/cli/args_extract_convert.py:356 +msgid "output" +msgstr "вывод" + +#: lib/cli/args_extract_convert.py:320 +msgid "" +"The output size of extracted faces. Make sure that the model you intend to " +"train supports your required size. This will only need to be changed for hi-" +"res models." +msgstr "" +"Выходной размер извлеченных лиц. Убедитесь, что модель, которую вы " +"собираетесь тренировать, поддерживает требуемый размер. Это необходимо " +"изменить только для моделей высокого разрешения." + +#: lib/cli/args_extract_convert.py:333 +msgid "" +"Extract every 'nth' frame. This option will skip frames when extracting " +"faces. For example a value of 1 will extract faces from every frame, a value " +"of 10 will extract faces from every 10th frame." +msgstr "" +"Извлекать каждый 'n-й' кадр. Этот параметр пропускает кадры при извлечении " +"лиц. Например, значение 1 будет извлекать лица из каждого кадра, значение 10 " +"будет извлекать лица из каждого 10-го кадра." + +#: lib/cli/args_extract_convert.py:346 +msgid "" +"Automatically save the alignments file after a set amount of frames. By " +"default the alignments file is only saved at the end of the extraction " +"process. NB: If extracting in 2 passes then the alignments file will only " +"start to be saved out during the second pass. WARNING: Don't interrupt the " +"script when writing the file because it might get corrupted. Set to 0 to " +"turn off" +msgstr "" +"Автоматическое сохранение файла выравнивания после заданного количества " +"кадров. По умолчанию файл выравнивания сохраняется только в конце процесса " +"извлечения. Примечание: Если извлечение выполняется в 2 прохода, то файл " +"выравнивания начнет сохраняться только во время второго прохода. " +"ПРЕДУПРЕЖДЕНИЕ: Не прерывайте работу скрипта при записи файла, так как он " +"может быть поврежден. Установите значение 0, чтобы отключить" + +#: lib/cli/args_extract_convert.py:357 +msgid "Draw landmarks on the ouput faces for debugging purposes." +msgstr "Нарисуйте ориентиры на выходящих гранях для отладки." + +#: lib/cli/args_extract_convert.py:363 lib/cli/args_extract_convert.py:373 +#: lib/cli/args_extract_convert.py:381 lib/cli/args_extract_convert.py:388 +#: lib/cli/args_extract_convert.py:678 lib/cli/args_extract_convert.py:691 +#: lib/cli/args_extract_convert.py:712 lib/cli/args_extract_convert.py:718 +msgid "settings" +msgstr "настройки" + +#: lib/cli/args_extract_convert.py:365 +msgid "" +"Don't run extraction in parallel. Will run each part of the extraction " +"process separately (one after the other) rather than all at the same time. " +"Useful if VRAM is at a premium." +msgstr "" +"Не запускать извлечение параллельно. Каждая часть процесса извлечения будет " +"выполняться отдельно (одна за другой), а не одновременно. Полезно, если " +"память VRAM ограничена." + +#: lib/cli/args_extract_convert.py:375 +msgid "" +"Skips frames that have already been extracted and exist in the alignments " +"file" +msgstr "" +"Пропускает кадры, которые уже были извлечены и существуют в файле " +"выравнивания" + +#: lib/cli/args_extract_convert.py:382 +msgid "Skip frames that already have detected faces in the alignments file" +msgstr "" +"Пропустить кадры, в которых уже есть обнаруженные лица в файле выравнивания" + +#: lib/cli/args_extract_convert.py:389 +msgid "Skip saving the detected faces to disk. Just create an alignments file" +msgstr "" +"Не сохранять обнаруженные лица на диск. Просто создать файл выравнивания" + +#: lib/cli/args_extract_convert.py:463 +msgid "" +"Swap the original faces in a source video/images to your final faces.\n" +"Conversion plugins can be configured in the 'Settings' Menu" +msgstr "" +"Поменять исходные лица в исходном видео/изображении на ваши конечные лица.\n" +"Плагины конвертирования можно настроить в меню \"Настройки\"" + +#: lib/cli/args_extract_convert.py:485 +msgid "" +"Only required if converting from images to video. Provide The original video " +"that the source frames were extracted from (for extracting the fps and " +"audio)." +msgstr "" +"Требуется только при преобразовании из изображений в видео. Предоставьте " +"исходное видео, из которого были извлечены исходные кадры (для извлечения " +"кадров в секунду и звука)." + +#: lib/cli/args_extract_convert.py:494 +msgid "" +"Model directory. The directory containing the trained model you wish to use " +"for conversion." +msgstr "" +"Папка модели. Папка, содержащая обученную модель, которую вы хотите " +"использовать для преобразования." + +#: lib/cli/args_extract_convert.py:505 +msgid "" +"R|Performs color adjustment to the swapped face. Some of these options have " +"configurable settings in '/config/convert.ini' or 'Settings > Configure " +"Convert Plugins':\n" +"L|avg-color: Adjust the mean of each color channel in the swapped " +"reconstruction to equal the mean of the masked area in the original image.\n" +"L|color-transfer: Transfers the color distribution from the source to the " +"target image using the mean and standard deviations of the L*a*b* color " +"space.\n" +"L|manual-balance: Manually adjust the balance of the image in a variety of " +"color spaces. Best used with the Preview tool to set correct values.\n" +"L|match-hist: Adjust the histogram of each color channel in the swapped " +"reconstruction to equal the histogram of the masked area in the original " +"image.\n" +"L|seamless-clone: Use cv2's seamless clone function to remove extreme " +"gradients at the mask seam by smoothing colors. Generally does not give very " +"satisfactory results.\n" +"L|none: Don't perform color adjustment." +msgstr "" +"R|Производит корректировку цвета поменявшегося лица. Некоторые из этих " +"параметров настраиваются в '/config/convert.ini' или 'Настройки > Настроить " +"плагины конвертации':\n" +"L|avg-color: корректирует среднее значение каждого цветового канала в " +"реконструкции, чтобы оно было равно среднему значению маскированной области " +"в исходном изображении.\n" +"L|color-transfer: Переносит распределение цветов с исходного изображения на " +"целевое, используя среднее и стандартные отклонения цветового пространства " +"L*a*b*.\n" +"L|manual-balance: Ручная настройка баланса изображения в различных цветовых " +"пространствах. Лучше всего использовать с инструментом предварительного " +"просмотра для установки правильных значений.\n" +"L|match-hist: Настроить гистограмму каждого цветового канала в измененном " +"восстановлении так, чтобы она соответствовала гистограмме маскированной " +"области исходного изображения.\n" +"L|seamless-clone: Используйте функцию бесшовного клонирования cv2 для " +"удаления экстремальных градиентов на шве маски путем сглаживания цветов. " +"Обычно дает не очень удовлетворительные результаты.\n" +"L|none: Не выполнять коррекцию цвета." + +#: lib/cli/args_extract_convert.py:531 +msgid "" +"R|Masker to use. NB: The mask you require must exist within the alignments " +"file. You can add additional masks with the Mask Tool.\n" +"L|none: Don't use a mask.\n" +"L|bisenet-fp_face: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'face' or " +"'legacy' centering.\n" +"L|bisenet-fp_head: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'head' " +"centering.\n" +"L|custom_face: Custom user created, face centered mask.\n" +"L|custom_head: Custom user created, head centered mask.\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance.\n" +"L|predicted: If the 'Learn Mask' option was enabled during training, this " +"will use the mask that was created by the trained model." +msgstr "" +"R|Маскер для использования. Примечание: Нужная маска должна существовать в " +"файле выравнивания. Вы можете добавить дополнительные маски с помощью " +"инструмента Mask Tool.\n" +"L|none: Не использовать маску.\n" +"L|bisenet-fp_face: Относительно легкая маска на основе NN, которая " +"обеспечивает более точный контроль над маскируемой областью (настраивается в " +"настройках маски). Используйте эту версию bisenet-fp, если ваша модель " +"обучена с центрированием 'face' или 'legacy'.\n" +"L|bisenet-fp_head: Относительно легкая маска на основе NN, которая " +"обеспечивает более точный контроль над маскируемой областью (настраивается в " +"настройках маски). Используйте эту версию bisenet-fp, если ваша модель " +"обучена с центрированием по \"голове\".\n" +"L|custom_face: Пользовательская маска, созданная пользователем и " +"центрированная по лицу.\n" +"L|custom_head: Созданная пользователем маска, центрированная по голове.\n" +"L|components: Маска, разработанная для сегментации лица на основе " +"расположения ориентиров. Для создания маски вокруг внешних ориентиров " +"строится выпуклая оболочка.\n" +"L|extended: Маска, предназначенная для сегментации лица на основе " +"расположения ориентиров. Выпуклый корпус строится вокруг внешних ориентиров, " +"и маска расширяется вверх на лоб.\n" +"L|vgg-clear: Маска предназначена для интеллектуальной сегментации " +"преимущественно фронтальных лиц без препятствий. Профильные лица и " +"препятствия могут привести к снижению производительности.\n" +"L|vgg-obstructed: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была специально обучена " +"распознавать некоторые препятствия на лице (руки и очки). Лица в профиль " +"могут иметь низкую производительность.\n" +"L|unet-dfl: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была обучена членами " +"сообщества и для дальнейшего описания нуждается в тестировании. Профильные " +"лица могут привести к низкой производительности.\n" +"L|predicted: Если во время обучения была включена опция 'Изучить Маску', то " +"будет использоваться маска, созданная обученной моделью." + +#: lib/cli/args_extract_convert.py:570 +msgid "" +"R|The plugin to use to output the converted images. The writers are " +"configurable in '/config/convert.ini' or 'Settings > Configure Convert " +"Plugins:'\n" +"L|ffmpeg: [video] Writes out the convert straight to video. When the input " +"is a series of images then the '-ref' (--reference-video) parameter must be " +"set.\n" +"L|gif: [animated image] Create an animated gif.\n" +"L|opencv: [images] The fastest image writer, but less options and formats " +"than other plugins.\n" +"L|patch: [images] Outputs the raw swapped face patch, along with the " +"transformation matrix required to re-insert the face back into the original " +"frame. Use this option if you wish to post-process and composite the final " +"face within external tools.\n" +"L|pillow: [images] Slower than opencv, but has more options and supports " +"more formats." +msgstr "" +"R|Плагин, который нужно использовать для вывода преобразованных изображений. " +"Записи настраиваются в '/config/convert.ini' или 'Настройки > Настроить " +"плагины конвертации:'\n" +"L|ffmpeg: [видео] Записывает конвертацию прямо в видео. Если на вход " +"подается серия изображений, необходимо установить параметр '-ref' (--" +"reference-video).\n" +"L|gif: [анимированное изображение] Создает анимированный gif.\n" +"L|opencv: [изображения] Самый быстрый редактор изображений, но имеет меньше " +"опций и форматов, чем другие плагины.\n" +"L|patch: [изображения] Выводит необработанный фрагмент измененного лица " +"вместе с матрицей преобразования, необходимой для повторной вставки лица " +"обратно в исходный кадр.\n" +"L|pillow: [изображения] Медленнее, чем opencv, но имеет больше опций и " +"поддерживает больше форматов." + +#: lib/cli/args_extract_convert.py:591 lib/cli/args_extract_convert.py:600 +#: lib/cli/args_extract_convert.py:703 +msgid "Frame Processing" +msgstr "Обработка лиц" + +#: lib/cli/args_extract_convert.py:593 +#, python-format +msgid "" +"Scale the final output frames by this amount. 100%% will output the frames " +"at source dimensions. 50%% at half size 200%% at double size" +msgstr "" +"Масштабирование конечных выходных кадров на эту величину. 100%% выводит " +"кадры в исходном размере. 50%% при половинном размере 200%% при двойном " +"размере" + +#: lib/cli/args_extract_convert.py:602 +msgid "" +"Frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use " +"--frame-ranges 10-50 90-100. Frames falling outside of the selected range " +"will be discarded unless '-k' (--keep-unchanged) is selected. NB: If you are " +"converting from images, then the filenames must end with the frame-number!" +msgstr "" +"Диапазоны кадров для применения переноса, например, для кадров с 10 по 50 и " +"с 90 по 100 используйте --frame-ranges 10-50 90-100. Кадры, выходящие за " +"пределы выбранного диапазона, будут отброшены, если не выбрана опция '-k' (--" +"keep-unchanged). Примечание: Если вы конвертируете из изображений, то имена " +"файлов должны заканчиваться номером кадра!" + +#: lib/cli/args_extract_convert.py:616 +msgid "" +"Scale the swapped face by this percentage. Positive values will enlarge the " +"face, Negative values will shrink the face." +msgstr "" +"Увеличить масштаб нового лица на этот процент. Положительные значения " +"увеличат лицо, в то время как отрицательные значения уменьшат его." + +#: lib/cli/args_extract_convert.py:625 +msgid "" +"If you have not cleansed your alignments file, then you can filter out faces " +"by defining a folder here that contains the faces extracted from your input " +"files/video. If this folder is defined, then only faces that exist within " +"your alignments file and also exist within the specified folder will be " +"converted. Leaving this blank will convert all faces that exist within the " +"alignments file." +msgstr "" +"Если вы не очистили свой файл выравнивания, то вы можете отфильтровать лица, " +"определив здесь папку, содержащую лица, извлеченные из ваших входных файлов/" +"видео. Если эта папка определена, то будут преобразованы только те лица, " +"которые существуют в вашем файле выравнивания, а также в указанной папке. " +"Если оставить этот параметр пустым, будут преобразованы все лица, " +"существующие в файле выравнивания." + +#: lib/cli/args_extract_convert.py:640 +msgid "" +"Optionally filter out people who you do not wish to process by passing in an " +"image of that person. Should be a front portrait with a single person in the " +"image. Multiple images can be added space separated. NB: Using face filter " +"will significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"По желанию отфильтровать людей, которых вы не хотите обрабатывать, передав " +"изображение этого человека. Это должен быть фронтальный портрет с " +"изображением одного человека. Можно добавить несколько изображений, " +"разделенных пробелами. Примечание: Использование фильтра лиц значительно " +"снизит скорость извлечения, а его точность не гарантируется." + +#: lib/cli/args_extract_convert.py:653 +msgid "" +"Optionally select people you wish to process by passing in an image of that " +"person. Should be a front portrait with a single person in the image. " +"Multiple images can be added space separated. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"По желанию выберите людей, которых вы хотите обработать, передав изображение " +"этого человека. Это должен быть фронтальный портрет с изображением одного " +"человека. Можно добавить несколько изображений, разделенных пробелами. " +"Примечание: Использование фильтра лиц значительно снизит скорость " +"извлечения, а его точность не гарантируется." + +#: lib/cli/args_extract_convert.py:667 +msgid "" +"For use with the optional nfilter/filter files. Threshold for positive face " +"recognition. Lower values are stricter. NB: Using face filter will " +"significantly decrease extraction speed and its accuracy cannot be " +"guaranteed." +msgstr "" +"Для использования с дополнительными файлами nfilter/filter. Порог для " +"положительного распознавания лиц. Более низкие значения являются более " +"строгими. Примечание: Использование фильтра лиц значительно снизит скорость " +"извлечения, а его точность не гарантируется." + +#: lib/cli/args_extract_convert.py:680 +msgid "" +"The maximum number of parallel processes for performing conversion. " +"Converting images is system RAM heavy so it is possible to run out of memory " +"if you have a lot of processes and not enough RAM to accommodate them all. " +"Setting this to 0 will use the maximum available. No matter what you set " +"this to, it will never attempt to use more processes than are available on " +"your system. If singleprocess is enabled this setting will be ignored." +msgstr "" +"Максимальное количество параллельных процессов для выполнения конвертации. " +"Конвертирование изображений занимает много системной оперативной памяти, " +"поэтому может закончиться память, если у вас много процессов и недостаточно " +"оперативной памяти для их размещения. Если установить значение 0, будет " +"использован максимум доступной памяти. Независимо от того, какое значение вы " +"установите, программа никогда не будет пытаться использовать больше " +"процессов, чем доступно в вашей системе. Если включена однопоточная " +"обработка, этот параметр будет проигнорирован." + +#: lib/cli/args_extract_convert.py:693 +msgid "" +"Enable On-The-Fly Conversion. NOT recommended. You should generate a clean " +"alignments file for your destination video. However, if you wish you can " +"generate the alignments on-the-fly by enabling this option. This will use an " +"inferior extraction pipeline and will lead to substandard results. If an " +"alignments file is found, this option will be ignored." +msgstr "" +"Включить преобразование \"на лету\". НЕ рекомендуется. Вы должны " +"сгенерировать чистый файл выравнивания для конечного видео. Однако при " +"желании вы можете генерировать выравнивания \"на лету\", включив эту опцию. " +"При этом будет использоваться некачественный конвейер извлечения, что " +"приведет к некачественным результатам. Если файл выравнивания найден, этот " +"параметр будет проигнорирован." + +#: lib/cli/args_extract_convert.py:705 +msgid "" +"When used with --frame-ranges outputs the unchanged frames that are not " +"processed instead of discarding them." +msgstr "" +"При использовании с --frame-ranges выводит неизмененные кадры, которые не " +"были обработаны, вместо того, чтобы отбрасывать их." + +#: lib/cli/args_extract_convert.py:713 +msgid "Swap the model. Instead converting from of A -> B, converts B -> A" +msgstr "" +"Поменять модель местами. Вместо преобразования из A -> B, преобразуется B -> " +"A" + +#: lib/cli/args_extract_convert.py:719 +msgid "Disable multiprocessing. Slower but less resource intensive." +msgstr "Отключение многопоточной обработки. Медленнее, но менее ресурсоемко." + +#~ msgid "" +#~ "[LEGACY] This only needs to be selected if a legacy model is being loaded " +#~ "or if there are multiple models in the model folder" +#~ msgstr "" +#~ "[ОТБРОШЕН] Этот параметр необходимо выбрать только в том случае, если " +#~ "загружается устаревшая модель или если в папке моделей имеется несколько " +#~ "моделей" diff --git a/locales/ru/LC_MESSAGES/lib.cli.args_train.mo b/locales/ru/LC_MESSAGES/lib.cli.args_train.mo new file mode 100644 index 0000000000..7ae5608465 Binary files /dev/null and b/locales/ru/LC_MESSAGES/lib.cli.args_train.mo differ diff --git a/locales/ru/LC_MESSAGES/lib.cli.args_train.po b/locales/ru/LC_MESSAGES/lib.cli.args_train.po new file mode 100755 index 0000000000..e78537cc6b --- /dev/null +++ b/locales/ru/LC_MESSAGES/lib.cli.args_train.po @@ -0,0 +1,1060 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-15 20:02+0000\n" +"PO-Revision-Date: 2025-12-19 23:27+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.8\n" + +#: lib/cli/args_train.py:30 +msgid "" +"Train a model on extracted original (A) and swap (B) faces.\n" +"Training models can take a long time. Anything from 24hrs to over a week\n" +"Model plugins can be configured in the 'Settings' Menu" +msgstr "" +"Обучение модели на извлеченных оригинальных (A) и подмененных (B) лицах.\n" +"Обучение моделей может занять много времени. От 24 часов до недели.\n" +"Плагины для моделей можно настроить в меню \"Настройки\"" + +#: lib/cli/args_train.py:49 lib/cli/args_train.py:58 +msgid "faces" +msgstr "лица" + +#: lib/cli/args_train.py:51 +msgid "" +"Input directory. A directory containing training images for face A. This is " +"the original face, i.e. the face that you want to remove and replace with " +"face B." +msgstr "" +"Входная папка. Папка, содержащая обучающие изображения для лица A. Это " +"исходное лицо, т.е. лицо, которое вы хотите удалить и заменить лицом B." + +#: lib/cli/args_train.py:60 +msgid "" +"Input directory. A directory containing training images for face B. This is " +"the swap face, i.e. the face that you want to place onto the head of person " +"A." +msgstr "" +"Входная папка. Папка, содержащая обучающие изображения для лица B. Это " +"подменное лицо, т.е. лицо, которое вы хотите поместить на голову человека A." + +#: lib/cli/args_train.py:67 lib/cli/args_train.py:80 lib/cli/args_train.py:97 +#: lib/cli/args_train.py:123 lib/cli/args_train.py:133 +msgid "model" +msgstr "модель" + +#: lib/cli/args_train.py:69 +msgid "" +"Model directory. This is where the training data will be stored. You should " +"always specify a new folder for new models. If starting a new model, select " +"either an empty folder, or a folder which does not exist (which will be " +"created). If continuing to train an existing model, specify the location of " +"the existing model." +msgstr "" +"Папка модели. Здесь будут храниться данные для обучения. Для новых моделей " +"всегда следует указывать новую папку. Если вы начинаете новую модель, " +"выберите либо пустую папку, либо несуществующую папку (которая будет " +"создана). Если вы продолжаете обучение существующей модели, укажите " +"местоположение существующей модели." + +#: lib/cli/args_train.py:82 +msgid "" +"R|Load the weights from a pre-existing model into a newly created model. For " +"most models this will load weights from the Encoder of the given model into " +"the encoder of the newly created model. Some plugins may have specific " +"configuration options allowing you to load weights from other layers. " +"Weights will only be loaded when creating a new model. This option will be " +"ignored if you are resuming an existing model. Generally you will also want " +"to 'freeze-weights' whilst the rest of your model catches up with your " +"Encoder.\n" +"NB: Weights can only be loaded from models of the same plugin as you intend " +"to train." +msgstr "" +"R|Загрузить веса из уже существующей модели во вновь созданную модель. Для " +"большинства моделей это означает загрузку весов из кодировщика данной модели " +"в кодировщик вновь создаваемой модели. Некоторые плагины могут иметь " +"специальные параметры конфигурации, позволяющие загружать веса из других " +"слоев. Веса будут загружаться только при создании новой модели. Эта опция " +"будет проигнорирована, если вы возобновляете существующую модель. Обычно " +"также требуется \"заморозить\" веса, пока остальная часть модели догоняет " +"кодировщик.\n" +"Примечание: Веса могут быть загружены только из моделей того же плагина, " +"который вы собираетесь обучать." + +#: lib/cli/args_train.py:99 +msgid "" +"R|Select which trainer to use. Trainers can be configured from the Settings " +"menu or the config folder.\n" +"L|original: The original model created by /u/deepfakes.\n" +"L|dfaker: 64px in/128px out model from dfaker. Enable 'warp-to-landmarks' " +"for full dfaker method.\n" +"L|dfl-h128: 128px in/out model from deepfacelab\n" +"L|dfl-sae: Adaptable model from deepfacelab\n" +"L|dlight: A lightweight, high resolution DFaker variant.\n" +"L|iae: A model that uses intermediate layers to try to get better details\n" +"L|lightweight: A lightweight model for low-end cards. Don't expect great " +"results. Can train as low as 1.6GB with batch size 8.\n" +"L|realface: A high detail, dual density model based on DFaker, with " +"customizable in/out resolution. The autoencoders are unbalanced so B>A swaps " +"won't work so well. By andenixa et al. Very configurable.\n" +"L|unbalanced: 128px in/out model from andenixa. The autoencoders are " +"unbalanced so B>A swaps won't work so well. Very configurable.\n" +"L|villain: 128px in/out model from villainguy. Very resource hungry (You " +"will require a GPU with a fair amount of VRAM). Good for details, but more " +"susceptible to color differences." +msgstr "" +"R|Выберите, какой тренажер использовать. Тренажеры можно настроить в меню " +"\"Настройки\" или в папке config.\n" +"L|original: Оригинальная модель, созданная /u/deepfakes.\n" +"L|dfaker: модель 64px вход/ 128px выход от dfaker. Включите 'warp-to-" +"landmarks' для полного метода dfaker.\n" +"L|dfl-h128: модель 128px вход/выход от deepfacelab\n" +"L|dfl-sae: Адаптируемая модель от deepfacelab\n" +"L|dlight: Легкий вариант DFaker с высоким разрешением.\n" +"L|iae: Модель, использующая промежуточные слои для получения лучших " +"деталей.\n" +"L|lightweight: Облегченная модель для карт низкого класса. Не ожидайте " +"высоких результатов. Может обучаться на 1,6 ГБ при размере пачки 8.\n" +"L|realface: Модель с высокой детализацией и двойной плотностью, основанная " +"на DFaker, с настраиваемым разрешением входа/выхода. Автоэнкодеры " +"несбалансированы, поэтому замены B>A не будут работать так хорошо. Автор " +"andenixa и др. Очень настраиваемая.\n" +"L|unbalanced: модель 128px вход/выход от andenixa. Автокодировщики " +"несбалансированы, поэтому замены B>A не будут работать так хорошо. Очень " +"настраиваемая.\n" +"L|villain: модель 128px вход/выход от villainguy. Очень требовательна к " +"ресурсам (вам потребуется GPU с достаточным количеством VRAM). Хороша для " +"детализации, но более восприимчива к цветовым различиям." + +#: lib/cli/args_train.py:125 +msgid "" +"Output a summary of the model and exit. If a model folder is provided then a " +"summary of the saved model is displayed. Otherwise a summary of the model " +"that would be created by the chosen plugin and configuration settings is " +"displayed." +msgstr "" +"Вывести сводку модели и выйти. Если указана папка модели, то выводится " +"сводка сохраненной модели. В противном случае отображается сводка модели, " +"которая будет создана выбранным плагином и настройками конфигурации." + +#: lib/cli/args_train.py:135 +msgid "" +"Freeze the weights of the model. Freezing weights means that some of the " +"parameters in the model will no longer continue to learn, but those that are " +"not frozen will continue to learn. For most models, this will freeze the " +"encoder, but some models may have configuration options for freezing other " +"layers." +msgstr "" +"Заморозить веса модели. Замораживание весов означает, что некоторые " +"параметры в модели больше не будут продолжать обучение, но те, которые не " +"заморожены, будут продолжать обучение. Для большинства моделей это означает " +"замораживание кодера, но некоторые модели могут иметь опции конфигурации для " +"замораживания других слоев." + +#: lib/cli/args_train.py:147 lib/cli/args_train.py:160 +#: lib/cli/args_train.py:174 lib/cli/args_train.py:183 +#: lib/cli/args_train.py:190 lib/cli/args_train.py:199 +msgid "training" +msgstr "тренировка" + +#: lib/cli/args_train.py:149 +msgid "" +"Batch size. This is the number of images processed through the model for " +"each side per iteration. NB: As the model is fed 2 sides at a time, the " +"actual number of images within the model at any one time is double the " +"number that you set here. Larger batches require more GPU RAM." +msgstr "" +"Размер пачки. Это количество изображений, обрабатываемых моделью для каждой " +"стороны за итерацию. Примечание: Поскольку модель обрабатывает 2 стороны " +"одновременно, фактическое количество изображений в модели в любой момент " +"времени будет вдвое больше, чем заданное здесь. Большие партии требуют " +"больше оперативной памяти GPU." + +#: lib/cli/args_train.py:162 +msgid "" +"Length of training in iterations. This is only really used for automation. " +"There is no 'correct' number of iterations a model should be trained for. " +"You should stop training when you are happy with the previews. However, if " +"you want the model to stop automatically at a set number of iterations, you " +"can set that value here." +msgstr "" +"Продолжительность обучения в итерациях. Этот параметр действительно " +"используется только для автоматизации. Не существует \"правильного\" " +"количества итераций, за которое следует обучить модель. Вы должны прекратить " +"обучение, когда будете удовлетворены предварительным просмотром. Однако если " +"вы хотите, чтобы модель автоматически останавливалась при определенном " +"количестве итераций, вы можете задать это значение здесь." + +#: lib/cli/args_train.py:176 +msgid "" +"Learning rate warmup. Linearly increase the learning rate from 0 to the " +"chosen target rate over the number of iterations given here. 0 to disable." +msgstr "" +"Разогрев скорости обучения. Линейно увеличивает скорость обучения от 0 до " +"выбранного целевого значения за указанное здесь количество итераций. 0 — " +"отключить." + +#: lib/cli/args_train.py:184 +msgid "Use distibuted training on multi-gpu setups." +msgstr "" +"Используйте распределенное обучение на системах с несколькими графическими " +"процессорами." + +#: lib/cli/args_train.py:192 +msgid "" +"Disables TensorBoard logging. NB: Disabling logs means that you will not be " +"able to use the graph or analysis for this session in the GUI." +msgstr "" +"Отключает ведение журналов TensorBoard. Примечание: Отключение ведения " +"журналов означает, что вы не сможете использовать график или анализ для этой " +"сессии в графическом интерфейсе." + +#: lib/cli/args_train.py:201 +msgid "" +"Use the Learning Rate Finder to discover the optimal learning rate for " +"training. For new models, this will calculate the optimal learning rate for " +"the model. For existing models this will use the optimal learning rate that " +"was discovered when initializing the model. Setting this option will ignore " +"the manually configured learning rate (configurable in train settings)." +msgstr "" +"Используйте инструмент поиска коэффициента обучения, чтобы найти оптимальную " +"скорость обучения вашей модели. Для новых моделей это позволит рассчитать " +"оптимальный коэффициент обучения для модели. Для существующих моделей будет " +"использован оптимальный коэффициент обучения, найденный при инициализации " +"модели. Установка этой опции приведет к игнорированию вручную настроенного " +"коэффициента обучения (настраиваемого в параметрах обучения)." + +#: lib/cli/args_train.py:214 lib/cli/args_train.py:224 +msgid "Saving" +msgstr "Сохранение" + +#: lib/cli/args_train.py:215 +msgid "Sets the number of iterations between each model save." +msgstr "Устанавливает количество итераций между каждым сохранением модели." + +#: lib/cli/args_train.py:226 +msgid "" +"Sets the number of iterations before saving a backup snapshot of the model " +"in it's current state. Set to 0 for off." +msgstr "" +"Устанавливает количество итераций между каждым сохранением модели. " +"Устанавливает количество итераций перед сохранением резервного снимка модели " +"в текущем состоянии. Установите значение 0 для выключения." + +#: lib/cli/args_train.py:233 lib/cli/args_train.py:245 +#: lib/cli/args_train.py:257 +msgid "timelapse" +msgstr "таймлапс" + +#: lib/cli/args_train.py:235 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'A' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-B parameter." +msgstr "" +"Опционально для создания таймлапса. Timelapse будет сохранять изображение " +"выбранных лиц в папку timelapse-output на каждой итерации сохранения. Это " +"должна быть входная папка с лицами 'A', которые вы хотите использовать для " +"создания timelapse. Вы также должны указать параметры --timelapse-output и --" +"timelapse-input-B." + +#: lib/cli/args_train.py:247 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. " +"This should be the input folder of 'B' faces that you would like to use for " +"creating the timelapse. You must also supply a --timelapse-output and a --" +"timelapse-input-A parameter." +msgstr "" +"Опционально для создания таймлапса. Timelapse будет сохранять изображение " +"выбранных лиц в папку timelapse-output на каждой итерации сохранения. Это " +"должна быть входная папка с лицами 'B', которые вы хотите использовать для " +"создания timelapse. Вы также должны указать параметры --timelapse-output и --" +"timelapse-input-A." + +#: lib/cli/args_train.py:259 +msgid "" +"Optional for creating a timelapse. Timelapse will save an image of your " +"selected faces into the timelapse-output folder at every save iteration. If " +"the input folders are supplied but no output folder, it will default to your " +"model folder/timelapse/" +msgstr "" +"Опционально для создания таймлапса. Timelapse будет сохранять изображение " +"выбранных лиц в папку timelapse-output на каждой итерации сохранения. Если " +"указаны входные папки, но нет выходной папки, то по умолчанию будет выбрана " +"папка модели/timelapse/" + +#: lib/cli/args_train.py:268 lib/cli/args_train.py:275 +msgid "preview" +msgstr "предпросмотр" + +#: lib/cli/args_train.py:269 +msgid "Show training preview output. in a separate window." +msgstr "Показать вывод предварительного просмотра тренировки в отдельном окне." + +#: lib/cli/args_train.py:277 +msgid "" +"Writes the training result to a file. The image will be stored in the root " +"of your FaceSwap folder." +msgstr "" +"Записывает результат обучения в файл. Изображение будет сохранено в корне " +"папки Faceswap." + +#: lib/cli/args_train.py:284 lib/cli/args_train.py:294 +#: lib/cli/args_train.py:304 lib/cli/args_train.py:314 +msgid "augmentation" +msgstr "аугментация" + +#: lib/cli/args_train.py:286 +msgid "" +"Warps training faces to closely matched Landmarks from the opposite face-set " +"rather than randomly warping the face. This is the 'dfaker' way of doing " +"warping." +msgstr "" +"Искажает обучаемые лица до близко подходящих ориентиров из противоположного " +"набора лиц вместо случайного искажения лица. Это способ выполнения искажения " +"от \"dfaker\" ." + +#: lib/cli/args_train.py:296 +msgid "" +"To effectively learn, a random set of images are flipped horizontally. " +"Sometimes it is desirable for this not to occur. Generally this should be " +"left off except for during 'fit training'." +msgstr "" +"Для эффективного обучения случайный набор изображений переворачивается по " +"горизонтали. Иногда желательно, чтобы этого не происходило. Как правило, это " +"не нужно делать, за исключением случаев \"тренировки подгонки\"." + +#: lib/cli/args_train.py:306 +msgid "" +"Color augmentation helps make the model less susceptible to color " +"differences between the A and B sets, at an increased training time cost. " +"Enable this option to disable color augmentation." +msgstr "" +"Аугментация цвета помогает сделать модель менее восприимчивой к цветовым " +"различиям между наборами A и B, что влечет за собой увеличение затрат " +"времени на обучение. Включите этот параметр для отключения цветовой " +"аугментации." + +#: lib/cli/args_train.py:316 +msgid "" +"Warping is integral to training the Neural Network. This option should only " +"be enabled towards the very end of training to try to bring out more detail. " +"Think of it as 'fine-tuning'. Enabling this option from the beginning is " +"likely to kill a model and lead to terrible results." +msgstr "" +"Искажение является неотъемлемой частью обучения нейронной сети. Эту опцию " +"следует включать только в самом конце обучения, чтобы попытаться получить " +"больше деталей. Считайте это \"тонкой настройкой\". Включение этой опции в " +"самом начале, скорее всего, погубит модель и приведет к ужасным результатам." + +#~ msgid "" +#~ "R|Select the distribution stategy to use.\n" +#~ "L|default: Use Tensorflow's default distribution strategy.\n" +#~ "L|central-storage: Centralizes variables on the CPU whilst operations are " +#~ "performed on 1 or more local GPUs. This can help save some VRAM at the " +#~ "cost of some speed by not storing variables on the GPU. Note: Mixed-" +#~ "Precision is not supported on multi-GPU setups.\n" +#~ "L|mirrored: Supports synchronous distributed training across multiple " +#~ "local GPUs. A copy of the model and all variables are loaded onto each " +#~ "GPU with batches distributed to each GPU at each iteration." +#~ msgstr "" +#~ "R|Выберите стратегию распределения для использования.\n" +#~ "L|default: Использовать стратегию распространения Tensorflow по " +#~ "умолчанию.\n" +#~ "L|central-storage: Централизует переменные на CPU, в то время как " +#~ "операции выполняются на 1 или более локальных GPU. Это может помочь " +#~ "сэкономить немного VRAM за счет некоторой скорости, поскольку переменные " +#~ "не хранятся на GPU. Примечание: Mixed-Precision не поддерживается на " +#~ "многопроцессорных установках.\n" +#~ "L|mirrored: Поддерживает синхронное распределенное обучение на нескольких " +#~ "локальных GPU. Копия модели и все переменные загружаются на каждый GPU с " +#~ "распределением партий на каждый GPU на каждой итерации." + +#~ msgid "Global Options" +#~ msgstr "Глобальные Настройки" + +#~ msgid "" +#~ "R|Exclude GPUs from use by Faceswap. Select the number(s) which " +#~ "correspond to any GPU(s) that you do not wish to be made available to " +#~ "Faceswap. Selecting all GPUs here will force Faceswap into CPU mode.\n" +#~ "L|{}" +#~ msgstr "" +#~ "R|Исключить GPU из использования Faceswap. Выберите номер (номера), " +#~ "соответствующие любому GPU, который вы не хотите предоставлять Faceswap. " +#~ "Если выбрать здесь все GPU, Faceswap перейдет в режим CPU.\n" +#~ "L|{}" + +#~ msgid "" +#~ "Optionally overide the saved config with the path to a custom config file." +#~ msgstr "" +#~ "Опционально переопределите сохраненную конфигурацию, указав путь к " +#~ "пользовательскому файлу конфигурации." + +#~ msgid "" +#~ "Log level. Stick with INFO or VERBOSE unless you need to file an error " +#~ "report. Be careful with TRACE as it will generate a lot of data" +#~ msgstr "" +#~ "Уровень логирования. Придерживайтесь INFO или VERBOSE, если только вам не " +#~ "нужно отправить отчет об ошибке. Будьте осторожны с TRACE, поскольку он " +#~ "генерирует много данных" + +#~ msgid "" +#~ "Path to store the logfile. Leave blank to store in the faceswap folder" +#~ msgstr "" +#~ "Путь для хранения файла журнала. Оставьте пустым, чтобы хранить в папке " +#~ "faceswap" + +#~ msgid "Data" +#~ msgstr "Данные" + +#~ msgid "" +#~ "Input directory or video. Either a directory containing the image files " +#~ "you wish to process or path to a video file. NB: This should be the " +#~ "source video/frames NOT the source faces." +#~ msgstr "" +#~ "Входная папка или видео. Либо каталог, содержащий файлы изображений, " +#~ "которые вы хотите обработать, либо путь к видеофайлу. ПРИМЕЧАНИЕ: Это " +#~ "должно быть исходное видео/кадры, а не исходные лица." + +#~ msgid "Output directory. This is where the converted files will be saved." +#~ msgstr "Выходная папка. Здесь будут сохранены преобразованные файлы." + +#~ msgid "" +#~ "Optional path to an alignments file. Leave blank if the alignments file " +#~ "is at the default location." +#~ msgstr "" +#~ "Необязательный путь к файлу выравниваний. Оставьте пустым, если файл " +#~ "выравнивания находится в месте по умолчанию." + +#~ msgid "" +#~ "Extract faces from image or video sources.\n" +#~ "Extraction plugins can be configured in the 'Settings' Menu" +#~ msgstr "" +#~ "Извлечение лиц из источников изображений или видео.\n" +#~ "Плагины извлечения можно настроить в меню \"Настройки\"" + +#~ msgid "" +#~ "R|If selected then the input_dir should be a parent folder containing " +#~ "multiple videos and/or folders of images you wish to extract from. The " +#~ "faces will be output to separate sub-folders in the output_dir." +#~ msgstr "" +#~ "R|Если выбрано, то input_dir должен быть родительской папкой, содержащей " +#~ "несколько видео и/или папок с изображениями, из которых вы хотите извлечь " +#~ "изображение. Лица будут выведены в отдельные вложенные папки в output_dir." + +#~ msgid "Plugins" +#~ msgstr "Плагины" + +#~ msgid "" +#~ "R|Detector to use. Some of these have configurable settings in '/config/" +#~ "extract.ini' or 'Settings > Configure Extract 'Plugins':\n" +#~ "L|cv2-dnn: A CPU only extractor which is the least reliable and least " +#~ "resource intensive. Use this if not using a GPU and time is important.\n" +#~ "L|mtcnn: Good detector. Fast on CPU, faster on GPU. Uses fewer resources " +#~ "than other GPU detectors but can often return more false positives.\n" +#~ "L|s3fd: Best detector. Slow on CPU, faster on GPU. Can detect more faces " +#~ "and fewer false positives than other GPU detectors, but is a lot more " +#~ "resource intensive." +#~ msgstr "" +#~ "R|Детектор для использования. Некоторые из них имеют настраиваемые " +#~ "параметры в '/config/extract.ini' или 'Settings > Configure Extract " +#~ "'Plugins':\n" +#~ "L|cv2-dnn: Экстрактор только для процессора, который является наименее " +#~ "надежным и наименее ресурсоемким. Используйте его, если не используется " +#~ "GPU и важно время.\n" +#~ "L|mtcnn: Хороший детектор. Быстрый на CPU, еще быстрее на GPU. Использует " +#~ "меньше ресурсов, чем другие детекторы на GPU, но часто может давать " +#~ "больше ложных срабатываний.\n" +#~ "L|s3fd: Лучший детектор. Медленный на CPU, более быстрый на GPU. Может " +#~ "обнаружить больше лиц и меньше ложных срабатываний, чем другие детекторы " +#~ "на GPU, но требует гораздо больше ресурсов." + +#~ msgid "" +#~ "R|Aligner to use.\n" +#~ "L|cv2-dnn: A CPU only landmark detector. Faster, less resource intensive, " +#~ "but less accurate. Only use this if not using a GPU and time is " +#~ "important.\n" +#~ "L|fan: Best aligner. Fast on GPU, slow on CPU." +#~ msgstr "" +#~ "R|Выравниватель для использования.\n" +#~ "L|cv2-dnn: Детектор ориентиров только для процессора. Быстрее, менее " +#~ "ресурсоемкий, но менее точный. Используйте его, только если не " +#~ "используется GPU и важно время.\n" +#~ "L|fan: Лучший выравниватель. Быстрый на GPU, медленный на CPU." + +#~ msgid "" +#~ "R|Additional Masker(s) to use. The masks generated here will all take up " +#~ "GPU RAM. You can select none, one or multiple masks, but the extraction " +#~ "may take longer the more you select. NB: The Extended and Components " +#~ "(landmark based) masks are automatically generated on extraction.\n" +#~ "L|bisenet-fp: Relatively lightweight NN based mask that provides more " +#~ "refined control over the area to be masked including full head masking " +#~ "(configurable in mask settings).\n" +#~ "L|custom: A dummy mask that fills the mask area with all 1s or 0s " +#~ "(configurable in settings). This is only required if you intend to " +#~ "manually edit the custom masks yourself in the manual tool. This mask " +#~ "does not use the GPU so will not use any additional VRAM.\n" +#~ "L|vgg-clear: Mask designed to provide smart segmentation of mostly " +#~ "frontal faces clear of obstructions. Profile faces and obstructions may " +#~ "result in sub-par performance.\n" +#~ "L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +#~ "frontal faces. The mask model has been specifically trained to recognize " +#~ "some facial obstructions (hands and eyeglasses). Profile faces may result " +#~ "in sub-par performance.\n" +#~ "L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +#~ "faces. The mask model has been trained by community members and will need " +#~ "testing for further description. Profile faces may result in sub-par " +#~ "performance.\n" +#~ "The auto generated masks are as follows:\n" +#~ "L|components: Mask designed to provide facial segmentation based on the " +#~ "positioning of landmark locations. A convex hull is constructed around " +#~ "the exterior of the landmarks to create a mask.\n" +#~ "L|extended: Mask designed to provide facial segmentation based on the " +#~ "positioning of landmark locations. A convex hull is constructed around " +#~ "the exterior of the landmarks and the mask is extended upwards onto the " +#~ "forehead.\n" +#~ "(eg: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" +#~ msgstr "" +#~ "R|Дополнительный маскер(ы) для использования. Все маски, созданные здесь, " +#~ "будут занимать видеопамять GPU. Вы можете выбрать ни одной, одну или " +#~ "несколько масок, но извлечение может занять больше времени, чем больше " +#~ "масок вы выберете. Примечание: Расширенные маски и маски компонентов (на " +#~ "основе ориентиров) генерируются автоматически при извлечении.\n" +#~ "L|bisenet-fp: Относительно легкая маска на основе NN, которая " +#~ "обеспечивает более точный контроль над маскируемой областью, включая " +#~ "полное маскирование головы (настраивается в настройках маски).\n" +#~ "L|custom: Фиктивная маска, которая заполняет область маски всеми 1 или 0 " +#~ "(настраивается в настройках). Она необходима только в том случае, если вы " +#~ "собираетесь вручную редактировать пользовательские маски в ручном " +#~ "инструменте. Эта маска не задействует GPU, поэтому не будет использовать " +#~ "дополнительную память VRAM.\n" +#~ "L|vgg-clear: Маска предназначена для интеллектуальной сегментации " +#~ "преимущественно фронтальных лиц без препятствий. Профильные лица и " +#~ "препятствия могут привести к снижению производительности.\n" +#~ "L|vgg-obstructed: Маска, разработанная для интеллектуальной сегментации " +#~ "преимущественно фронтальных лиц. Модель маски была специально обучена " +#~ "распознавать некоторые препятствия на лице (руки и очки). Лица в профиль " +#~ "могут иметь низкую производительность.\n" +#~ "L|unet-dfl: Маска, разработанная для интеллектуальной сегментации " +#~ "преимущественно фронтальных лиц. Модель маски была обучена членами " +#~ "сообщества и для дальнейшего описания нуждается в тестировании. " +#~ "Профильные лица могут привести к низкой производительности.\n" +#~ "Автоматически сгенерированные маски выглядят следующим образом:\n" +#~ "L|components: Маска, разработанная для сегментации лица на основе " +#~ "расположения ориентиров. Для создания маски вокруг внешних ориентиров " +#~ "строится выпуклая оболочка.\n" +#~ "L|extended: Маска, предназначенная для сегментации лица на основе " +#~ "расположения ориентиров. Выпуклый корпус строится вокруг внешних " +#~ "ориентиров, и маска расширяется вверх на лоб.\n" +#~ "(например: `-M unet-dfl vgg-clear`, `--masker vgg-obstructed`)" + +#~ msgid "" +#~ "R|Performing normalization can help the aligner better align faces with " +#~ "difficult lighting conditions at an extraction speed cost. Different " +#~ "methods will yield different results on different sets. NB: This does not " +#~ "impact the output face, just the input to the aligner.\n" +#~ "L|none: Don't perform normalization on the face.\n" +#~ "L|clahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +#~ "face.\n" +#~ "L|hist: Equalize the histograms on the RGB channels.\n" +#~ "L|mean: Normalize the face colors to the mean." +#~ msgstr "" +#~ "R|Проведение нормализации может помочь выравнивателю лучше выравнивать " +#~ "лица со сложными условиями освещения при затратах на скорость извлечения. " +#~ "Различные методы дают разные результаты на разных наборах. NB: Это не " +#~ "влияет на выходное лицо, только на вход выравнивателя.\n" +#~ "L|none: Не выполнять нормализацию лица.\n" +#~ "L|clahe: Выполнить для лица адаптивную гистограммную эквализацию с " +#~ "ограничением контраста.\n" +#~ "L|hist: Уравнять гистограммы в каналах RGB.\n" +#~ "L|mean: Нормализовать цвета лица к среднему значению." + +#~ msgid "" +#~ "The number of times to re-feed the detected face into the aligner. Each " +#~ "time the face is re-fed into the aligner the bounding box is adjusted by " +#~ "a small amount. The final landmarks are then averaged from each " +#~ "iteration. Helps to remove 'micro-jitter' but at the cost of slower " +#~ "extraction speed. The more times the face is re-fed into the aligner, the " +#~ "less micro-jitter should occur but the longer extraction will take." +#~ msgstr "" +#~ "Количество повторных подач обнаруженной области лица в выравниватель. При " +#~ "каждой повторной подаче лица в выравниватель ограничивающая рамка " +#~ "корректируется на небольшую величину. Затем конечные ориентиры " +#~ "усредняются по результатам каждой итерации. Это помогает устранить " +#~ "\"микро-дрожание\", но ценой снижения скорости извлечения. Чем больше раз " +#~ "лицо повторно подается в выравниватель, тем меньше микро-дрожание, но тем " +#~ "больше времени займет извлечение." + +#~ msgid "" +#~ "Re-feed the initially found aligned face through the aligner. Can help " +#~ "produce better alignments for faces that are rotated beyond 45 degrees in " +#~ "the frame or are at extreme angles. Slows down extraction." +#~ msgstr "" +#~ "Повторная подача первоначально найденной выровненной области лица через " +#~ "выравниватель. Может помочь получить лучшее выравнивание для лиц, " +#~ "повернутых в кадре более чем на 45 градусов или расположенных под " +#~ "экстремальными углами. Замедляет извлечение." + +#~ msgid "" +#~ "If a face isn't found, rotate the images to try to find a face. Can find " +#~ "more faces at the cost of extraction speed. Pass in a single number to " +#~ "use increments of that size up to 360, or pass in a list of numbers to " +#~ "enumerate exactly what angles to check." +#~ msgstr "" +#~ "Если лицо не найдено, поворачивает изображения, чтобы попытаться найти " +#~ "лицо. Может найти больше лиц ценой снижения скорости извлечения. " +#~ "Передайте одно число, чтобы использовать приращения этого размера до 360, " +#~ "или передайте список чисел, чтобы перечислить, какие именно углы нужно " +#~ "проверить." + +#~ msgid "" +#~ "Obtain and store face identity encodings from VGGFace2. Slows down " +#~ "extract a little, but will save time if using 'sort by face'" +#~ msgstr "" +#~ "Получение и хранение кодировок идентификации лица из VGGFace2. Немного " +#~ "замедляет извлечение, но экономит время при использовании \"сортировки по " +#~ "лицам\"." + +#~ msgid "Face Processing" +#~ msgstr "Обработка лиц" + +#~ msgid "" +#~ "Filters out faces detected below this size. Length, in pixels across the " +#~ "diagonal of the bounding box. Set to 0 for off" +#~ msgstr "" +#~ "Отфильтровывает лица, обнаруженные ниже этого размера. Длина в пикселях " +#~ "по диагонали ограничивающего поля. Установите значение 0, чтобы выключить" + +#~ msgid "" +#~ "Optionally filter out people who you do not wish to extract by passing in " +#~ "images of those people. Should be a small variety of images at different " +#~ "angles and in different conditions. A folder containing the required " +#~ "images or multiple image files, space separated, can be selected." +#~ msgstr "" +#~ "По желанию отфильтруйте людей, которых вы не хотите извлекать, передав " +#~ "изображения этих людей. Должно быть небольшое разнообразие изображений " +#~ "под разными углами и в разных условиях. Можно выбрать папку, содержащую " +#~ "необходимые изображения, или несколько файлов изображений, разделенных " +#~ "пробелами." + +#~ msgid "" +#~ "Optionally select people you wish to extract by passing in images of that " +#~ "person. Should be a small variety of images at different angles and in " +#~ "different conditions A folder containing the required images or multiple " +#~ "image files, space separated, can be selected." +#~ msgstr "" +#~ "По желанию выберите людей, которых вы хотите извлечь, передав изображения " +#~ "этого человека. Должно быть небольшое разнообразие изображений под " +#~ "разными углами и в разных условиях. Можно выбрать папку, содержащую " +#~ "необходимые изображения, или несколько файлов изображений, разделенных " +#~ "пробелами." + +#~ msgid "" +#~ "For use with the optional nfilter/filter files. Threshold for positive " +#~ "face recognition. Higher values are stricter." +#~ msgstr "" +#~ "Для использования с дополнительными файлами nfilter/filter. Порог для " +#~ "положительного распознавания лица. Более высокие значения являются более " +#~ "строгими." + +#~ msgid "output" +#~ msgstr "вывод" + +#~ msgid "" +#~ "The output size of extracted faces. Make sure that the model you intend " +#~ "to train supports your required size. This will only need to be changed " +#~ "for hi-res models." +#~ msgstr "" +#~ "Выходной размер извлеченных лиц. Убедитесь, что модель, которую вы " +#~ "собираетесь тренировать, поддерживает требуемый размер. Это необходимо " +#~ "изменить только для моделей высокого разрешения." + +#~ msgid "" +#~ "Extract every 'nth' frame. This option will skip frames when extracting " +#~ "faces. For example a value of 1 will extract faces from every frame, a " +#~ "value of 10 will extract faces from every 10th frame." +#~ msgstr "" +#~ "Извлекать каждый 'n-й' кадр. Этот параметр пропускает кадры при " +#~ "извлечении лиц. Например, значение 1 будет извлекать лица из каждого " +#~ "кадра, значение 10 будет извлекать лица из каждого 10-го кадра." + +#~ msgid "" +#~ "Automatically save the alignments file after a set amount of frames. By " +#~ "default the alignments file is only saved at the end of the extraction " +#~ "process. NB: If extracting in 2 passes then the alignments file will only " +#~ "start to be saved out during the second pass. WARNING: Don't interrupt " +#~ "the script when writing the file because it might get corrupted. Set to 0 " +#~ "to turn off" +#~ msgstr "" +#~ "Автоматическое сохранение файла выравнивания после заданного количества " +#~ "кадров. По умолчанию файл выравнивания сохраняется только в конце " +#~ "процесса извлечения. Примечание: Если извлечение выполняется в 2 прохода, " +#~ "то файл выравнивания начнет сохраняться только во время второго прохода. " +#~ "ПРЕДУПРЕЖДЕНИЕ: Не прерывайте работу скрипта при записи файла, так как он " +#~ "может быть поврежден. Установите значение 0, чтобы отключить" + +#~ msgid "Draw landmarks on the ouput faces for debugging purposes." +#~ msgstr "Нарисуйте ориентиры на выходящих гранях для отладки." + +#~ msgid "settings" +#~ msgstr "настройки" + +#~ msgid "" +#~ "Don't run extraction in parallel. Will run each part of the extraction " +#~ "process separately (one after the other) rather than all at the same " +#~ "time. Useful if VRAM is at a premium." +#~ msgstr "" +#~ "Не запускать извлечение параллельно. Каждая часть процесса извлечения " +#~ "будет выполняться отдельно (одна за другой), а не одновременно. Полезно, " +#~ "если память VRAM ограничена." + +#~ msgid "" +#~ "Skips frames that have already been extracted and exist in the alignments " +#~ "file" +#~ msgstr "" +#~ "Пропускает кадры, которые уже были извлечены и существуют в файле " +#~ "выравнивания" + +#~ msgid "Skip frames that already have detected faces in the alignments file" +#~ msgstr "" +#~ "Пропустить кадры, в которых уже есть обнаруженные лица в файле " +#~ "выравнивания" + +#~ msgid "" +#~ "Skip saving the detected faces to disk. Just create an alignments file" +#~ msgstr "" +#~ "Не сохранять обнаруженные лица на диск. Просто создать файл выравнивания" + +#~ msgid "" +#~ "Swap the original faces in a source video/images to your final faces.\n" +#~ "Conversion plugins can be configured in the 'Settings' Menu" +#~ msgstr "" +#~ "Поменять исходные лица в исходном видео/изображении на ваши конечные " +#~ "лица.\n" +#~ "Плагины конвертирования можно настроить в меню \"Настройки\"" + +#~ msgid "" +#~ "Only required if converting from images to video. Provide The original " +#~ "video that the source frames were extracted from (for extracting the fps " +#~ "and audio)." +#~ msgstr "" +#~ "Требуется только при преобразовании из изображений в видео. Предоставьте " +#~ "исходное видео, из которого были извлечены исходные кадры (для извлечения " +#~ "кадров в секунду и звука)." + +#~ msgid "" +#~ "Model directory. The directory containing the trained model you wish to " +#~ "use for conversion." +#~ msgstr "" +#~ "Папка модели. Папка, содержащая обученную модель, которую вы хотите " +#~ "использовать для преобразования." + +#~ msgid "" +#~ "R|Performs color adjustment to the swapped face. Some of these options " +#~ "have configurable settings in '/config/convert.ini' or 'Settings > " +#~ "Configure Convert Plugins':\n" +#~ "L|avg-color: Adjust the mean of each color channel in the swapped " +#~ "reconstruction to equal the mean of the masked area in the original " +#~ "image.\n" +#~ "L|color-transfer: Transfers the color distribution from the source to the " +#~ "target image using the mean and standard deviations of the L*a*b* color " +#~ "space.\n" +#~ "L|manual-balance: Manually adjust the balance of the image in a variety " +#~ "of color spaces. Best used with the Preview tool to set correct values.\n" +#~ "L|match-hist: Adjust the histogram of each color channel in the swapped " +#~ "reconstruction to equal the histogram of the masked area in the original " +#~ "image.\n" +#~ "L|seamless-clone: Use cv2's seamless clone function to remove extreme " +#~ "gradients at the mask seam by smoothing colors. Generally does not give " +#~ "very satisfactory results.\n" +#~ "L|none: Don't perform color adjustment." +#~ msgstr "" +#~ "R|Производит корректировку цвета поменявшегося лица. Некоторые из этих " +#~ "параметров настраиваются в '/config/convert.ini' или 'Настройки > " +#~ "Настроить плагины конвертации':\n" +#~ "L|avg-color: корректирует среднее значение каждого цветового канала в " +#~ "реконструкции, чтобы оно было равно среднему значению маскированной " +#~ "области в исходном изображении.\n" +#~ "L|color-transfer: Переносит распределение цветов с исходного изображения " +#~ "на целевое, используя среднее и стандартные отклонения цветового " +#~ "пространства L*a*b*.\n" +#~ "L|manual-balance: Ручная настройка баланса изображения в различных " +#~ "цветовых пространствах. Лучше всего использовать с инструментом " +#~ "предварительного просмотра для установки правильных значений.\n" +#~ "L|match-hist: Настроить гистограмму каждого цветового канала в измененном " +#~ "восстановлении так, чтобы она соответствовала гистограмме маскированной " +#~ "области исходного изображения.\n" +#~ "L|seamless-clone: Используйте функцию бесшовного клонирования cv2 для " +#~ "удаления экстремальных градиентов на шве маски путем сглаживания цветов. " +#~ "Обычно дает не очень удовлетворительные результаты.\n" +#~ "L|none: Не выполнять коррекцию цвета." + +#~ msgid "" +#~ "R|Masker to use. NB: The mask you require must exist within the " +#~ "alignments file. You can add additional masks with the Mask Tool.\n" +#~ "L|none: Don't use a mask.\n" +#~ "L|bisenet-fp_face: Relatively lightweight NN based mask that provides " +#~ "more refined control over the area to be masked (configurable in mask " +#~ "settings). Use this version of bisenet-fp if your model is trained with " +#~ "'face' or 'legacy' centering.\n" +#~ "L|bisenet-fp_head: Relatively lightweight NN based mask that provides " +#~ "more refined control over the area to be masked (configurable in mask " +#~ "settings). Use this version of bisenet-fp if your model is trained with " +#~ "'head' centering.\n" +#~ "L|custom_face: Custom user created, face centered mask.\n" +#~ "L|custom_head: Custom user created, head centered mask.\n" +#~ "L|components: Mask designed to provide facial segmentation based on the " +#~ "positioning of landmark locations. A convex hull is constructed around " +#~ "the exterior of the landmarks to create a mask.\n" +#~ "L|extended: Mask designed to provide facial segmentation based on the " +#~ "positioning of landmark locations. A convex hull is constructed around " +#~ "the exterior of the landmarks and the mask is extended upwards onto the " +#~ "forehead.\n" +#~ "L|vgg-clear: Mask designed to provide smart segmentation of mostly " +#~ "frontal faces clear of obstructions. Profile faces and obstructions may " +#~ "result in sub-par performance.\n" +#~ "L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +#~ "frontal faces. The mask model has been specifically trained to recognize " +#~ "some facial obstructions (hands and eyeglasses). Profile faces may result " +#~ "in sub-par performance.\n" +#~ "L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +#~ "faces. The mask model has been trained by community members and will need " +#~ "testing for further description. Profile faces may result in sub-par " +#~ "performance.\n" +#~ "L|predicted: If the 'Learn Mask' option was enabled during training, this " +#~ "will use the mask that was created by the trained model." +#~ msgstr "" +#~ "R|Маскер для использования. Примечание: Нужная маска должна существовать " +#~ "в файле выравнивания. Вы можете добавить дополнительные маски с помощью " +#~ "инструмента Mask Tool.\n" +#~ "L|none: Не использовать маску.\n" +#~ "L|bisenet-fp_face: Относительно легкая маска на основе NN, которая " +#~ "обеспечивает более точный контроль над маскируемой областью " +#~ "(настраивается в настройках маски). Используйте эту версию bisenet-fp, " +#~ "если ваша модель обучена с центрированием 'face' или 'legacy'.\n" +#~ "L|bisenet-fp_head: Относительно легкая маска на основе NN, которая " +#~ "обеспечивает более точный контроль над маскируемой областью " +#~ "(настраивается в настройках маски). Используйте эту версию bisenet-fp, " +#~ "если ваша модель обучена с центрированием по \"голове\".\n" +#~ "L|custom_face: Пользовательская маска, созданная пользователем и " +#~ "центрированная по лицу.\n" +#~ "L|custom_head: Созданная пользователем маска, центрированная по голове.\n" +#~ "L|components: Маска, разработанная для сегментации лица на основе " +#~ "расположения ориентиров. Для создания маски вокруг внешних ориентиров " +#~ "строится выпуклая оболочка.\n" +#~ "L|extended: Маска, предназначенная для сегментации лица на основе " +#~ "расположения ориентиров. Выпуклый корпус строится вокруг внешних " +#~ "ориентиров, и маска расширяется вверх на лоб.\n" +#~ "L|vgg-clear: Маска предназначена для интеллектуальной сегментации " +#~ "преимущественно фронтальных лиц без препятствий. Профильные лица и " +#~ "препятствия могут привести к снижению производительности.\n" +#~ "L|vgg-obstructed: Маска, разработанная для интеллектуальной сегментации " +#~ "преимущественно фронтальных лиц. Модель маски была специально обучена " +#~ "распознавать некоторые препятствия на лице (руки и очки). Лица в профиль " +#~ "могут иметь низкую производительность.\n" +#~ "L|unet-dfl: Маска, разработанная для интеллектуальной сегментации " +#~ "преимущественно фронтальных лиц. Модель маски была обучена членами " +#~ "сообщества и для дальнейшего описания нуждается в тестировании. " +#~ "Профильные лица могут привести к низкой производительности.\n" +#~ "L|predicted: Если во время обучения была включена опция 'Изучить Маску', " +#~ "то будет использоваться маска, созданная обученной моделью." + +#~ msgid "" +#~ "R|The plugin to use to output the converted images. The writers are " +#~ "configurable in '/config/convert.ini' or 'Settings > Configure Convert " +#~ "Plugins:'\n" +#~ "L|ffmpeg: [video] Writes out the convert straight to video. When the " +#~ "input is a series of images then the '-ref' (--reference-video) parameter " +#~ "must be set.\n" +#~ "L|gif: [animated image] Create an animated gif.\n" +#~ "L|opencv: [images] The fastest image writer, but less options and formats " +#~ "than other plugins.\n" +#~ "L|patch: [images] Outputs the raw swapped face patch, along with the " +#~ "transformation matrix required to re-insert the face back into the " +#~ "original frame. Use this option if you wish to post-process and composite " +#~ "the final face within external tools.\n" +#~ "L|pillow: [images] Slower than opencv, but has more options and supports " +#~ "more formats." +#~ msgstr "" +#~ "R|Плагин, который нужно использовать для вывода преобразованных " +#~ "изображений. Записи настраиваются в '/config/convert.ini' или 'Настройки " +#~ "> Настроить плагины конвертации:'\n" +#~ "L|ffmpeg: [видео] Записывает конвертацию прямо в видео. Если на вход " +#~ "подается серия изображений, необходимо установить параметр '-ref' (--" +#~ "reference-video).\n" +#~ "L|gif: [анимированное изображение] Создает анимированный gif.\n" +#~ "L|opencv: [изображения] Самый быстрый редактор изображений, но имеет " +#~ "меньше опций и форматов, чем другие плагины.\n" +#~ "L|patch: [изображения] Выводит необработанный фрагмент измененного лица " +#~ "вместе с матрицей преобразования, необходимой для повторной вставки лица " +#~ "обратно в исходный кадр.\n" +#~ "L|pillow: [изображения] Медленнее, чем opencv, но имеет больше опций и " +#~ "поддерживает больше форматов." + +#~ msgid "Frame Processing" +#~ msgstr "Обработка лиц" + +#, python-format +#~ msgid "" +#~ "Scale the final output frames by this amount. 100%% will output the " +#~ "frames at source dimensions. 50%% at half size 200%% at double size" +#~ msgstr "" +#~ "Масштабирование конечных выходных кадров на эту величину. 100%% выводит " +#~ "кадры в исходном размере. 50%% при половинном размере 200%% при двойном " +#~ "размере" + +#~ msgid "" +#~ "Frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 " +#~ "use --frame-ranges 10-50 90-100. Frames falling outside of the selected " +#~ "range will be discarded unless '-k' (--keep-unchanged) is selected. NB: " +#~ "If you are converting from images, then the filenames must end with the " +#~ "frame-number!" +#~ msgstr "" +#~ "Диапазоны кадров для применения переноса, например, для кадров с 10 по 50 " +#~ "и с 90 по 100 используйте --frame-ranges 10-50 90-100. Кадры, выходящие " +#~ "за пределы выбранного диапазона, будут отброшены, если не выбрана опция '-" +#~ "k' (--keep-unchanged). Примечание: Если вы конвертируете из изображений, " +#~ "то имена файлов должны заканчиваться номером кадра!" + +#~ msgid "" +#~ "Scale the swapped face by this percentage. Positive values will enlarge " +#~ "the face, Negative values will shrink the face." +#~ msgstr "" +#~ "Увеличить масштаб нового лица на этот процент. Положительные значения " +#~ "увеличат лицо, в то время как отрицательные значения уменьшат его." + +#~ msgid "" +#~ "If you have not cleansed your alignments file, then you can filter out " +#~ "faces by defining a folder here that contains the faces extracted from " +#~ "your input files/video. If this folder is defined, then only faces that " +#~ "exist within your alignments file and also exist within the specified " +#~ "folder will be converted. Leaving this blank will convert all faces that " +#~ "exist within the alignments file." +#~ msgstr "" +#~ "Если вы не очистили свой файл выравнивания, то вы можете отфильтровать " +#~ "лица, определив здесь папку, содержащую лица, извлеченные из ваших " +#~ "входных файлов/видео. Если эта папка определена, то будут преобразованы " +#~ "только те лица, которые существуют в вашем файле выравнивания, а также в " +#~ "указанной папке. Если оставить этот параметр пустым, будут преобразованы " +#~ "все лица, существующие в файле выравнивания." + +#~ msgid "" +#~ "Optionally filter out people who you do not wish to process by passing in " +#~ "an image of that person. Should be a front portrait with a single person " +#~ "in the image. Multiple images can be added space separated. NB: Using " +#~ "face filter will significantly decrease extraction speed and its accuracy " +#~ "cannot be guaranteed." +#~ msgstr "" +#~ "По желанию отфильтровать людей, которых вы не хотите обрабатывать, " +#~ "передав изображение этого человека. Это должен быть фронтальный портрет с " +#~ "изображением одного человека. Можно добавить несколько изображений, " +#~ "разделенных пробелами. Примечание: Использование фильтра лиц значительно " +#~ "снизит скорость извлечения, а его точность не гарантируется." + +#~ msgid "" +#~ "Optionally select people you wish to process by passing in an image of " +#~ "that person. Should be a front portrait with a single person in the " +#~ "image. Multiple images can be added space separated. NB: Using face " +#~ "filter will significantly decrease extraction speed and its accuracy " +#~ "cannot be guaranteed." +#~ msgstr "" +#~ "По желанию выберите людей, которых вы хотите обработать, передав " +#~ "изображение этого человека. Это должен быть фронтальный портрет с " +#~ "изображением одного человека. Можно добавить несколько изображений, " +#~ "разделенных пробелами. Примечание: Использование фильтра лиц значительно " +#~ "снизит скорость извлечения, а его точность не гарантируется." + +#~ msgid "" +#~ "For use with the optional nfilter/filter files. Threshold for positive " +#~ "face recognition. Lower values are stricter. NB: Using face filter will " +#~ "significantly decrease extraction speed and its accuracy cannot be " +#~ "guaranteed." +#~ msgstr "" +#~ "Для использования с дополнительными файлами nfilter/filter. Порог для " +#~ "положительного распознавания лиц. Более низкие значения являются более " +#~ "строгими. Примечание: Использование фильтра лиц значительно снизит " +#~ "скорость извлечения, а его точность не гарантируется." + +#~ msgid "" +#~ "The maximum number of parallel processes for performing conversion. " +#~ "Converting images is system RAM heavy so it is possible to run out of " +#~ "memory if you have a lot of processes and not enough RAM to accommodate " +#~ "them all. Setting this to 0 will use the maximum available. No matter " +#~ "what you set this to, it will never attempt to use more processes than " +#~ "are available on your system. If singleprocess is enabled this setting " +#~ "will be ignored." +#~ msgstr "" +#~ "Максимальное количество параллельных процессов для выполнения " +#~ "конвертации. Конвертирование изображений занимает много системной " +#~ "оперативной памяти, поэтому может закончиться память, если у вас много " +#~ "процессов и недостаточно оперативной памяти для их размещения. Если " +#~ "установить значение 0, будет использован максимум доступной памяти. " +#~ "Независимо от того, какое значение вы установите, программа никогда не " +#~ "будет пытаться использовать больше процессов, чем доступно в вашей " +#~ "системе. Если включена однопоточная обработка, этот параметр будет " +#~ "проигнорирован." + +#~ msgid "" +#~ "[LEGACY] This only needs to be selected if a legacy model is being loaded " +#~ "or if there are multiple models in the model folder" +#~ msgstr "" +#~ "[ОТБРОШЕН] Этот параметр необходимо выбрать только в том случае, если " +#~ "загружается устаревшая модель или если в папке моделей имеется несколько " +#~ "моделей" + +#~ msgid "" +#~ "Enable On-The-Fly Conversion. NOT recommended. You should generate a " +#~ "clean alignments file for your destination video. However, if you wish " +#~ "you can generate the alignments on-the-fly by enabling this option. This " +#~ "will use an inferior extraction pipeline and will lead to substandard " +#~ "results. If an alignments file is found, this option will be ignored." +#~ msgstr "" +#~ "Включить преобразование \"на лету\". НЕ рекомендуется. Вы должны " +#~ "сгенерировать чистый файл выравнивания для конечного видео. Однако при " +#~ "желании вы можете генерировать выравнивания \"на лету\", включив эту " +#~ "опцию. При этом будет использоваться некачественный конвейер извлечения, " +#~ "что приведет к некачественным результатам. Если файл выравнивания найден, " +#~ "этот параметр будет проигнорирован." + +#~ msgid "" +#~ "When used with --frame-ranges outputs the unchanged frames that are not " +#~ "processed instead of discarding them." +#~ msgstr "" +#~ "При использовании с --frame-ranges выводит неизмененные кадры, которые не " +#~ "были обработаны, вместо того, чтобы отбрасывать их." + +#~ msgid "Swap the model. Instead converting from of A -> B, converts B -> A" +#~ msgstr "" +#~ "Поменять модель местами. Вместо преобразования из A -> B, преобразуется B " +#~ "-> A" + +#~ msgid "Disable multiprocessing. Slower but less resource intensive." +#~ msgstr "" +#~ "Отключение многопоточной обработки. Медленнее, но менее ресурсоемко." + +#~ msgid "Output to Shell console instead of GUI console" +#~ msgstr "Вывод в консоль Shell вместо консоли GUI" + +#~ msgid "" +#~ "[Deprecated - Use '-D, --distribution-strategy' instead] Use the " +#~ "Tensorflow Mirrored Distrubution Strategy to train on multiple GPUs." +#~ msgstr "" +#~ "[Устарело - Используйте '-D, --distribution-strategy' вместо этого] " +#~ "Используйте стратегию Tensorflow Mirrored Distrubution Strategy(Стратегия " +#~ "Зеркального Распределения Tensorflow) для обучения на нескольких GPU." diff --git a/locales/ru/LC_MESSAGES/lib.config.mo b/locales/ru/LC_MESSAGES/lib.config.mo new file mode 100644 index 0000000000..49787224ff Binary files /dev/null and b/locales/ru/LC_MESSAGES/lib.config.mo differ diff --git a/locales/ru/LC_MESSAGES/lib.config.objects.mo b/locales/ru/LC_MESSAGES/lib.config.objects.mo new file mode 100644 index 0000000000..d10f8be1df Binary files /dev/null and b/locales/ru/LC_MESSAGES/lib.config.objects.mo differ diff --git a/locales/ru/LC_MESSAGES/lib.config.objects.po b/locales/ru/LC_MESSAGES/lib.config.objects.po new file mode 100644 index 0000000000..b85d04bffc --- /dev/null +++ b/locales/ru/LC_MESSAGES/lib.config.objects.po @@ -0,0 +1,76 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-11 19:02+0000\n" +"PO-Revision-Date: 2025-12-12 13:08+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru_RU\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"X-Generator: Poedit 3.8\n" + +#: lib/config/objects.py:115 +msgid "" +"\n" +"This option can be updated for existing models.\n" +msgstr "" +"\n" +"Эта настройка будет обновлена для существующих моделей.\n" + +#: lib/config/objects.py:117 +msgid "" +"\n" +"If selecting multiple options then each option should be separated by a " +"space or a comma (e.g. item1, item2, item3)\n" +msgstr "" +"\n" +"Если выбираете несколько опций, тогда каждая опция должна быть разделена " +"пробелом или запятой (например: опция1, опция2, опция3)\n" + +#: lib/config/objects.py:120 +msgid "" +"\n" +"Choose from: {}" +msgstr "" +"\n" +"Выберите из: {}" + +#: lib/config/objects.py:122 +msgid "" +"\n" +"Choose from: True, False" +msgstr "" +"\n" +"Выберите из: True, False" + +#: lib/config/objects.py:126 +msgid "" +"\n" +"Select an integer between {} and {}" +msgstr "" +"\n" +"Выберите число между {} и {}" + +#: lib/config/objects.py:130 +msgid "" +"\n" +"Select a decimal number between {} and {}" +msgstr "" +"\n" +"Выберите десятичное число между {} и {}" + +#: lib/config/objects.py:132 +msgid "" +"\n" +"[Default: {}]" +msgstr "" +"\n" +"[По умолчанию: {}]" diff --git a/locales/ru/LC_MESSAGES/plugins.extract.extract_config.mo b/locales/ru/LC_MESSAGES/plugins.extract.extract_config.mo new file mode 100644 index 0000000000..411175f38b Binary files /dev/null and b/locales/ru/LC_MESSAGES/plugins.extract.extract_config.mo differ diff --git a/locales/ru/LC_MESSAGES/plugins.extract.extract_config.po b/locales/ru/LC_MESSAGES/plugins.extract.extract_config.po new file mode 100644 index 0000000000..9d42c596bf --- /dev/null +++ b/locales/ru/LC_MESSAGES/plugins.extract.extract_config.po @@ -0,0 +1,164 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-12 13:11+0000\n" +"PO-Revision-Date: 2025-12-12 13:14+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru_RU\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"X-Generator: Poedit 3.8\n" + +#: plugins/extract/extract_config.py:23 +msgid "Options that apply to all extraction plugins" +msgstr "Параметры, применимые ко всем плагинам извлечения" + +#: plugins/extract/extract_config.py:30 plugins/extract/extract_config.py:45 +#: plugins/extract/extract_config.py:60 plugins/extract/extract_config.py:72 +#: plugins/extract/extract_config.py:85 plugins/extract/extract_config.py:95 +#: plugins/extract/extract_config.py:107 +msgid "filters" +msgstr "фильтры" + +#: plugins/extract/extract_config.py:32 +msgid "" +"Filters out faces below this size. This is a multiplier of the minimum " +"dimension of the frame (i.e. 1280x720 = 720). If the original face extract " +"box is smaller than the minimum dimension times this multiplier, it is " +"considered a false positive and discarded. Faces which are found to be " +"unusually smaller than the frame tend to be misaligned images, except in " +"extreme long-shots. These can be usually be safely discarded." +msgstr "" +"Отфильтровывает лица меньше этого размера. Это множитель минимального " +"размера кадра (т.е. 1280x720 = 720). Если исходное поле извлечения лица " +"меньше минимального размера, умноженного на этот множитель, оно считается " +"ложным срабатыванием и отбрасывается. Лица, которые оказываются необычно " +"меньшего размера, чем кадр, как правило, являются неправильно выровненными " +"изображениями, за исключением экстремально длинных снимков. Обычно их можно " +"смело отбрасывать." + +#: plugins/extract/extract_config.py:47 +msgid "" +"Filters out faces above this size. This is a multiplier of the minimum " +"dimension of the frame (i.e. 1280x720 = 720). If the original face extract " +"box is larger than the minimum dimension times this multiplier, it is " +"considered a false positive and discarded. Faces which are found to be " +"unusually larger than the frame tend to be misaligned images except in " +"extreme close-ups. These can be usually be safely discarded." +msgstr "" +"Отфильтровывает лица, превышающие этот размер. Это множитель минимального " +"размера кадра (т.е. 1280x720 = 720). Если исходный блок извлечения лица " +"больше, чем минимальный размер кадра, умноженный на этот множитель, он " +"считается ложным срабатыванием и отбрасывается. Лица, размер которых " +"необычно превышает размер кадра, как правило, являются несогласованными " +"изображениями, за исключением экстремальных крупных планов. Обычно их можно " +"смело отбрасывать." + +#: plugins/extract/extract_config.py:62 +msgid "" +"Filters out faces who's landmarks are above this distance from an 'average' " +"face. Values above 15 tend to be fairly safe. Values above 10 will remove " +"more false positives, but may also filter out some faces at extreme angles." +msgstr "" +"Отфильтровывает лица, ориентиры которых находятся на расстоянии, превышающем " +"это расстояние от 'среднего' лица. Значения выше 15, как правило, достаточно " +"безопасны. Значения выше 10 устраняют больше ложных срабатываний, но также " +"могут отфильтровать некоторые лица под экстремальными углами." + +#: plugins/extract/extract_config.py:74 +msgid "" +"Filters out faces who's calculated roll is greater than zero +/- this value " +"in degrees. Aligned faces should have a roll value close to zero. Values " +"that are a significant distance from 0 degrees tend to be misaligned images. " +"These can usually be safely disgarded." +msgstr "" +"Отфильтровывает лица, у которых расчетный угол наклона больше нуля +/- это " +"значение в градусах. Выровненные лица должны иметь значение угла наклона, " +"близкое к нулю. Значения, которые значительно удалены от 0 градусов, как " +"правило, представляют собой неправильно выровненные изображения. Обычно их " +"можно смело отбрасывать." + +#: plugins/extract/extract_config.py:87 +msgid "" +"Filters out faces where the lowest point of the aligned face's eye or " +"eyebrow is lower than the highest point of the aligned face's mouth. Any " +"faces where this occurs are misaligned and can be safely disgarded." +msgstr "" +"Отфильтровывает лица, у которых нижняя точка глаза или брови выровненного " +"лица находится ниже, чем верхняя точка рта выровненного лица. Все лица, на " +"которых это происходит, являются неправильно выровненными и могут быть смело " +"отброшены." + +#: plugins/extract/extract_config.py:97 +msgid "" +"If enabled, and 're-feed' has been selected for extraction, then interim " +"alignments will be filtered prior to averaging the final landmarks. This can " +"help improve the final alignments by removing any obvious misaligns from the " +"interim results, and may also help pick up difficult alignments. If " +"disabled, then all re-feed results will be averaged." +msgstr "" +"Если эта функция включена, и для извлечения выбрана 'повторная подача'('re-" +"feed'), то промежуточные выравнивания будут отфильтрованы перед усреднением " +"окончательных ориентиров. Это может помочь улучшить окончательное " +"выравнивание, удалив любые очевидные несоответствия из промежуточных " +"результатов, а также может помочь выявить сложные выравнивания. Если эта " +"функция отключена, то все результаты повторной подачи будут усреднены." + +#: plugins/extract/extract_config.py:109 +msgid "" +"If enabled, saves any filtered out images into a sub-folder during the " +"extraction process. If disabled, filtered faces are deleted. Note: The faces " +"will always be filtered out of the alignments file, regardless of whether " +"you keep the faces or not." +msgstr "" +"Если включена, то в процессе извлечения отфильтрованные изображения " +"сохраняются в подпапке. Если отключено, отфильтрованные лица удаляются. " +"Примечание: Лица всегда будут отфильтрованы из файла выравнивания, " +"независимо от того, сохраняете вы эти лица или нет." + +#: plugins/extract/extract_config.py:118 plugins/extract/extract_config.py:128 +msgid "re-align" +msgstr "повторное выравнивание" + +#: plugins/extract/extract_config.py:120 +msgid "" +"If enabled, and 're-align' has been selected for extraction, then all re-" +"feed iterations are re-aligned. If disabled, then only the final averaged " +"output from re-feed will be re-aligned." +msgstr "" +"Если включено, и для извлечения выбрано 'повторное выравнивание'('re-" +"align'), то все итерации повторной подачи выравниваются повторно. Если " +"отключено, то выравнивается только конечный усредненный результат повторной " +"подачи." + +#: plugins/extract/extract_config.py:130 +msgid "" +"If enabled, and 're-align' has been selected for extraction, then any " +"alignments which would be filtered out will not be re-aligned." +msgstr "" +"Если эта функция включена, и для извлечения выбрано 'повторное " +"выравнивание'('re-align'), то все выравнивания, которые будут отфильтрованы, " +"не будут повторно выравниваться." + +#~ msgid "settings" +#~ msgstr "настройки" + +#~ msgid "" +#~ "Enable the Tensorflow GPU `allow_growth` configuration option. This " +#~ "option prevents Tensorflow from allocating all of the GPU VRAM at launch " +#~ "but can lead to higher VRAM fragmentation and slower performance. Should " +#~ "only be enabled if you are having problems running extraction." +#~ msgstr "" +#~ "Включите опцию конфигурации Tensorflow GPU `allow_growth`. Эта опция не " +#~ "позволяет Tensorflow выделять всю видеопамять видеокарты при запуске, но " +#~ "может привести к повышенной фрагментации видеопамяти и снижению " +#~ "производительности. Следует включать только в том случае, если у вас есть " +#~ "проблемы с запуском извлечения." diff --git a/locales/ru/LC_MESSAGES/plugins.train.train_config.mo b/locales/ru/LC_MESSAGES/plugins.train.train_config.mo new file mode 100644 index 0000000000..3c41502926 Binary files /dev/null and b/locales/ru/LC_MESSAGES/plugins.train.train_config.mo differ diff --git a/locales/ru/LC_MESSAGES/plugins.train.train_config.po b/locales/ru/LC_MESSAGES/plugins.train.train_config.po new file mode 100644 index 0000000000..a36c337933 --- /dev/null +++ b/locales/ru/LC_MESSAGES/plugins.train.train_config.po @@ -0,0 +1,1290 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-12-13 13:39+0000\n" +"PO-Revision-Date: 2025-12-15 22:01+0700\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru_RU\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"X-Generator: Poedit 3.5\n" + +#: plugins/train/train_config.py:21 +msgid "" +"\n" +"NB: Unless specifically stated, values changed here will only take effect " +"when creating a new model." +msgstr "" +"\n" +"Примечание: До тех пор, пока об этом не сказано, значения, измененные здесь, " +"будут применены при создании новой модели." + +#: plugins/train/train_config.py:30 +msgid "Options that apply to all models" +msgstr "Настройки, применимые ко всем моделям" + +#: plugins/train/train_config.py:43 plugins/train/train_config.py:66 +#: plugins/train/train_config.py:86 +msgid "face" +msgstr "лицо" + +#: plugins/train/train_config.py:45 +msgid "" +"How to center the training image. The extracted images are centered on the " +"middle of the skull based on the face's estimated pose. A subsection of " +"these images are used for training. The centering used dictates how this " +"subsection will be cropped from the aligned images.\n" +"\tface: Centers the training image on the center of the face, adjusting for " +"pitch and yaw.\n" +"\thead: Centers the training image on the center of the head, adjusting for " +"pitch and yaw. NB: You should only select head centering if you intend to " +"include the full head (including hair) in the final swap. This may give " +"mixed results. Additionally, it is only worth choosing head centering if you " +"are training with a mask that includes the hair (e.g. BiSeNet-FP-Head).\n" +"\tlegacy: The 'original' extraction technique. Centers the training image " +"near the tip of the nose with no adjustment. Can result in the edges of the " +"face appearing outside of the training area." +msgstr "" +"Как централизовывать тренировочное изображение. Центр в извлеченных " +"изображениях находится в середине черепа, основанный на примерной позе лица. " +"Подсекция этих изображений используется для тренировки. Используемый центр " +"диктует то, как эта подсекция будет обрезана из выравненных изображений.\n" +"\tface: Центрирует учебное изображение по центру лица, регулируя угол " +"наклона и поворота.\n" +"\thead: Централизует тренировочное изображение в центре головы, регулируя " +"угол наклона и поворота. Примечание: Следует выбирать централизацию головы, " +"если вы планируете включать голову полностью (включая волосы) в финальную " +"замену. Может дать смешанные результаты. В дополнении, оно стоит того только " +"если вы тренируете с маской, что включает в себя волосы (к примеру: BiSeNet-" +"FP-Head).\n" +"\tlegacy: 'оригинальная' техника извлечения. Централизует тренировочное " +"изображение ближе к кончику носа без правок. Может привести к тому, что края " +"лица будут вне тренировочной зоны." + +#: plugins/train/train_config.py:68 +msgid "" +"How much of the extracted image to train on. A lower coverage will limit the " +"model's scope to a zoomed-in central area while higher amounts can include " +"the entire face. A trade-off exists between lower amounts given more detail " +"versus higher amounts avoiding noticeable swap transitions. For 'Face' " +"centering you will want to leave this above 75%. For Head centering you will " +"most likely want to set this to 100%. Sensible values for 'Legacy' centering " +"are:\n" +"\t62.5% spans from eyebrow to eyebrow.\n" +"\t75.0% spans from temple to temple.\n" +"\t87.5% spans from ear to ear.\n" +"\t100.0% is a mugshot." +msgstr "" +"Сколько извлеченного изображения тренировать. Низкая покрытость ограничит " +"прицел модели к приближенной центральной зоне, в то время как большие " +"значения могут включать в себя целое лицо. Существует компромисс между " +"меньшими объемами, дающими больше деталей, и большими объемами, позволяющими " +"избежать заметных переходов замены. Для централизации 'Face', вам нужно " +"будет оставить значение выше 75%. Для централизации 'Head', вам скорее всего " +"нужно будет поставить значение 100%. Адекватные значения для 'Legacy':\n" +"\t62.5% охватывает от бровей до бровей.\n" +"\t75% охватывает от виска до виска.\n" +"\t87.5% охватывает от уха до уха.\n" +"\t100% - полный снимок." + +#: plugins/train/train_config.py:88 +msgid "" +"How much to adjust the vertical position of the aligned face as a percentage " +"of face image size. Negative values move the face up (expose more chin and " +"less forehead). Positive values move the face down (expose less chin and " +"more forehead)" +msgstr "" +"На сколько процентов от размера изображения лица сдвигать его по вертикали " +"после выравнивания. Отрицательные значения сдвигают лицо вверх (в кадре " +"становится больше подбородка и шеи, а лба — меньше). Положительные значения " +"сдвигают лицо вниз (в кадре становится больше лба и волос, а подбородка — " +"меньше)." + +#: plugins/train/train_config.py:99 plugins/train/train_config.py:109 +msgid "initialization" +msgstr "инициализация" + +#: plugins/train/train_config.py:101 +msgid "" +"Use ICNR to tile the default initializer in a repeating pattern. This " +"strategy is designed for pairing with sub-pixel / pixel shuffler to reduce " +"the 'checkerboard effect' in image reconstruction. \n" +"\t https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf" +msgstr "" +"Использовать ICNR для чередования инициализатора по умолчанию в " +"повторяющемся шаблоне. Эта стратегия предназначена для использования в паре " +"с субпиксельным/пиксельным перетасовщиком для уменьшения \"эффекта шахматной " +"доски\" при реконструкции изображения. \n" +"\t [ТОЛЬКО на английском] https://arxiv.org/ftp/arxiv/papers/1707/1707.02937." +"pdf" + +#: plugins/train/train_config.py:111 +msgid "" +"Use Convolution Aware Initialization for convolutional layers. This can help " +"eradicate the vanishing and exploding gradient problem as well as lead to " +"higher accuracy, lower loss and faster convergence.\n" +"NB:\n" +"\t This can use more VRAM when creating a new model so you may want to lower " +"the batch size for the first run. The batch size can be raised again when " +"reloading the model.\n" +"\t Multi-GPU is not supported for this option, so you should start the model " +"on a single GPU. Once training has started, you can stop training, enable " +"multi-GPU and resume.\n" +"\t Building the model will likely take several minutes as the calculations " +"for this initialization technique are expensive. This will only impact " +"starting a new model." +msgstr "" +"Использовать свёрточно-осведомлённую инициализацию для сверточных слоев. " +"Может помочь устранить проблему исчезающего и взрывающегося градиента, а " +"также повысить точность, снизить потери и ускорить сходимость.\n" +"Примечание:\n" +"\t При создании новой модели может потребоваться больше видеопамяти, поэтому " +"для первого запуска лучше уменьшить размер пачки. Размер пачки может быть " +"увеличен при перезагрузке модели. \n" +"\t Использование нескольких видеокарт не поддерживается, поэтому модель " +"следует запускать на одной видеокарте. После начала обучения вы можете " +"остановить обучение, включить несколько видеокарт и возобновить его.\n" +"\t Построение модели, скорее всего, займет несколько минут, поскольку " +"вычисления для этой техники инициализации являются дорогостоящими. Это " +"повлияет только на запуск новой модели." + +#: plugins/train/train_config.py:126 plugins/train/train_config.py:138 +#: plugins/train/train_config.py:155 +msgid "Learning Rate Finder" +msgstr "Инструмент поиска оптимального коэффициента обучения" + +#: plugins/train/train_config.py:128 +msgid "" +"The number of iterations to process to find the optimal learning rate. " +"Higher values will take longer, but will be more accurate." +msgstr "" +"Количество итераций для поиска оптимального коэффициента обучения. Большие " +"значения займут больше времени, но будут более точными." + +#: plugins/train/train_config.py:140 +msgid "" +"The operation mode for the learning rate finder. Only applicable to new " +"models. For existing models this will always default to 'set'.\n" +"\tset - Train with the discovered optimal learning rate.\n" +"\tgraph_and_set - Output a graph in the training folder showing the " +"discovered learning rates and train with the optimal learning rate.\n" +"\tgraph_and_exit - Output a graph in the training folder with the discovered " +"learning rates and exit." +msgstr "" +"Режим работы для поиска коэффициента обучения. Применимо только для новых " +"моделей. Для уже существующих моделей режим будет автоматически выставлен в " +"'set'.\n" +"\tset - Обучение с найденным оптимальным коэффициентом обучения.\n" +"\tgraph_and_set - Вывод графика в папку обучения, показывающего найденные " +"коэффициенты обучения, и обучение с оптимальным коэффициентом.\n" +"\tgraph_and_exit - Вывод графика в папку обучения с найденными " +"коэффициентами обучения с последующим выходом из программы." + +#: plugins/train/train_config.py:157 +msgid "" +"How aggressively to set the Learning Rate. More aggressive can learn faster, " +"but is more likely to lead to exploding gradients.\n" +"\tdefault - The default optimal learning rate. A safe choice for nearly all " +"use cases.\n" +"\taggressive - Set's a higher learning rate than the default. May learn " +"faster but with a higher chance of exploding gradients.\n" +"\textreme - The highest optimal learning rate. A much higher risk of " +"exploding gradients." +msgstr "" +"Насколько агрессивно устанавливать коэффициент обучения. Более агрессивный " +"подход может обучать быстрее, но с большей вероятностью может привести к " +"взрыву градиентов.\n" +"\tdefault - Оптимальный коэффициент обучения по умолчанию. Безопасный выбор " +"для почти всех случаев использования.\n" +"\taggressive - Устанавливает коэффициент обучения выше, чем по умолчанию. " +"Может обучать быстрее, но с большей вероятностью взрыва градиента.\n" +"\textreme - Наивысший оптимальный коэффициент обучения. Гораздо выше риск " +"взрыва градиента." + +#: plugins/train/train_config.py:172 plugins/train/train_config.py:183 +#: plugins/train/train_config.py:199 +msgid "network" +msgstr "сеть" + +#: plugins/train/train_config.py:174 +msgid "" +"Use reflection padding rather than zero padding with convolutions. Each " +"convolution must pad the image boundaries to maintain the proper sizing. " +"More complex padding schemes can reduce artifacts at the border of the " +"image.\n" +"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" +msgstr "" +"Используйте для сверток не нулевую, а отражающую подкладку. Каждая свертка " +"должна заполнять границы изображения для поддержания правильного размера. " +"Более сложные схемы вставки могут уменьшить артефакты на границе " +"изображения.\n" +"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" + +#: plugins/train/train_config.py:185 +msgid "" +"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed " +"precision allows you to use a mix of float16 with float32, to get the " +"performance benefits from float16 and the numeric stability benefits from " +"float32.\n" +"\n" +"This is untested on non-Nvidia cards, but will run on most Nvidia models. it " +"will only speed up training on more recent GPUs. Those with compute " +"capability 7.0 or higher will see the greatest performance benefit from " +"mixed precision because they have Tensor Cores. Older GPUs offer no math " +"performance benefit for using mixed precision, however memory and bandwidth " +"savings can enable some speedups. Generally RTX GPUs and later will offer " +"the most benefit." +msgstr "" +"Видеокарты от NVIDIA могут оперировать в 'float16' быстрее, чем в 'float32'. " +"Смешанная точность позволяет вам использовать микс float16 с float32, чтобы " +"получить улучшение производительности от float16 и числовую стабильность от " +"float32.\n" +"\n" +"Данная функция не проверенна на DirectML, но будет работать на большенстве " +"моделей Nvidia. Оно только ускорит тренировку на более недавних видеокартах. " +"Те, что имеют возможность вычислений('Compute Capability') 7.0 и выше, " +"получат самое большое ускорение от смешанной точности, потому что у них " +"имеются тензор ядра. Старые видеокарты предлагают никакого ускорения от " +"смешанной точности, однако экономия памяти и бóльшая пропускная способность " +"могут дать небольшое ускорение. В основном RTX видеокарты и позже предлагают " +"самое большое ускорение." + +#: plugins/train/train_config.py:201 +msgid "" +"If a 'NaN' is generated in the model, this means that the model has " +"corrupted and the model is likely to start deteriorating from this point on. " +"Enabling NaN protection will stop training immediately in the event of a " +"NaN. The last save will not contain the NaN, so you may still be able to " +"rescue your model." +msgstr "" +"Если 'Не число'(далее, NaN) сгенерировано в модели - это значит, что модель " +"повреждена и с этого момента, скорее всего, начнет деградировать. Включение " +"защиты от NaN немедленно остановит тренировку, в случае, если был обнаружен " +"NaN. Последнее сохранение не будет содержать в себе NaN, так что у вас будет " +"возможность спасти вашу модель." + +#: plugins/train/train_config.py:211 +msgid "convert" +msgstr "конвертирование" + +#: plugins/train/train_config.py:213 +msgid "" +"[GPU Only]. The number of faces to feed through the model at once when " +"running the Convert process.\n" +"\n" +"NB: Increasing this figure is unlikely to improve convert speed, however, if " +"you are getting Out of Memory errors, then you may want to reduce the batch " +"size." +msgstr "" +"[Только для видеокарт] Количество лиц, проходящих через модель в одно время " +"во время конвертирования\n" +"\n" +"Примечание: Увеличение этого значения вряд ли повлечет за собой ускорение " +"конвертирования, однако, если у вас появляются ошибки 'Out of Memory', тогда " +"стоит снизить размер пачки." + +#: plugins/train/train_config.py:224 +msgid "" +"Focal Frequency Loss. Analyzes the frequency spectrum of the images rather " +"than the images themselves. This loss function can be used on its own, but " +"the original paper found increased benefits when using it as a complementary " +"loss to another spacial loss function (e.g. MSE). Ref: Focal Frequency Loss " +"for Image Reconstruction and Synthesis https://arxiv.org/pdf/2012.12821.pdf " +"NB: This loss does not currently work on AMD cards." +msgstr "" +"Потеря фокальной частоты. Анализирует частотный спектр изображений, а не " +"сами изображения. Эта функция потерь может использоваться сама по себе, но в " +"оригинальной статье было обнаружено, что она дает больше преимуществ при " +"использовании в качестве дополнительной потери к другой пространственной " +"функции потерь (например, MSE). Ссылка: Focal Frequency Loss for Image " +"Reconstruction and Synthesis [ТОЛЬКО на английском] https://arxiv.org/" +"pdf/2012.12821.pdf NB: Эта потеря в настоящее время не работает на картах " +"AMD." + +#: plugins/train/train_config.py:231 +msgid "" +"Nvidia FLIP. A perceptual loss measure that approximates the difference " +"perceived by humans as they alternate quickly (or flip) between two images. " +"Used on its own and this loss function creates a distinct grid on the " +"output. However it can be helpful when used as a complimentary loss " +"function. Ref: FLIP: A Difference Evaluator for Alternating Images: https://" +"research.nvidia.com/sites/default/files/node/3260/FLIP_Paper.pdf" +msgstr "" +"Nvidia FLIP. Мера потерь восприятия, которая приближает разницу, " +"воспринимаемую человеком при быстром чередовании (или перелистывании) двух " +"изображений. Используемая сама по себе, эта функция потерь создает на выходе " +"отчетливую сетку. Однако она может быть полезна при использовании в качестве " +"дополнительной функции потерь. Ссылка: FLIP: A Difference Evaluator for " +"Alternating Images [ТОЛЬКО на английском]: https://research.nvidia.com/sites/" +"default/files/node/3260/FLIP_Paper.pdf" + +#: plugins/train/train_config.py:238 +msgid "" +"Gradient Magnitude Similarity Deviation seeks to match the global standard " +"deviation of the pixel to pixel differences between two images. Similar in " +"approach to SSIM. Ref: Gradient Magnitude Similarity Deviation: An Highly " +"Efficient Perceptual Image Quality Index https://arxiv.org/ftp/arxiv/" +"papers/1308/1308.3052.pdf" +msgstr "" +"Отклонение Схожести Магнитуды Градиентов(Gradient Magnitude Similarity " +"Deviation) пытается совместить глобальную стандартную девиацию различий " +"пикселя к пикселю между двумя изображениями. Подход похож на SSIM. Ссылка: " +"Gradient Magnitude Similarity Deviation: An Highly Efficient Perceptual " +"Image Quality Index [ТОЛЬКО на английском] https://arxiv.org/ftp/arxiv/" +"papers/1308/1308.3052.pdf" + +#: plugins/train/train_config.py:243 +msgid "" +"The L_inf norm will reduce the largest individual pixel error in an image. " +"As each largest error is minimized sequentially, the overall error is " +"improved. This loss will be extremely focused on outliers." +msgstr "" +"Норма L_inf уменьшает наибольшую ошибку отдельного пикселя в изображении. По " +"мере последовательной минимизации каждой наибольшей ошибки улучшается общая " +"ошибка. Эта потеря будет чрезвычайно сосредоточена на выбросах." + +#: plugins/train/train_config.py:247 +msgid "" +"Laplacian Pyramid Loss. Attempts to improve results by focussing on edges " +"using Laplacian Pyramids. As this loss function gives priority to edges over " +"other low-frequency information, like color, it should not be used on its " +"own. The original implementation uses this loss as a complimentary function " +"to MSE. Ref: Optimizing the Latent Space of Generative Networks https://" +"arxiv.org/abs/1707.05776" +msgstr "" +"Потеря пирамиды Лапласиана. Пытается улучшить результаты, концентрируясь на " +"краях с помощью пирамид Лапласиана. Поскольку эта функция потерь отдает " +"приоритет краям, а не другой низкочастотной информации, например, цвету, ее " +"не следует использовать самостоятельно. В оригинальной реализации эта потеря " +"используется как дополнительная функция к MSE. Ссылка: Optimizing the Latent " +"Space of Generative Networks [ТОЛЬКО на английском] https://arxiv.org/" +"abs/1707.05776" + +#: plugins/train/train_config.py:254 +msgid "" +"LPIPS is a perceptual loss that uses the feature outputs of other pretrained " +"models as a loss metric. Be aware that this loss function will use more " +"VRAM. Used on its own and this loss will create a distinct moire pattern on " +"the output, however it can be helpful as a complimentary loss function. The " +"output of this function is strong, so depending on your chosen primary loss " +"function, you are unlikely going to want to set the weight above about 25%. " +"Ref: The Unreasonable Effectiveness of Deep Features as a Perceptual Metric " +"http://arxiv.org/abs/1801.03924\n" +"This variant uses the AlexNet backbone. A fairly light and old model which " +"performed best in the paper's original implementation.\n" +"NB: For AMD Users the final linear layer is not implemented." +msgstr "" +"LPIPS - это перцептивная потеря, которая использует в качестве метрики " +"потерь выходные характеристики других предварительно обученных моделей. " +"Имейте в виду, что эта функция потерь использует больше VRAM. При " +"самостоятельном использовании эта потеря создает на выходе отчетливый " +"муаровый рисунок, однако она может быть полезна как дополнительная функция " +"потерь. Вывод этой функции является сильным, поэтому, в зависимости от " +"выбранной вами основной функции потерь, вы вряд ли захотите устанавливать " +"вес выше 25%. Ссылка: The Unreasonable Effectiveness of Deep Features as a " +"Perceptual Metric [ТОЛЬКО на английском] http://arxiv.org/abs/1801.03924.\n" +"Этот вариант использует основу AlexNet. Это довольно легкая и старая модель, " +"которая лучше всего показала себя в оригинальной реализации.\n" +"NB: Для пользователей AMD последний линейный слой не реализован." + +#: plugins/train/train_config.py:264 +msgid "" +"Same as lpips_alex, but using the SqueezeNet backbone. A more lightweight " +"version of AlexNet.\n" +"NB: For AMD Users the final linear layer is not implemented." +msgstr "" +"То же, что и lpips_alex, но использует основу SqueezeNet. Более облегченная " +"версия AlexNet.\n" +"NB: Для пользователей AMD последний линейный слой не реализован." + +#: plugins/train/train_config.py:267 +msgid "" +"Same as lpips_alex, but using the VGG16 backbone. A more heavyweight model.\n" +"NB: For AMD Users the final linear layer is not implemented." +msgstr "" +"То же, что и lpips_alex, но использует основу VGG16. Более тяжелая модель.\n" +"NB: Для пользователей AMD последний линейный слой не реализован." + +#: plugins/train/train_config.py:270 +msgid "" +"log(cosh(x)) acts similar to MSE for small errors and to MAE for large " +"errors. Like MSE, it is very stable and prevents overshoots when errors are " +"near zero. Like MAE, it is robust to outliers." +msgstr "" +"log(cosh(x)) действует аналогично MSE для малых ошибок и MAE для больших " +"ошибок. Как и MSE, он очень стабилен и предотвращает переборы, когда ошибки " +"близки к нулю. Как и MAE, он устойчив к выбросам." + +#: plugins/train/train_config.py:274 +msgid "" +"Mean absolute error will guide reconstructions of each pixel towards its " +"median value in the training dataset. Robust to outliers but as a median, it " +"can potentially ignore some infrequent image types in the dataset." +msgstr "" +"Средняя абсолютная погрешность направляет реконструкцию каждого пикселя к " +"его медианному значению в обучающем наборе данных. Устойчив к выбросам, но в " +"качестве медианы может игнорировать некоторые редкие типы изображений в " +"наборе данных." + +#: plugins/train/train_config.py:278 +msgid "" +"Mean squared error will guide reconstructions of each pixel towards its " +"average value in the training dataset. As an avg, it will be susceptible to " +"outliers and typically produces slightly blurrier results. Ref: Multi-Scale " +"Structural Similarity for Image Quality Assessment https://www.cns.nyu.edu/" +"pub/eero/wang03b.pdf" +msgstr "" +"Средняя квадратичная погрешность направляет реконструкцию каждого пикселя к " +"его среднему значению в наборе данных для обучения. Как среднее значение, " +"оно будет чувствительно к выбросам и обычно дает немного более размытые " +"результаты. Ссылка: Multi-Scale Structural Similarity for Image Quality " +"Assessment [ТОЛЬКО на английском]https://www.cns.nyu.edu/pub/eero/wang03b.pdf" + +#: plugins/train/train_config.py:283 +msgid "" +"Multiscale Structural Similarity Index Metric is similar to SSIM except that " +"it performs the calculations along multiple scales of the input image." +msgstr "" +"Метрика Индекса Многомасштабного Структурного Сходства (Multiscale " +"Structural Similarity Index Metric) похожа на SSIM, за исключением того, что " +"она выполняет вычисления по нескольким масштабам входного изображения." + +#: plugins/train/train_config.py:286 +msgid "" +"Smooth_L1 is a modification of the MAE loss to correct two of its " +"disadvantages. This loss has improved stability and guidance for small " +"errors. Ref: A General and Adaptive Robust Loss Function https://arxiv.org/" +"pdf/1701.03077.pdf" +msgstr "" +"Smooth_L1 - это модификация потери MAE для исправления двух ее недостатков. " +"Эта потеря улучшает стабильность и ориентирование при небольших " +"погрешностях. Ссылка: A General and Adaptive Robust Loss Function [ТОЛЬКО на " +"английском] https://arxiv.org/pdf/1701.03077.pdf" + +#: plugins/train/train_config.py:290 +msgid "" +"Structural Similarity Index Metric is a perception-based loss that considers " +"changes in texture, luminance, contrast, and local spatial statistics of an " +"image. Potentially delivers more realistic looking images. Ref: Image " +"Quality Assessment: From Error Visibility to Structural Similarity http://" +"www.cns.nyu.edu/pub/eero/wang03-reprint.pdf" +msgstr "" +"Метрика индекса структурного сходства ('Structural Similarity Index Metric') " +"- это основанная на восприятии потеря, которая учитывает изменения в " +"текстуре, яркости, контрасте и локальной пространственной статистике " +"изображения. Потенциально обеспечивает более реалистичный вид изображений. " +"Ссылка: Image Quality Assessment: From Error Visibility to Structural " +"Similarity [ТОЛЬКО на английском] http://www.cns.nyu.edu/pub/eero/wang03-" +"reprint.pdf" + +#: plugins/train/train_config.py:295 +msgid "" +"Instead of minimizing the difference between the absolute value of each " +"pixel in two reference images, compute the pixel to pixel spatial difference " +"in each image and then minimize that difference between two images. Allows " +"for large color shifts, but maintains the structure of the image." +msgstr "" +"Вместо того чтобы минимизировать разницу между абсолютным значением каждого " +"пикселя в двух образцовых изображениях, вычислить пространственную разницу " +"между пикселями в каждом изображении и затем минимизировать эту разницу " +"между двумя изображениями. Это позволяет получить большие цветовые сдвиги, " +"но сохраняет структуру изображения." + +#: plugins/train/train_config.py:299 +msgid "Do not use an additional loss function." +msgstr "Не использовать функцию дополнительных потерь." + +#: plugins/train/train_config.py:315 +msgid "" +"Loss configuration options\n" +"Loss is the mechanism by which a Neural Network judges how well it thinks " +"that it is recreating a face." +msgstr "" +"Настройки потерь\n" +"Потеря - механизм, по которому Нейронная Сеть судит, насколько хорошо она " +"воспроизводит лицо." + +#: plugins/train/train_config.py:321 plugins/train/train_config.py:331 +#: plugins/train/train_config.py:343 plugins/train/train_config.py:362 +#: plugins/train/train_config.py:372 plugins/train/train_config.py:391 +#: plugins/train/train_config.py:402 plugins/train/train_config.py:421 +#: plugins/train/train_config.py:436 plugins/train/train_config.py:450 +#: plugins/train/train_config.py:464 +msgid "loss" +msgstr "потери" + +#: plugins/train/train_config.py:322 +msgid "The loss function to use." +msgstr "Какую функцию потерь стоит использовать." + +#: plugins/train/train_config.py:333 +msgid "" +"The second loss function to use. If using a structural based loss (such as " +"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " +"regularization (MSE) function. You can adjust the weighting of this loss " +"function with the loss_weight_2 option.\n" +"\n" +"\t\n" +"\n" +"\t" +msgstr "" +"Вторая используемая функция потерь. При использовании потерь, основанных на " +"структуре (таких как SSIM, MS-SSIM или GMSD), обычно добавляется функция " +"регуляризации L1 (MAE) или регуляризации L2 (MSE). Вы можете настроить вес " +"этой функции потерь с помощью параметра loss_weight_2. \n" +"\n" +"\t\n" +"\n" +"\t" + +#: plugins/train/train_config.py:345 +msgid "" +"The amount of weight to apply to the second loss function.\n" +"\n" +"\n" +"\n" +"The value given here is as a percentage denoting how much the selected " +"function should contribute to the overall loss cost of the model. For " +"example:\n" +"\t 100 - The loss calculated for the second loss function will be applied at " +"its full amount towards the overall loss score. \n" +"\t 25 - The loss calculated for the second loss function will be reduced by " +"a quarter prior to adding to the overall loss score. \n" +"\t 400 - The loss calculated for the second loss function will be mulitplied " +"4 times prior to adding to the overall loss score. \n" +"\t 0 - Disables the second loss function altogether." +msgstr "" +"Величина веса, применяемая ко второй функции потерь.\n" +"\n" +"\n" +"\n" +"Значение задается в процентах и показывает, какой вклад выбранная функция " +"должна внести в общую стоимость потерь модели. Например:\n" +"\t 100 - Потери, рассчитанные для второй функции потерь, будут применены в " +"полном объеме к общей стоимости потерь. \n" +"\t25 - Потери, рассчитанные для второй функции потерь, будут уменьшены на " +"четверть перед добавлением к общей стоимости потерь. \n" +"\t400 - Потери, рассчитанные для второй функции потерь, будут умножены в 4 " +"раза перед добавлением к общей оценке потерь. \n" +"\t 0 - Полностью отключает вторую функцию потерь." + +#: plugins/train/train_config.py:363 +msgid "" +"The third loss function to use. You can adjust the weighting of this loss " +"function with the loss_weight_3 option.\n" +"\n" +"\t\n" +"\n" +"\t" +msgstr "" +"Третья используемая функция потерь. Вы можете настроить вес этой функции " +"потерь с помощью параметра loss_weight_3.\n" +"\n" +"\t\n" +"\n" +"\t" + +#: plugins/train/train_config.py:374 +msgid "" +"The amount of weight to apply to the third loss function.\n" +"\n" +"\n" +"\n" +"The value given here is as a percentage denoting how much the selected " +"function should contribute to the overall loss cost of the model. For " +"example:\n" +"\t 100 - The loss calculated for the third loss function will be applied at " +"its full amount towards the overall loss score. \n" +"\t 25 - The loss calculated for the third loss function will be reduced by a " +"quarter prior to adding to the overall loss score. \n" +"\t 400 - The loss calculated for the third loss function will be mulitplied " +"4 times prior to adding to the overall loss score. \n" +"\t 0 - Disables the third loss function altogether." +msgstr "" +"Величина веса, применяемая к третьей функции потерь.\n" +"\n" +"\n" +"\n" +"Значение задается в процентах и показывает, какой вклад выбранная функция " +"должна внести в общую стоимость потерь модели. Например:\n" +"\t 100 - Потери, рассчитанные для четвертой функции потерь, будут применены " +"в полном объеме к общей стоимости потерь. \n" +"\t25 - Потери, рассчитанные для четвертой функции потерь, будут уменьшены на " +"четверть перед добавлением к общей стоимости потерь. \n" +"\t400 - Потери, рассчитанные для четвертой функции потерь, будут умножены в " +"4 раза перед добавлением к общей оценке потерь. \n" +"\t 0 - Полностью отключает четвертую функцию потерь." + +#: plugins/train/train_config.py:393 +msgid "" +"The fourth loss function to use. You can adjust the weighting of this loss " +"function with the loss_weight_3 option.\n" +"\n" +"\t\n" +"\n" +"\t" +msgstr "" +"Четвертая используемая функция потерь. Вы можете настроить вес этой функции " +"потерь с помощью параметра 'loss_weight_4'.\n" +"\n" +"\t\n" +"\n" +"\t" + +#: plugins/train/train_config.py:404 +msgid "" +"The amount of weight to apply to the fourth loss function.\n" +"\n" +"\n" +"\n" +"The value given here is as a percentage denoting how much the selected " +"function should contribute to the overall loss cost of the model. For " +"example:\n" +"\t 100 - The loss calculated for the fourth loss function will be applied at " +"its full amount towards the overall loss score. \n" +"\t 25 - The loss calculated for the fourth loss function will be reduced by " +"a quarter prior to adding to the overall loss score. \n" +"\t 400 - The loss calculated for the fourth loss function will be mulitplied " +"4 times prior to adding to the overall loss score. \n" +"\t 0 - Disables the fourth loss function altogether." +msgstr "" +"Величина веса, применяемая к четвертой функции потерь.\n" +"\n" +"\n" +"\n" +"Значение задается в процентах и показывает, какой вклад выбранная функция " +"должна внести в общую стоимость потерь модели. Например:\n" +"\t 100 - Потери, рассчитанные для четвертой функции потерь, будут применены " +"в полном объеме к общей стоимости потерь. \n" +"\t25 - Потери, рассчитанные для четвертой функции потерь, будут уменьшены на " +"четверть перед добавлением к общей стоимости потерь. \n" +"\t400 - Потери, рассчитанные для четвертой функции потерь, будут умножены в " +"4 раза перед добавлением к общей оценке потерь. \n" +"\t 0 - Полностью отключает четвертую функцию потерь." + +#: plugins/train/train_config.py:423 +msgid "" +"The loss function to use when learning a mask.\n" +"\t MAE - Mean absolute error will guide reconstructions of each pixel " +"towards its median value in the training dataset. Robust to outliers but as " +"a median, it can potentially ignore some infrequent image types in the " +"dataset.\n" +"\t MSE - Mean squared error will guide reconstructions of each pixel towards " +"its average value in the training dataset. As an average, it will be " +"susceptible to outliers and typically produces slightly blurrier results." +msgstr "" +"Функция потерь, используемая при обучении маски.\n" +"\tMAE - средняя абсолютная погрешность('Mean absolute error') направляет " +"реконструкцию каждого пикселя к его срединному значению в обучающем наборе " +"данных. Устойчива к выбросам, но как медиана может игнорировать некоторые " +"редкие типы изображений в наборе данных.\n" +"\tMSE - средняя квадратичная погрешность('Mean squared error') направляет " +"реконструкцию каждого пикселя к его срединному значению в обучающем наборе " +"данных. Как среднее значение, оно чувствительно к выбросам и обычно дает " +"немного более размытые результаты." + +#: plugins/train/train_config.py:438 +msgid "" +"The amount of priority to give to the eyes.\n" +"\n" +"The value given here is as a multiplier of the main loss score. For " +"example:\n" +"\t 1 - The eyes will receive the same priority as the rest of the face. \n" +"\t 10 - The eyes will be given a score 10 times higher than the rest of the " +"face.\n" +"\n" +"NB: Penalized Mask Loss must be enable to use this option." +msgstr "" +"Величина приоритета, которую следует придать глазам.\n" +"\n" +"Значение дается как множитель основного показателя потерь. Например:\n" +"\t 1 - Глаза получат тот же приоритет, что и остальное лицо. \n" +"\t 10 - глаза получат оценку в 10 раз выше, чем остальные части лица.\n" +"\n" +"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию." + +#: plugins/train/train_config.py:452 +msgid "" +"The amount of priority to give to the mouth.\n" +"\n" +"The value given here is as a multiplier of the main loss score. For " +"Example:\n" +"\t 1 - The mouth will receive the same priority as the rest of the face. \n" +"\t 10 - The mouth will be given a score 10 times higher than the rest of the " +"face.\n" +"\n" +"NB: Penalized Mask Loss must be enable to use this option." +msgstr "" +"Величина приоритета, которую следует придать рту.\n" +"\n" +"Значение дается как множитель основного показателя потерь. Например:\n" +"\t 1 - Рот получит тот же приоритет, что и остальное лицо. \n" +"\t 10 - Рот получит оценку в 10 раз выше, чем остальные части лица.\n" +"\n" +"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию." + +#: plugins/train/train_config.py:466 +msgid "" +"Image loss function is weighted by mask presence. For areas of the image " +"without the facial mask, reconstruction errors will be ignored while the " +"masked face area is prioritized. May increase overall quality by focusing " +"attention on the core face area." +msgstr "" +"Функция потерь изображения взвешивается по наличию маски. Для областей " +"изображения без маски лица погрешности реконструкции игнорируются, в то " +"время как область лица с маской является приоритетной. Может повысить общее " +"качество за счет концентрации внимания на основной области лица." + +#: plugins/train/train_config.py:473 plugins/train/train_config.py:514 +#: plugins/train/train_config.py:525 plugins/train/train_config.py:539 +#: plugins/train/train_config.py:549 +msgid "mask" +msgstr "маска" + +#: plugins/train/train_config.py:475 +msgid "" +"The mask to be used for training. If you have selected 'Learn Mask' or " +"'Penalized Mask Loss' you must select a value other than 'none'. The " +"required mask should have been selected as part of the Extract process. If " +"it does not exist in the alignments file then it will be generated prior to " +"training commencing.\n" +"\tnone: Don't use a mask.\n" +"\tbisenet-fp_face: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'face' or " +"'legacy' centering.\n" +"\tbisenet-fp_head: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked (configurable in mask settings). " +"Use this version of bisenet-fp if your model is trained with 'head' " +"centering.\n" +"\tcomponents: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"\tcustom_face: Custom user created, face centered mask.\n" +"\tcustom_head: Custom user created, head centered mask.\n" +"\textended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"\tvgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"\tvgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"\tunet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members and will need " +"testing for further description. Profile faces may result in sub-par " +"performance." +msgstr "" +"Маска, которая будет использоваться для обучения. Если вы выбрали 'Learn " +"Mask' или 'Penalized Mask Loss', вы должны выбрать значение, отличное от " +"'none'. Необходимая маска должна быть выбрана в процессе извлечения. Если " +"она не существует в файле выравниваний, то она будет создана до начала " +"обучения.\n" +"\tnone: Не использовать маску.\n" +"\tbisenet-fp_face: Относительно легкая маска на основе NN, которая " +"обеспечивает более точный контроль над маскируемой областью (настраивается в " +"настройках маски). Используйте эту версию bisenet-fp, если ваша модель " +"обучена с центрированием 'face' или 'legacy'.\n" +"\tbisenet-fp_head: Относительно легкая маска на основе NN, которая " +"обеспечивает более точный контроль над маскируемой областью (настраивается в " +"параметрах маски). Используйте эту версию bisenet-fp, если ваша модель " +"обучена с центрированием 'head'.\n" +"\tcomponents: Маска, разработанная для сегментации лица на основе " +"расположения ориентиров. Для создания маски вокруг внешних ориентиров " +"строится выпуклая оболочка.\n" +"\tcustom_face: Пользовательская маска, созданная пользователем и " +"центрированная по лицу.\n" +"\tcustom_head: Созданная пользователем маска, центрированная по голове.\n" +"\textended: Маска, разработанная для сегментации лица на основе расположения " +"ориентиров. Выпуклый корпус строится вокруг внешних ориентиров, и маска " +"расширяется вверх на лоб.\n" +"\tvgg-clear: Маска предназначена для интеллектуальной сегментации " +"преимущественно фронтальных лиц без препятствий. Профильные лица и " +"препятствия могут привести к снижению производительности.\n" +"\tvgg-obstructed: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была специально обучена " +"распознавать некоторые препятствия на лице (руки и очки). Профильные лица " +"могут иметь низкую производительность.\n" +"\tunet-dfl: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была обучена членами " +"сообщества и для дальнейшего описания нуждается в тестировании. Профильные " +"лица могут иметь низкую производительность." + +#: plugins/train/train_config.py:516 +msgid "" +"Dilate or erode the mask. Negative values erode the mask (make it smaller). " +"Positive values dilate the mask (make it larger). The value given is a " +"percentage of the total mask size." +msgstr "" +"Расширяет или сужает маску. Отрицательные значения сужают маску (делают её " +"меньше). Положительные значения расширяют маску (делают её больше)." + +#: plugins/train/train_config.py:527 +msgid "" +"Apply gaussian blur to the mask input. This has the effect of smoothing the " +"edges of the mask, which can help with poorly calculated masks and give less " +"of a hard edge to the predicted mask. The size is in pixels (calculated from " +"a 128px mask). Set to 0 to not apply gaussian blur. This value should be " +"odd, if an even number is passed in then it will be rounded to the next odd " +"number." +msgstr "" +"Применить размытие по Гауссу на входную маску. Дает эффект сглаживания краев " +"маски, что может помочь с плохо вычисленными масками и дает менее резкий " +"край предугаданной маске. Размер в пикселях (вычисленно из маски на 128 " +"пикселей). Установите 0, чтобы не применять размытие по Гауссу. Это значение " +"должно быть нечетным, если передано четное число, то оно будет округлено до " +"следующего нечетного числа." + +#: plugins/train/train_config.py:541 +msgid "" +"Sets pixels that are near white to white and near black to black. Set to 0 " +"for off." +msgstr "" +"Устанавливает пиксели, которые почти белые - в белые и которые почти черные " +"- в черные. Установите 0, чтобы выключить." + +#: plugins/train/train_config.py:551 +msgid "" +"Dedicate a portion of the model to learning how to duplicate the input mask. " +"Increases VRAM usage in exchange for learning a quick ability to try to " +"replicate more complex mask models." +msgstr "" +"Выделить частичку модели обучению тому, как дублировать входную маску. " +"Увеличивает использование видеопамяти в обмен на обучение быстрой " +"способности попытки переделывать более сложные маски." + +#: plugins/train/train_config.py:559 +msgid "" +"Optimizer configuration options\n" +"The optimizer applies the output of the loss function to the model.\n" +msgstr "" +"Настройки оптимизатора\n" +"Оптимизатор использует значения функции потерь для обновления параметров " +"модели.\n" + +#: plugins/train/train_config.py:565 plugins/train/train_config.py:600 +#: plugins/train/train_config.py:613 plugins/train/train_config.py:634 +msgid "optimizer" +msgstr "оптимизатор" + +#: plugins/train/train_config.py:567 +msgid "" +"The optimizer to use.\n" +"\t adabelief - Adapting Stepsizes by the Belief in Observed Gradients. An " +"optimizer with the aim to converge faster, generalize better and remain more " +"stable. (https://arxiv.org/abs/2010.07468). NB: Epsilon for AdaBelief needs " +"to be set to a smaller value than other Optimizers. Generally setting the " +"'Epsilon Exponent' to around '-16' should work.\n" +"\t adam - Adaptive Moment Optimization. A stochastic gradient descent method " +"that is based on adaptive estimation of first-order and second-order " +"moments.\n" +"\t adamax - a variant of Adam based on the infinity norm. Due to its " +"capability of adjusting the learning rate based on data characteristics, it " +"is suited to learn time-variant process, parameters follow those provided in " +"the paper\n" +"\t adamw - Like 'adam' but with an added method to decay weights per the " +"techniques discussed in the paper (https://arxiv.org/abs/1711.05101). NB: " +"Weight decay should be set at 0.004 for default implementation.\n" +"\t lion - A method that uses the sign operator to control the magnitude of " +"the update, rather than relying on second-order moments (Adam). saves VRAM " +"by only tracking the momentum. Performance gains should be better with " +"larger batch sizes. A suitable learning rate for Lion is typically 3-10x " +"smaller than that for AdamW. The weight decay for Lion should be 3-10x " +"larger than that for AdamW to maintain a similar strength.\n" +"\t nadam - Adaptive Moment Optimization with Nesterov Momentum. Much like " +"Adam but uses a different formula for calculating momentum.\n" +"\t rms-prop - Root Mean Square Propagation. Maintains a moving (discounted) " +"average of the square of the gradients. Divides the gradient by the root of " +"this average." +msgstr "" +"Используемый оптимизатор.\n" +"\t adabelief - Адаптация размеров шагов по убеждению в наблюдаемых " +"градиентах('Adapting Stepsizes by the Belief in Observed Gradients'). " +"Оптимизатор, цель которого - быстрее сходиться, лучше обобщаться и " +"оставаться более стабильным. ([ТОЛЬКО на английском] https://arxiv.org/" +"abs/2010.07468). Примечание: значение Epsilon для AdaBelief должно быть " +"меньше, чем для других оптимизаторов. Как правило, значение 'Epsilon " +"Exponent' должно быть около '-16'.\n" +"\t adam - Адаптивная оптимизация моментов('Adaptive Moment Optimization'). " +"Стохастический метод градиентного спуска, основанный на адаптивной оценке " +"моментов первого и второго порядка.\n" +"\t adamax — вариант Adam, основанный на норме бесконечности (infinity norm). " +"Благодаря способности адаптировать скорость обучения в зависимости от " +"характеристик данных, он подходит для обучения процессам с изменяющимися во " +"времени характеристиками (time-variant processes). Параметры следуют " +"значениям, указанным в статье.\n" +"\t adamw — похож на 'Adam', но с добавленным методом затухания весов (weight " +"decay) в соответствии с техниками, описанными в статье. Примечание: Для " +"стандартной реализации коэффициент weight decay рекомендуется установить на " +"0.004.\n" +"\t lion — метод, который использует оператор знака для контроля величины " +"обновления, вместо зависимости от моментов второго порядка (как в Adam). " +"Экономит VRAM, отслеживая только моментум. Прирост производительности лучше " +"проявляется при больших размерах пачки. Подходящая скорость обучения для " +"Lion обычно в 3–10 раз меньше, чем для AdamW. Weight decay для Lion следует " +"делать в 3–10 раз больше, чем для AdamW, чтобы сохранить аналогичную силу " +"регуляризации.\n" +"\t nadam - Адаптивная оптимизация моментов с моментумом Нестерова ('Adaptive " +"Moment Optimization with Nesterov Momentum'). Похож на Adam, но использует " +"другую формулу для вычисления момента.\n" +"rms-prop - Распространение корневого среднего квадрата ('Root Mean Square " +"Propagation'). Поддерживает скользящее (дисконтированное) среднее квадрата " +"градиентов. Делит градиент на корень из этого среднего." + +#: plugins/train/train_config.py:602 +msgid "" +"Learning rate - how fast your network will learn (how large are the " +"modifications to the model weights after one batch of training). Values that " +"are too large might result in model crashes and the inability of the model " +"to find the best solution. Values that are too small might be unable to " +"escape from dead-ends and find the best global minimum." +msgstr "" +"Скорость обучения - насколько быстро ваша модель будет обучаться (насколько " +"огромны изменения весов модели после одной пачки тренировки). Слишком " +"большие значения могут привести к крахам модели и невозможности модели найти " +"лучшее решение. Слишком маленькие значения могут привести к невозможности " +"выбраться из тупиков и найти лучший глобальный минимум." + +#: plugins/train/train_config.py:615 +msgid "" +"The epsilon adds a small constant to weight updates to attempt to avoid " +"'divide by zero' errors. Unless you are using the AdaBelief Optimizer, then " +"Generally this option should be left at default value, For AdaBelief, " +"setting this to around '-16' should work.\n" +"In all instances if you are getting 'NaN' loss values, and have been unable " +"to resolve the issue any other way (for example, increasing batch size, or " +"lowering learning rate), then raising the epsilon can lead to a more stable " +"model. It may, however, come at the cost of slower training and a less " +"accurate final result.\n" +"Note: The value given here is the 'exponent' to the epsilon. For example, " +"choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the " +"epsilon to 0.001 (1e-3).\n" +"Note: Not used by the Lion optimizer" +msgstr "" +"Эпсилон добавляет небольшую константу к обновлениям веса, чтобы попытаться " +"избежать ошибок \"деления на ноль\". Если вы не используете оптимизатор " +"AdaBelief, то, как правило, этот параметр следует оставить по умолчанию. Для " +"AdaBelief подойдет значение около '-16'.\n" +"Во всех случаях, если вы получаете значения потерь 'NaN' и не смогли решить " +"проблему другим способом (например, увеличив размер пачки или уменьшив " +"скорость обучения), то увеличение эпсилона может привести к более стабильной " +"модели. Однако это может стоить более медленного обучения и менее точного " +"конечного результата.\n" +"Примечание: Значение, указанное здесь, является \"экспонентой\" к эпсилону. " +"Например, при выборе значения '-7' эпсилон будет равен 1e-7. При выборе " +"значения \"-3\" эпсилон будет равен 0,001 (1e-3).\n" +"Примечание: Не используется оптимизатором Lion" + +#: plugins/train/train_config.py:636 +msgid "" +"When to save the Optimizer Weights. Saving the optimizer weights is not " +"necessary and will increase the model file size 3x (and by extension the " +"amount of time it takes to save the model). However, it can be useful to " +"save these weights if you want to guarantee that a resumed model carries off " +"exactly from where it left off, rather than spending a few hundred " +"iterations catching up.\n" +"\t never - Don't save optimizer weights.\n" +"\t always - Save the optimizer weights at every save iteration. Model saving " +"will take longer, due to the increased file size, but you will always have " +"the last saved optimizer state in your model file.\n" +"\t exit - Only save the optimizer weights when explicitly terminating a " +"model. This can be when the model is actively stopped or when the target " +"iterations are met. Note: If the training session ends because of another " +"reason (e.g. power outage, Out of Memory Error, NaN detected) then the " +"optimizer weights will NOT be saved." +msgstr "" +"Когда сохранять веса оптимизатора. Сохранение весов оптимизатора не является " +"необходимым и увеличит размер файла модели в 3 раза (и соответственно время, " +"необходимое для сохранения модели). Однако может быть полезно сохранить эти " +"веса, если вы хотите гарантировать, что возобновленная модель продолжит " +"работу именно с того места, где она остановилась, а не тратит несколько " +"сотен итераций на догонялки.\n" +"\t never - не сохранять веса оптимизатора.\n" +"\t always - сохранять веса оптимизатора при каждой итерации сохранения. " +"Сохранение модели займет больше времени из-за увеличенного размера файла, но " +"в файле модели всегда будет последнее сохраненное состояние оптимизатора.\n" +"\t exit - сохранять веса оптимизатора только при явном завершении модели. " +"Это может быть, когда модель активно останавливается или когда выполняются " +"целевые итерации. Примечание. Если сеанс обучения завершается по другой " +"причине (например, отключение питания, ошибка нехватки памяти, обнаружение " +"NaN), веса оптимизатора НЕ будут сохранены." + +#: plugins/train/train_config.py:657 plugins/train/train_config.py:676 +#: plugins/train/train_config.py:695 +msgid "clipping" +msgstr "клиппинг" + +#: plugins/train/train_config.py:659 +msgid "" +"Apply clipping to the gradients. Can help prevent NaNs and improve model " +"optimization at the expense of VRAM.\n" +"\tautoclip: Analyzes the gradient weights and adjusts the normalization " +"value dynamically to fit the data\n" +"\tglobal_norm: Clips the gradient of each weight so that the global norm is " +"no higher than the given value.\n" +"\tnorm: Clips the gradient of each weight so that its norm is no higher than " +"the given value.\n" +"\tvalue: Clips the gradient of each weight so that it is no higher than the " +"given value.\n" +"\tnone: Don't perform any clipping to the gradients." +msgstr "" +"Применять клиппинг (обрезку) градиентов. Помогает предотвратить NaN'ы и " +"улучшить оптимизацию модели, но за счёт увеличения расхода VRAM.\n" +"\tautoclip: Анализирует значения градиентов и динамически подстраивает порог " +"нормализации под текущие данные.\n" +"\tglobal_norm: Обрезает градиенты так, чтобы глобальная норма (норма всего " +"вектора градиентов модели) не превышала заданного значения.\n" +"\tnorm: Обрезает градиенты так, чтобы норма не превышала заданного " +"значения.\n" +"\tvalue: Обрезает градиенты по значению — каждый элемент градиента " +"ограничивается диапазоном [-value, value].\n" +"\tnone: Не выполнять обрезку градиентов." + +#: plugins/train/train_config.py:678 +msgid "" +"The amount of clipping to perform.\n" +"\tautoclip: The percentile to clip at. A value of 1.0 will clip at the 10th " +"percentile a value of 2.5 will clip at the 25th percentile etc. Default: " +"1.0\n" +"\tglobal_norm: The gradient of each weight is clipped so that the global " +"norm is no higher than this value.\n" +"\tnorm: The gradient of each weight is clipped so that its norm is no higher " +"than this value.\n" +"\tvalue: The gradient of each weight is clipped to be no higher than this " +"value.\n" +"\tnone: This option is ignored." +msgstr "" +"Величина обрезки градиентов.\n" +"\tautoclip: Процентиль, по которому выполняется обрезка. Значение 1.0 — " +"обрезка по 10-му процентилю, 2.5 — по 25-му процентилю и т.д. По умолчанию: " +"1.0\n" +"\tglobal_norm: Градиенты обрезаются так, чтобы глобальная норма не превышала " +"это значение.\n" +"\tnorm: Градиенты обрезаются так, чтобы норма не превышала это значение.\n" +"\tvalue: Каждый элемент градиента обрезается по абсолютному значению " +"(диапазон [-value, value]).\n" +"\tnone: Эта опция игнорируется." + +#: plugins/train/train_config.py:697 +msgid "" +"The maximum number of prior iterations for autoclipper to analyze when " +"calculating the normalization amount. 0 to always include all prior " +"iterations." +msgstr "" +"Максимальное количество предыдущих итераций, которые автоклиппер анализирует " +"при расчёте величины нормализации. Значение 0 означает, что всегда " +"учитываются все предыдущие итерации." + +#: plugins/train/train_config.py:706 plugins/train/train_config.py:715 +msgid "updates" +msgstr "обновления" + +#: plugins/train/train_config.py:707 +msgid "" +"If set, weight decay is applied. 0.0 for no weight decay. Default is 0.0 for " +"all optimizers except AdamW (0.004)" +msgstr "" +"Если задано значение больше 0, применяется затухание весов (weight decay). " +"Значение 0.0 отключает затухание. По умолчанию 0.0 для всех оптимизаторов, " +"кроме AdamW (0.004)." + +#: plugins/train/train_config.py:717 +msgid "" +"Values above 1 will enable Gradient Accumulation. Updates will not be at " +"every iteration; instead they will occur every number of iterations given " +"here. The update will be the average value of the gradients since the last " +"update. Can be useful when your batch size is very small, in order to reduce " +"gradient noise at each update iteration." +msgstr "" +"Значения больше 1 включают накопление градиентов (Gradient Accumulation). " +"Обновление параметров будет происходить не на каждой итерации, а каждые " +"указанное здесь количество итераций. При обновлении будет использоваться " +"среднее значение градиентов, накопленных с момента последнего обновления. " +"Полезно, когда размер пачки очень мал — позволяет уменьшить шум градиентов " +"на каждом шаге обновления." + +#: plugins/train/train_config.py:728 plugins/train/train_config.py:738 +#: plugins/train/train_config.py:749 +msgid "exponential moving average" +msgstr "экспоненциальная скользящая средняя" + +#: plugins/train/train_config.py:730 +msgid "" +"Enable exponential moving average (EMA). EMA consists of computing an " +"exponential moving average of the weights of the model (as the weight values " +"change after each training batch), and periodically overwriting the weights " +"with their moving average" +msgstr "" +"Включить экспоненциальную скользящую среднюю (EMA) весов. EMA подразумевает " +"расчёт экспоненциальной скользящей средней весов модели по мере их " +"обновления после каждой пачки, с периодической заменой текущих весов на эту " +"среднюю" + +#: plugins/train/train_config.py:740 +msgid "" +"Only used if use_ema is enabled. This is the momentum to use when computing " +"the EMA of the model's weights: new_average = ema_momentum * old_average + " +"(1 - ema_momentum) * current_variable_value." +msgstr "" +"Параметр активен только при включённой EMA. Определяет коэффициент momentum " +"для экспоненциальной скользящей средней весов модели по формуле: new_average " +"= ema_momentum × old_average + (1 - ema_momentum) × current_variable_value." + +#: plugins/train/train_config.py:751 +msgid "" +"Only used if use_ema is enabled. Set the number of iterations, to overwrite " +"the model variable by its moving average. " +msgstr "" +"Активен только при включённой EMA. Указывает интервал в итерациях, после " +"которого веса основной модели заменяются на значения их экспоненциальной " +"скользящей средней. " + +#: plugins/train/train_config.py:759 plugins/train/train_config.py:770 +#: plugins/train/train_config.py:781 +msgid "optimizer specific" +msgstr "параметры, специфичные для оптимизатора" + +#: plugins/train/train_config.py:761 +msgid "" +"The exponential decay rate for the 1st moment estimates. Used for the " +"following Optimizers: AdaBelief, Adam, Adamax, AdamW, Lion, nAdam. Ignored " +"for all others." +msgstr "" +"Коэффициент экспоненциального затухания для среднего градиента первого " +"момента. Применяется только к оптимизаторам: AdaBelief, Adam, Adamax, AdamW, " +"Lion, nAdam. Для остальных оптимизаторов игнорируется." + +#: plugins/train/train_config.py:772 +msgid "" +"The exponential decay rate for the 2nd moment estimates. Used for the " +"following Optimizers: AdaBelief, Adam, Adamax, AdamW, Lion, nAdam. Ignored " +"for all others." +msgstr "" +"Коэффициент экспоненциального затухания для среднего градиента второго " +"момента. Применяется только к оптимизаторам: AdaBelief, Adam, Adamax, " +"AdamW, Lion, nAdam. Для остальных оптимизаторов игнорируется." + +#: plugins/train/train_config.py:783 +msgid "" +"Whether to apply AMSGrad variant of the algorithm from the paper 'On the " +"Convergence of Adam and beyond. Used for the following Optimizers: " +"AdaBelief, Adam, AdamW. Ignored for all others.'" +msgstr "" +"Применять ли вариант AMSGrad алгоритма из статьи «On the Convergence of Adam " +"and Beyond». Используется только для следующих оптимизаторов: AdaBelief, " +"Adam, AdamW. Для всех остальных игнорируется." + +#~ msgid "" +#~ "The amount of weight to apply to the second loss function.\n" +#~ "\n" +#~ "\n" +#~ "\n" +#~ "The value given here is as a percentage denoting how much the selected " +#~ "function should contribute to the overall loss cost of the model. For " +#~ "example:\n" +#~ "\t 100 - The loss calculated for the fourth loss function will be applied " +#~ "at its full amount towards the overall loss score. \n" +#~ "\t 25 - The loss calculated for the fourth loss function will be reduced " +#~ "by a quarter prior to adding to the overall loss score. \n" +#~ "\t 400 - The loss calculated for the fourth loss function will be " +#~ "mulitplied 4 times prior to adding to the overall loss score. \n" +#~ "\t 0 - Disables the fourth loss function altogether." +#~ msgstr "" +#~ "Величина веса, применяемая к второй функции потерь.\n" +#~ "\n" +#~ "\n" +#~ "\n" +#~ "Значение задается в процентах и показывает, какой вклад выбранная функция " +#~ "должна внести в общую стоимость потерь модели. Например:\n" +#~ "\t 100 - Потери, рассчитанные для второй функции потерь, будут применены " +#~ "в полном объеме к общей стоимости потерь. \n" +#~ "\t25 - Потери, рассчитанные для второй функции потерь, будут уменьшены на " +#~ "четверть перед добавлением к общей стоимости потерь. \n" +#~ "\t400 - Потери, рассчитанные для второй функции потерь, будут умножены в " +#~ "4 раза перед добавлением к общей оценке потерь. \n" +#~ "\t 0 - Полностью отключает вторую функцию потерь." + +#, fuzzy +#~| msgid "" +#~| "The amount of weight to apply to the fourth loss function.\n" +#~| "\n" +#~| "\n" +#~| "\n" +#~| "The value given here is as a percentage denoting how much the selected " +#~| "function should contribute to the overall loss cost of the model. For " +#~| "example:\n" +#~| "\t 100 - The loss calculated for the fourth loss function will be " +#~| "applied at its full amount towards the overall loss score. \n" +#~| "\t 25 - The loss calculated for the fourth loss function will be reduced " +#~| "by a quarter prior to adding to the overall loss score. \n" +#~| "\t 400 - The loss calculated for the fourth loss function will be " +#~| "mulitplied 4 times prior to adding to the overall loss score. \n" +#~| "\t 0 - Disables the fourth loss function altogether." +#~ msgid "" +#~ "The amount of weight to apply to the third loss function.\n" +#~ "\n" +#~ "\n" +#~ "\n" +#~ "The value given here is as a percentage denoting how much the selected " +#~ "function should contribute to the overall loss cost of the model. For " +#~ "example:\n" +#~ "\t 100 - The loss calculated for the fourth loss function will be applied " +#~ "at its full amount towards the overall loss score. \n" +#~ "\t 25 - The loss calculated for the fourth loss function will be reduced " +#~ "by a quarter prior to adding to the overall loss score. \n" +#~ "\t 400 - The loss calculated for the fourth loss function will be " +#~ "mulitplied 4 times prior to adding to the overall loss score. \n" +#~ "\t 0 - Disables the fourth loss function altogether." +#~ msgstr "" +#~ "Величина веса, применяемая к четвертой функции потерь.\n" +#~ "\n" +#~ "\n" +#~ "\n" +#~ "Значение задается в процентах и показывает, какой вклад выбранная функция " +#~ "должна внести в общую стоимость потерь модели. Например:\n" +#~ "\t 100 - Потери, рассчитанные для четвертой функции потерь, будут " +#~ "применены в полном объеме к общей стоимости потерь. \n" +#~ "\t25 - Потери, рассчитанные для четвертой функции потерь, будут уменьшены " +#~ "на четверть перед добавлением к общей стоимости потерь. \n" +#~ "\t400 - Потери, рассчитанные для четвертой функции потерь, будут умножены " +#~ "в 4 раза перед добавлением к общей оценке потерь. \n" +#~ "\t 0 - Полностью отключает четвертую функцию потерь." + +#~ msgid "" +#~ "Apply AutoClipping to the gradients. AutoClip analyzes the gradient " +#~ "weights and adjusts the normalization value dynamically to fit the data. " +#~ "Can help prevent NaNs and improve model optimization at the expense of " +#~ "VRAM. Ref: AutoClip: Adaptive Gradient Clipping for Source Separation " +#~ "Networks https://arxiv.org/abs/2007.14469" +#~ msgstr "" +#~ "Применить AutoClipping к градиентам. AutoClip анализирует веса градиентов " +#~ "и динамически корректирует значение нормализации, чтобы оно подходило к " +#~ "данным. Может помочь избежать NaN('не число') и улучшить оптимизацию " +#~ "модели ценой видеопамяти. Ссылка: AutoClip: Adaptive Gradient Clipping " +#~ "for Source Separation Networks [ТОЛЬКО на английском] https://arxiv.org/" +#~ "abs/2007.14469" + +#~ msgid "" +#~ "Enable the Tensorflow GPU 'allow_growth' configuration option. This " +#~ "option prevents Tensorflow from allocating all of the GPU VRAM at launch " +#~ "but can lead to higher VRAM fragmentation and slower performance. Should " +#~ "only be enabled if you are receiving errors regarding 'cuDNN fails to " +#~ "initialize' when commencing training." +#~ msgstr "" +#~ "[Только для Nvidia]. Включите опцию конфигурации Tensorflow GPU " +#~ "`allow_growth`. Эта опция не позволяет Tensorflow выделять всю " +#~ "видеопамять видеокарты при запуске, но может привести к повышенной " +#~ "фрагментации видеопамяти и снижению производительности. Следует включать " +#~ "только в том случае, если у вас появляются ошибки, рода 'cuDNN fails to " +#~ "initialize'(cuDNN не может инициализироваться) при начале тренировки." diff --git a/locales/ru/LC_MESSAGES/tools.alignments.cli.mo b/locales/ru/LC_MESSAGES/tools.alignments.cli.mo new file mode 100644 index 0000000000..5277c793d2 Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.alignments.cli.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.alignments.cli.po b/locales/ru/LC_MESSAGES/tools.alignments.cli.po new file mode 100644 index 0000000000..3f68a44acf --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.alignments.cli.po @@ -0,0 +1,276 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-19 11:28+0100\n" +"PO-Revision-Date: 2024-04-19 11:31+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/alignments/cli.py:16 +msgid "" +"This command lets you perform various tasks pertaining to an alignments file." +msgstr "" +"Эта команда позволяет выполнять различные задачи, относящиеся к файлу " +"выравнивания." + +#: tools/alignments/cli.py:31 +msgid "" +"Alignments tool\n" +"This tool allows you to perform numerous actions on or using an alignments " +"file against its corresponding faceset/frame source." +msgstr "" +"Инструмент выравнивания\n" +"Этот инструмент позволяет выполнять многочисленные действия с файлом " +"выравнивания или с его использованием против соответствующего набора лиц/" +"кадров." + +#: tools/alignments/cli.py:43 +msgid " Must Pass in a frames folder/source video file (-r)." +msgstr " Должен проходить в папке с кадрами/исходным видеофайлом (-r)." + +#: tools/alignments/cli.py:44 +msgid " Must Pass in a faces folder (-c)." +msgstr " Должен проходить в папке с лицами (-c)." + +#: tools/alignments/cli.py:45 +msgid "" +" Must Pass in either a frames folder/source video file OR a faces folder (-r " +"or -c)." +msgstr "" +" Должно передаваться либо в папку с кадрами/исходным видеофайлом, либо в " +"папку с лицами (-r или -c)." + +#: tools/alignments/cli.py:47 +msgid "" +" Must Pass in a frames folder/source video file AND a faces folder (-r and -" +"c)." +msgstr "" +" Должно передаваться либо в папку с кадрами/исходным видеофайлом И в папку с " +"лицами (-r и -c)." + +#: tools/alignments/cli.py:49 +msgid " Use the output option (-o) to process results." +msgstr " Используйте опцию вывода (-o) для обработки результатов." + +#: tools/alignments/cli.py:58 tools/alignments/cli.py:104 +msgid "processing" +msgstr "обработка" + +#: tools/alignments/cli.py:61 +#, python-brace-format +msgid "" +"R|Choose which action you want to perform. NB: All actions require an " +"alignments file (-a) to be passed in.\n" +"L|'draw': Draw landmarks on frames in the selected folder/video. A subfolder " +"will be created within the frames folder to hold the output.{0}\n" +"L|'export': Export the contents of an alignments file to a json file. Can be " +"used for editing alignment information in external tools and then re-" +"importing by using Faceswap's Extract 'Import' plugins. Note: masks and " +"identity vectors will not be included in the exported file, so will be re-" +"generated when the json file is imported back into Faceswap. All data is " +"exported with the origin (0, 0) at the top left of the canvas.\n" +"L|'extract': Re-extract faces from the source frames/video based on " +"alignment data. This is a lot quicker than re-detecting faces. Can pass in " +"the '-een' (--extract-every-n) parameter to only extract every nth frame." +"{1}\n" +"L|'from-faces': Generate alignment file(s) from a folder of extracted faces. " +"if the folder of faces comes from multiple sources, then multiple alignments " +"files will be created. NB: for faces which have been extracted from folders " +"of source images, rather than a video, a single alignments file will be " +"created as there is no way for the process to know how many folders of " +"images were originally used. You do not need to provide an alignments file " +"path to run this job. {3}\n" +"L|'missing-alignments': Identify frames that do not exist in the alignments " +"file.{2}{0}\n" +"L|'missing-frames': Identify frames in the alignments file that do not " +"appear within the frames folder/video.{2}{0}\n" +"L|'multi-faces': Identify where multiple faces exist within the alignments " +"file.{2}{4}\n" +"L|'no-faces': Identify frames that exist within the alignment file but no " +"faces were detected.{2}{0}\n" +"L|'remove-faces': Remove deleted faces from an alignments file. The original " +"alignments file will be backed up.{3}\n" +"L|'rename' - Rename faces to correspond with their parent frame and position " +"index in the alignments file (i.e. how they are named after running extract)." +"{3}\n" +"L|'sort': Re-index the alignments from left to right. For alignments with " +"multiple faces this will ensure that the left-most face is at index 0.\n" +"L|'spatial': Perform spatial and temporal filtering to smooth alignments " +"(EXPERIMENTAL!)" +msgstr "" +"R|Выберите действие, которое вы хотите выполнить. Примечание: Все действия " +"требуют передачи файла выравнивания (-a).\n" +"L|'draw': Нарисовать ориентиры на кадрах в выбранной папке/видео. В папке " +"frames будет создана подпапка для хранения результатов.\n" +"L|'export': экспортировать содержимое файла выравнивания в файл JSON. Может " +"использоваться для редактирования информации о выравнивании во внешних " +"инструментах, а затем повторно импортируется с помощью плагинов Faceswap " +"Extract 'Import'. ПРИМЕЧАНИЕ. Маски и векторы идентификации не будут " +"включены в экспортированный файл, поэтому будут повторно сгенерированы, " +"когда файл JSON будет импортирован обратно в Faceswap. Все данные " +"экспортируются с началом координат (0, 0) в верхнем левом углу холста.\n" +"L|'extract': Повторное извлечение лиц из исходных кадров/видео на основе " +"данных о выравнивании. Это намного быстрее, чем повторное обнаружение лиц. " +"Можно передать параметр '-een' (--extract-every-n), чтобы извлекать только " +"каждый n-й кадр.{1}\n" +"L|'from-faces': Создать файл(ы) выравнивания из папки с извлеченными лицами. " +"Если папка с лицами получена из нескольких источников, то будет создано " +"несколько файлов выравнивания. Примечание: для лиц, которые были извлечены " +"из папок с исходными изображениями, а не из видео, будет создан один файл " +"выравнивания, поскольку процесс не может знать, сколько папок с " +"изображениями было использовано изначально. Для выполнения этого задания не " +"нужно указывать путь к файлу выравнивания. {3}\n" +"L|'missing-alignments': Определить кадры, которых нет в файле выравнивания." +"{2}{0}\n" +"L|'missing-frames': Определить кадры в файле выравнивания, которые не " +"появляются в папке frames/video.{2}{0}\n" +"L|'multi-faces': Определить, где в файле выравнивания существует несколько " +"лиц.{2}{4}\n" +"L|'no-faces': Идентифицировать кадры, которые существуют в файле " +"выравнивания, но лица не были обнаружены.{2}{0}\n" +"L|'remove-faces': Удалить удаленные лица из файла выравнивания. Оригинальный " +"файл выравнивания будет сохранен.{3}\n" +"L|'rename' - Переименовать лица в соответствии с их родительским кадром и " +"индексом позиции в файле выравниваний (т.е. как они будут названы после " +"запуска extract).{3}\n" +"L|'sort': Переиндексирует выравнивания слева направо. Для выравниваний с " +"несколькими гранями это гарантирует, что самое левое лицо будет иметь индекс " +"0.\n" +"L|'spatial': Выполнить пространственную и временную фильтрацию для " +"сглаживания выравниваний (ЭКСПЕРИМЕНТАЛЬНО!)." + +#: tools/alignments/cli.py:107 +msgid "" +"R|How to output discovered items ('faces' and 'frames' only):\n" +"L|'console': Print the list of frames to the screen. (DEFAULT)\n" +"L|'file': Output the list of frames to a text file (stored within the source " +"directory).\n" +"L|'move': Move the discovered items to a sub-folder within the source " +"directory." +msgstr "" +"R|Как вывести обнаруженные элементы (только \"лица\" и \"кадры\"):\n" +"L|'console': Вывести список рамок на экран. (DEFAULT)\n" +"L|'file': Вывести список кадров в текстовый файл (хранящийся в исходном " +"каталоге).\n" +"L|'move': Переместить обнаруженные элементы в подпапку в исходном каталоге." + +#: tools/alignments/cli.py:118 tools/alignments/cli.py:141 +#: tools/alignments/cli.py:148 +msgid "data" +msgstr "данные" + +#: tools/alignments/cli.py:125 +msgid "" +"Full path to the alignments file to be processed. If you have input a " +"'frames_dir' and don't provide this option, the process will try to find the " +"alignments file at the default location. All jobs require an alignments file " +"with the exception of 'from-faces' when the alignments file will be " +"generated in the specified faces folder." +msgstr "" +"Полный путь к обрабатываемому файлу выравниваний. Если вы ввели 'frames_dir' " +"и не указали этот параметр, процесс попытается найти файл выравнивания в " +"месте по умолчанию. Все задания требуют файл выравнивания, за исключением " +"задания 'from-faces', когда файл выравнивания будет создан в указанной папке " +"с лицами." + +#: tools/alignments/cli.py:142 +msgid "Directory containing source frames that faces were extracted from." +msgstr "Папка, содержащая исходные кадры, из которых были извлечены лица." + +#: tools/alignments/cli.py:150 +msgid "" +"R|Run the aligmnents tool on multiple sources. The following jobs support " +"batch mode:\n" +"L|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " +"sort, spatial.\n" +"If batch mode is selected then the other options should be set as follows:\n" +"L|alignments_file: For 'sort' and 'spatial' this should point to the parent " +"folder containing the alignments files to be processed. For all other jobs " +"this option is ignored, and the alignments files must exist at their default " +"location relative to the original frames folder/video.\n" +"L|faces_dir: For 'from-faces' this should be a parent folder, containing sub-" +"folders of extracted faces from which to generate alignments files. For " +"'extract' this should be a parent folder where sub-folders will be created " +"for each extraction to be run. For all other jobs this option is ignored.\n" +"L|frames_dir: For 'draw', 'extract', 'missing-alignments', 'missing-frames' " +"and 'no-faces' this should be a parent folder containing video files or sub-" +"folders of images to perform the alignments job on. The alignments file " +"should exist at the default location. For all other jobs this option is " +"ignored." +msgstr "" +"R|Запуск инструмента выравнивания на нескольких источниках. Следующие " +"задания поддерживают пакетный режим:\n" +"L|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " +"sort, spatial.\n" +"Если выбран пакетный режим, то остальные опции должны быть установлены " +"следующим образом:\n" +"L|alignments_file: Для заданий 'sort' и 'spatial' этот параметр должен " +"указывать на родительскую папку, содержащую файлы выравниваний, которые " +"будут обрабатываться. Для всех остальных заданий этот параметр игнорируется, " +"и файлы выравнивания должны существовать в их расположении по умолчанию " +"относительно исходной папки кадров/видео.\n" +"L|faces_dir: Для 'from-faces' это должна быть родительская папка, содержащая " +"вложенные папки с извлеченными лицами, из которых будут сгенерированы файлы " +"выравнивания. Для 'extract' это должна быть родительская папка, в которой " +"будут создаваться вложенные папки для каждой выполняемой экстракции. Для " +"всех остальных заданий этот параметр игнорируется.\n" +"L|frames_dir: Для 'draw', 'extract', 'missing-alignments', 'missing-frames' " +"и 'no-faces' это должна быть родительская папка, содержащая видеофайлы или " +"вложенные папки изображений для выполнения задания выравнивания. Файл " +"выравнивания должен существовать в месте по умолчанию. Для всех остальных " +"заданий этот параметр игнорируется." + +#: tools/alignments/cli.py:176 tools/alignments/cli.py:188 +#: tools/alignments/cli.py:198 +msgid "extract" +msgstr "извлечение" + +#: tools/alignments/cli.py:178 +msgid "" +"[Extract only] Extract every 'nth' frame. This option will skip frames when " +"extracting faces. For example a value of 1 will extract faces from every " +"frame, a value of 10 will extract faces from every 10th frame." +msgstr "" +"[Только извлечение] Извлекать каждый \"n-й\" кадр. Этот параметр пропускает " +"кадры при извлечении лиц. Например, значение 1 будет извлекать лица из " +"каждого кадра, значение 10 будет извлекать лица из каждого 10-го кадра." + +#: tools/alignments/cli.py:189 +msgid "[Extract only] The output size of extracted faces." +msgstr "[Только извлечение] Выходной размер извлеченных лиц." + +#: tools/alignments/cli.py:200 +msgid "" +"[Extract only] Only extract faces that have been resized by this percent or " +"more to meet the specified extract size (`-sz`, `--size`). Useful for " +"excluding low-res images from a training set. Set to 0 to extract all faces. " +"Eg: For an extract size of 512px, A setting of 50 will only include faces " +"that have been resized from 256px or above. Setting to 100 will only extract " +"faces that have been resized from 512px or above. A setting of 200 will only " +"extract faces that have been downscaled from 1024px or above." +msgstr "" +"[Только извлечение] Извлекать только те лица, размер которых был изменен на " +"данный процент или более, чтобы соответствовать заданному размеру извлечения " +"(`-sz`, `--size`). Полезно для исключения изображений с низким разрешением " +"из обучающего набора. Установите значение 0, чтобы извлечь все лица. " +"Например: Для размера экстракта 512px, при установке значения 50 будут " +"извлечены только лица, размер которых был изменен с 256px или выше. При " +"значении 100 будут извлечены только лица, размер которых был изменен с 512px " +"или выше. При значении 200 будут извлечены только лица, уменьшенные с 1024px " +"или выше." + +#~ msgid "Directory containing extracted faces." +#~ msgstr "Папка, содержащая извлеченные лица." diff --git a/locales/ru/LC_MESSAGES/tools.effmpeg.cli.mo b/locales/ru/LC_MESSAGES/tools.effmpeg.cli.mo new file mode 100644 index 0000000000..b47b17d08d Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.effmpeg.cli.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.effmpeg.cli.po b/locales/ru/LC_MESSAGES/tools.effmpeg.cli.po new file mode 100644 index 0000000000..8322e90571 --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.effmpeg.cli.po @@ -0,0 +1,191 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:50+0000\n" +"PO-Revision-Date: 2024-03-29 00:08+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/effmpeg/cli.py:15 +msgid "This command allows you to easily execute common ffmpeg tasks." +msgstr "Эта команда позволяет легко выполнять общие задачи ffmpeg." + +#: tools/effmpeg/cli.py:52 +msgid "A wrapper for ffmpeg for performing image <> video converting." +msgstr "Обертка для ffmpeg для выполнения конвертации изображений <> видео." + +#: tools/effmpeg/cli.py:64 +msgid "" +"R|Choose which action you want ffmpeg ffmpeg to do.\n" +"L|'extract': turns videos into images \n" +"L|'gen-vid': turns images into videos \n" +"L|'get-fps' returns the chosen video's fps.\n" +"L|'get-info' returns information about a video.\n" +"L|'mux-audio' add audio from one video to another.\n" +"L|'rescale' resize video.\n" +"L|'rotate' rotate video.\n" +"L|'slice' cuts a portion of the video into a separate video file." +msgstr "" +"R|Выберите, какое действие вы хотите, чтобы выполнял ffmpeg.\n" +"L|'extract': превращает видео в изображения \n" +"L|'gen-vid': превращает изображения в видео. \n" +"L|'get-fps' возвращает частоту кадров в секунду выбранного видео.\n" +"L|'get-info': возвращает информацию о видео.\n" +"L|'mux-audio' добавляет звук из одного видео в другое.\n" +"L|'rescale' изменить размер видео.\n" +"L|'rotate' вращение видео.\n" +"L|'slice' вырезает часть видео в отдельный видеофайл." + +#: tools/effmpeg/cli.py:78 +msgid "Input file." +msgstr "Входной файл." + +#: tools/effmpeg/cli.py:79 tools/effmpeg/cli.py:86 tools/effmpeg/cli.py:100 +msgid "data" +msgstr "данные" + +#: tools/effmpeg/cli.py:89 +msgid "" +"Output file. If no output is specified then: if the output is meant to be a " +"video then a video called 'out.mkv' will be created in the input directory; " +"if the output is meant to be a directory then a directory called 'out' will " +"be created inside the input directory. Note: the chosen output file " +"extension will determine the file encoding." +msgstr "" +"Выходной файл. Если выходной файл не указан, то: если выходным файлом " +"является видео, то в каталоге ввода будет создан видеофайл с именем 'out." +"mkv'; если выходным файлом является каталог, то внутри каталога ввода будет " +"создан каталог с именем 'out'. Примечание: выбранное расширение выходного " +"файла определяет кодировку файла." + +#: tools/effmpeg/cli.py:102 +msgid "Path to reference video if 'input' was not a video." +msgstr "Путь к опорному видео, если 'input' не является видео." + +#: tools/effmpeg/cli.py:108 tools/effmpeg/cli.py:118 tools/effmpeg/cli.py:156 +#: tools/effmpeg/cli.py:185 +msgid "output" +msgstr "выход" + +#: tools/effmpeg/cli.py:110 +msgid "" +"Provide video fps. Can be an integer, float or fraction. Negative values " +"will will make the program try to get the fps from the input or reference " +"videos." +msgstr "" +"Предоставляет количество кадров в секунду. Может быть целым числом, " +"плавающей цифрой или дробью. Отрицательные значения заставят программу " +"попытаться получить fps из входного или опорного видео." + +#: tools/effmpeg/cli.py:120 +msgid "" +"Image format that extracted images should be saved as. '.bmp' will offer the " +"fastest extraction speed, but will take the most storage space. '.png' will " +"be slower but will take less storage." +msgstr "" +"Формат изображения, в котором должны быть сохранены извлеченные изображения. " +"'.bmp' обеспечивает самую высокую скорость извлечения, но занимает больше " +"всего места в памяти. '.png' будет медленнее, но займет меньше места." + +#: tools/effmpeg/cli.py:127 tools/effmpeg/cli.py:136 tools/effmpeg/cli.py:145 +msgid "clip" +msgstr "клип" + +#: tools/effmpeg/cli.py:129 +msgid "" +"Enter the start time from which an action is to be applied. Default: " +"00:00:00, in HH:MM:SS format. You can also enter the time with or without " +"the colons, e.g. 00:0000 or 026010." +msgstr "" +"Введите время начала, с которого будет применяться действие. По умолчанию: " +"00:00:00, в формате ЧЧ:ММ:СС. Вы также можете ввести время с двоеточием или " +"без него, например, 00:0000 или 026010." + +#: tools/effmpeg/cli.py:138 +msgid "" +"Enter the end time to which an action is to be applied. If both an end time " +"and duration are set, then the end time will be used and the duration will " +"be ignored. Default: 00:00:00, in HH:MM:SS." +msgstr "" +"Введите время окончания, до которого будет применяться действие. Если заданы " +"и время окончания, и продолжительность, то будет использоваться время " +"окончания, а продолжительность будет игнорироваться. По умолчанию: 00:00:00, " +"в формате ЧЧ:ММ:СС." + +#: tools/effmpeg/cli.py:147 +msgid "" +"Enter the duration of the chosen action, for example if you enter 00:00:10 " +"for slice, then the first 10 seconds after and including the start time will " +"be cut out into a new video. Default: 00:00:00, in HH:MM:SS format. You can " +"also enter the time with or without the colons, e.g. 00:0000 or 026010." +msgstr "" +"Введите продолжительность выбранного действия, например, если вы введете " +"00:00:10 для нарезки, то первые 10 секунд после начала и включая время " +"начала будут вырезаны в новое видео. По умолчанию: 00:00:00, в формате ЧЧ:ММ:" +"СС. Вы также можете ввести время с двоеточием или без него, например, " +"00:0000 или 026010." + +#: tools/effmpeg/cli.py:158 +msgid "" +"Mux the audio from the reference video into the input video. This option is " +"only used for the 'gen-vid' action. 'mux-audio' action has this turned on " +"implicitly." +msgstr "" +"Mux аудио из опорного видео во входное видео. Эта опция используется только " +"для действия 'gen-vid'. Действие 'mux-audio' включает эту опцию неявно." + +#: tools/effmpeg/cli.py:169 tools/effmpeg/cli.py:179 +msgid "rotate" +msgstr "поворот" + +#: tools/effmpeg/cli.py:171 +msgid "" +"Transpose the video. If transpose is set, then degrees will be ignored. For " +"cli you can enter either the number or the long command name, e.g. to use " +"(1, 90Clockwise) -tr 1 or -tr 90Clockwise" +msgstr "" +"Транспонировать видео. Если задано транспонирование, то градусы будут " +"игнорироваться. Для командой строки вы можете ввести либо число, либо " +"длинное имя команды, например, для использования (1, 90 по часовой стрелке) -" +"tr 1 или -tr 90 по часовой стрелке" + +#: tools/effmpeg/cli.py:180 +msgid "Rotate the video clockwise by the given number of degrees." +msgstr "Поверните видео по часовой стрелке на заданное количество градусов." + +#: tools/effmpeg/cli.py:187 +msgid "Set the new resolution scale if the chosen action is 'rescale'." +msgstr "Установите новый масштаб разрешения, если выбрано действие 'rescale'." + +#: tools/effmpeg/cli.py:192 tools/effmpeg/cli.py:200 +msgid "settings" +msgstr "настройки" + +#: tools/effmpeg/cli.py:194 +msgid "" +"Reduces output verbosity so that only serious errors are printed. If both " +"quiet and verbose are set, verbose will override quiet." +msgstr "" +"Уменьшает многословность вывода, чтобы выводились только серьезные ошибки. " +"Если заданы и quiet, и verbose, то verbose будет преобладать над quiet." + +#: tools/effmpeg/cli.py:202 +msgid "" +"Increases output verbosity. If both quiet and verbose are set, verbose will " +"override quiet." +msgstr "" +"Повышает точность вывода. Если заданы и quiet, и verbose, то verbose будет " +"преобладать над quiet." diff --git a/locales/ru/LC_MESSAGES/tools.manual.mo b/locales/ru/LC_MESSAGES/tools.manual.mo new file mode 100644 index 0000000000..6e724e6f9d Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.manual.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.manual.po b/locales/ru/LC_MESSAGES/tools.manual.po new file mode 100644 index 0000000000..2c74501b46 --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.manual.po @@ -0,0 +1,294 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR ORGANIZATION +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:55+0000\n" +"PO-Revision-Date: 2024-03-29 00:07+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"Generated-By: pygettext.py 1.5\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/manual/cli.py:13 +msgid "" +"This command lets you perform various actions on frames, faces and " +"alignments files using visual tools." +msgstr "" +"Эта команда позволяет выполнять различные действия с кадрами, гранями и " +"файлами выравнивания с помощью визуальных инструментов." + +#: tools/manual/cli.py:23 +msgid "" +"A tool to perform various actions on frames, faces and alignments files " +"using visual tools" +msgstr "" +"Инструмент для выполнения различных действий с кадрами, лицами и файлами " +"выравнивания с помощью визуальных инструментов" + +#: tools/manual/cli.py:35 tools/manual/cli.py:44 +msgid "data" +msgstr "данные" + +#: tools/manual/cli.py:38 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" +"Путь к файлу выравниваний для входных данных, если он не находится в месте " +"по умолчанию" + +#: tools/manual/cli.py:46 +msgid "" +"Video file or directory containing source frames that faces were extracted " +"from." +msgstr "" +"Видеофайл или папка, содержащая исходные кадры, из которых были извлечены " +"лица." + +#: tools/manual/cli.py:53 tools/manual/cli.py:62 +msgid "options" +msgstr "опции" + +#: tools/manual/cli.py:55 +msgid "" +"Force regeneration of the low resolution jpg thumbnails in the alignments " +"file." +msgstr "" +"Принудительное восстановление миниатюр jpg низкого разрешения в файле " +"выравнивания." + +#: tools/manual/cli.py:64 +msgid "" +"The process attempts to speed up generation of thumbnails by extracting from " +"the video in parallel threads. For some videos, this causes the caching " +"process to hang. If this happens, then set this option to generate the " +"thumbnails in a slower, but more stable single thread." +msgstr "" +"Процесс пытается ускорить генерацию эскизов путем извлечения из видео в " +"параллельных потоках. Для некоторых видео это приводит к зависанию процесса " +"кэширования. Если это происходит, установите этот параметр, чтобы " +"генерировать эскизы в более медленном, но более стабильном одном потоке." + +#: tools/manual\faceviewer\frame.py:163 +msgid "Display the landmarks mesh" +msgstr "Отображение сетки ориентиров" + +#: tools/manual\faceviewer\frame.py:164 +msgid "Display the mask" +msgstr "Отображение маски" + +#: tools/manual\frameviewer\editor\_base.py:628 +#: tools/manual\frameviewer\editor\landmarks.py:44 +#: tools/manual\frameviewer\editor\mask.py:75 +msgid "Magnify/Demagnify the View" +msgstr "Увеличение/уменьшение изображения" + +#: tools/manual\frameviewer\editor\bounding_box.py:33 +#: tools/manual\frameviewer\editor\extract_box.py:32 +msgid "Delete Face" +msgstr "Удалить лицо" + +#: tools/manual\frameviewer\editor\bounding_box.py:36 +msgid "" +"Bounding Box Editor\n" +"Edit the bounding box being fed into the aligner to recalculate the " +"landmarks.\n" +"\n" +" - Grab the corner anchors to resize the bounding box.\n" +" - Click and drag the bounding box to relocate.\n" +" - Click in empty space to create a new bounding box.\n" +" - Right click a bounding box to delete a face." +msgstr "" +"Редактор ограничительных рамок\n" +"Отредактируйте ограничивающую рамку, подаваемую в выравниватель, чтобы " +"пересчитать ориентиры.\n" +"\n" +"- Захватите угловые опоры, чтобы изменить размер ограничивающей рамки.\n" +" - Щелкните и перетащите ограничивающую рамку для перемещения.\n" +" - Щелкните в пустом пространстве, чтобы создать новую ограничивающую " +"рамку.\n" +"- Щелкните правой кнопкой мыши ограничительную рамку, чтобы удалить лицо." + +#: tools/manual\frameviewer\editor\bounding_box.py:70 +msgid "" +"Aligner to use. FAN will obtain better alignments, but cv2-dnn can be useful " +"if FAN cannot get decent alignments and you want to set a base to edit from." +msgstr "" +"Выравниватель для использования. FAN получит лучшие выравнивания, но cv2-dnn " +"может быть полезен, если FAN не может получить достойные выравнивания, и вы " +"хотите установить базу для редактирования." + +#: tools/manual\frameviewer\editor\bounding_box.py:83 +msgid "" +"Normalization method to use for feeding faces to the aligner. This can help " +"the aligner better align faces with difficult lighting conditions. Different " +"methods will yield different results on different sets. NB: This does not " +"impact the output face, just the input to the aligner.\n" +"\tnone: Don't perform normalization on the face.\n" +"\tclahe: Perform Contrast Limited Adaptive Histogram Equalization on the " +"face.\n" +"\thist: Equalize the histograms on the RGB channels.\n" +"\tmean: Normalize the face colors to the mean." +msgstr "" +"Метод нормализации, используемый для подачи лиц в выравниватель. Это может " +"помочь выравнивателю лучше выравнивать лица при сложных условиях освещения. " +"Различные методы дают разные результаты на разных наборах. Примечание: Это " +"не влияет на выходное лицо, только на входное в выравниватель.\n" +"\tnone: Не выполнять нормализацию лица.\n" +"\tclahe: Выполнить для лица адаптивную гистограммную эквализацию с " +"ограничением контраста.\n" +"\thist: Выравнивание гистограмм по каналам RGB.\n" +"\tmean: Нормализовать цвета лица к среднему значению." + +#: tools/manual\frameviewer\editor\extract_box.py:35 +msgid "" +"Extract Box Editor\n" +"Move the extract box that has been generated by the aligner. Click and " +"drag:\n" +"\n" +" - Inside the bounding box to relocate the landmarks.\n" +" - The corner anchors to resize the landmarks.\n" +" - Outside of the corners to rotate the landmarks." +msgstr "" +"Редактор поля извлечения\n" +"Переместите поле извлечения, созданное выравнивателем. Нажмите и " +"перетащите:\n" +"\n" +" - Внутри ограничивающей рамки для перемещения опорных точек.\n" +"- По угловым опорам для изменения размера опорных точек.\n" +"- За пределами углов, чтобы повернуть опорные точки." + +#: tools/manual\frameviewer\editor\landmarks.py:27 +msgid "" +"Landmark Point Editor\n" +"Edit the individual landmark points.\n" +"\n" +" - Click and drag individual points to relocate.\n" +" - Draw a box to select multiple points to relocate." +msgstr "" +"Редактор точек ориентира\n" +"Редактирование отдельных опорных точек.\n" +"\n" +" - Щелкните и перетащите отдельные точки для перемещения.\n" +" - Нарисуйте рамку, чтобы выбрать несколько точек для перемещения." + +#: tools/manual\frameviewer\editor\mask.py:33 +msgid "" +"Mask Editor\n" +"Edit the mask.\n" +" - NB: For Landmark based masks (e.g. components/extended) it is better to " +"make sure the landmarks are correct rather than editing the mask directly. " +"Any change to the landmarks after editing the mask will override your manual " +"edits." +msgstr "" +"Редактор маски\n" +"Отредактировать маску.\n" +" - Примечание: Для масок, основанных на ориентирах (например, компоненты/" +"расширенные), лучше убедиться в правильности ориентиров, а не редактировать " +"маску напрямую. Любое изменение ориентиров после редактирования маски " +"отменит ваши ручные правки." + +#: tools/manual\frameviewer\editor\mask.py:77 +msgid "Draw Tool" +msgstr "Инструмент рисования" + +#: tools/manual\frameviewer\editor\mask.py:78 +msgid "Erase Tool" +msgstr "Инструмент \"Ластик\"" + +#: tools/manual\frameviewer\editor\mask.py:97 +msgid "Select which mask to edit" +msgstr "Выбрать, какую маску редактировать" + +#: tools/manual\frameviewer\editor\mask.py:104 +msgid "Set the brush size. ([ - decrease, ] - increase)" +msgstr "Установить размер кисти. ([ - уменьшение, ] - увеличение)" + +#: tools/manual\frameviewer\editor\mask.py:111 +msgid "Select the brush cursor color." +msgstr "Установить цвет курсора кисти." + +#: tools/manual\frameviewer\frame.py:78 +msgid "Play/Pause (SPACE)" +msgstr "Воспроизвести/Приостановить (ПРОБЕЛ)" + +#: tools/manual\frameviewer\frame.py:79 +msgid "Go to First Frame (HOME)" +msgstr "Перейти к первому кадру (HOME)" + +#: tools/manual\frameviewer\frame.py:80 +msgid "Go to Previous Frame (Z)" +msgstr "Перейти к предыдущему кадру (Z/Я)" + +#: tools/manual\frameviewer\frame.py:81 +msgid "Go to Next Frame (X)" +msgstr "Перейти к следующему кадру (X/Ч)" + +#: tools/manual\frameviewer\frame.py:82 +msgid "Go to Last Frame (END)" +msgstr "Перейти к последнему кадру (END)" + +#: tools/manual\frameviewer\frame.py:83 +msgid "Extract the faces to a folder... (Ctrl+E)" +msgstr "Извлечь лица в папку... (Ctrl+E)" + +#: tools/manual\frameviewer\frame.py:84 +msgid "Save the Alignments file (Ctrl+S)" +msgstr "Сохранить файл выравнивания (Ctrl+S)" + +#: tools/manual\frameviewer\frame.py:85 +msgid "Filter Frames to only those Containing the Selected Item (F)" +msgstr "Отфильтровать кадры, содержащие только выбранный элемент (F/А)" + +#: tools/manual\frameviewer\frame.py:86 +msgid "" +"Set the distance from an 'average face' to be considered misaligned. Higher " +"distances are more restrictive" +msgstr "" +"Установить расстояние от \"среднего лица\", на котором оно будет считаться " +"смещенным. Большие расстояния являются более ограничительными" + +#: tools/manual\frameviewer\frame.py:391 +msgid "View alignments" +msgstr "Просмотреть выравнивания" + +#: tools/manual\frameviewer\frame.py:392 +msgid "Bounding box editor" +msgstr "Редактор ограничительных рамок" + +#: tools/manual\frameviewer\frame.py:393 +msgid "Location editor" +msgstr "Редактор расположения" + +#: tools/manual\frameviewer\frame.py:394 +msgid "Mask editor" +msgstr "Редактор маски" + +#: tools/manual\frameviewer\frame.py:395 +msgid "Landmark point editor" +msgstr "Редактор точек ориентира" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Next" +msgstr "Следующий" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Previous" +msgstr "Предыдущий" + +#: tools/manual\frameviewer\frame.py:481 +msgid "Revert to saved Alignments ({})" +msgstr "Откатить до сохраненных выравниваний ({})" + +#: tools/manual\frameviewer\frame.py:487 +msgid "Copy {} Alignments ({})" +msgstr "Копировать {} выравнивания ({})" diff --git a/locales/ru/LC_MESSAGES/tools.mask.cli.mo b/locales/ru/LC_MESSAGES/tools.mask.cli.mo new file mode 100644 index 0000000000..89631d0cd9 Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.mask.cli.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.mask.cli.po b/locales/ru/LC_MESSAGES/tools.mask.cli.po new file mode 100644 index 0000000000..6cabf81c53 --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.mask.cli.po @@ -0,0 +1,331 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-06-28 13:45+0100\n" +"PO-Revision-Date: 2024-06-28 13:48+0100\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.4\n" + +#: tools/mask/cli.py:15 +msgid "" +"This tool allows you to generate, import, export or preview masks for " +"existing alignments." +msgstr "" +"Этот инструмент позволяет создавать, импортировать, экспортировать или " +"просматривать маски для существующих трасс." + +#: tools/mask/cli.py:25 +msgid "" +"Mask tool\n" +"Generate, import, export or preview masks for existing alignments files." +msgstr "" +"Инструмент \"Маска\"\n" +"Создавайте, импортируйте, экспортируйте или просматривайте маски для " +"существующих файлов трасс." + +#: tools/mask/cli.py:35 tools/mask/cli.py:47 tools/mask/cli.py:58 +#: tools/mask/cli.py:69 +msgid "data" +msgstr "данные" + +#: tools/mask/cli.py:39 +msgid "" +"Full path to the alignments file that contains the masks if not at the " +"default location. NB: If the input-type is faces and you wish to update the " +"corresponding alignments file, then you must provide a value here as the " +"location cannot be automatically detected." +msgstr "" +"Полный путь к файлу выравниваний для добавления маски, если он не находится " +"в месте по умолчанию. Примечание: Если input-type - лица, и вы хотите " +"обновить соответствующий файл выравнивания, то вы должны указать значение " +"здесь, так как местоположение не может быть определено автоматически." + +#: tools/mask/cli.py:51 +msgid "Directory containing extracted faces, source frames, or a video file." +msgstr "Папка, содержащая извлеченные лица, исходные кадры или видеофайл." + +#: tools/mask/cli.py:61 +msgid "" +"R|Whether the `input` is a folder of faces or a folder frames/video\n" +"L|faces: The input is a folder containing extracted faces.\n" +"L|frames: The input is a folder containing frames or is a video" +msgstr "" +"R|Выбирается ли \"вход\" как папка лиц или как папка кадров/видео\n" +"L|faces: Входом является папка, содержащая извлеченные лица.\n" +"L|frames: Входом является папка с кадрами или видео" + +#: tools/mask/cli.py:71 +msgid "" +"R|Run the mask tool on multiple sources. If selected then the other options " +"should be set as follows:\n" +"L|input: A parent folder containing either all of the video files to be " +"processed, or containing sub-folders of frames/faces.\n" +"L|output-folder: If provided, then sub-folders will be created within the " +"given location to hold the previews for each input.\n" +"L|alignments: Alignments field will be ignored for batch processing. The " +"alignments files must exist at the default location (for frames). For batch " +"processing of masks with 'faces' as the input type, then only the PNG header " +"within the extracted faces will be updated." +msgstr "" +"R|Запустить инструмент маски на нескольких источниках. Если выбрано, то " +"остальные параметры должны быть установлены следующим образом:\n" +"L|input: Родительская папка, содержащая либо все видеофайлы для обработки, " +"либо содержащая вложенные папки кадров/лиц.\n" +"L|output-folder: Если указано, то в заданном месте будут созданы вложенные " +"папки для хранения превью для каждого входа.\n" +"L|alignments: Поле выравнивания будет игнорироваться при пакетной обработке. " +"Файлы выравнивания должны существовать в месте по умолчанию (для кадров). " +"При пакетной обработке масок с типом входа \"лица\" будут обновлены только " +"заголовки PNG в извлеченных лицах." + +#: tools/mask/cli.py:87 tools/mask/cli.py:119 +msgid "process" +msgstr "обработка" + +#: tools/mask/cli.py:89 +msgid "" +"R|Masker to use.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members. Profile faces " +"may result in sub-par performance." +msgstr "" +"R|Маскер для использования.\n" +"L|bisenet-fp: Относительно легкая маска на основе NN, которая обеспечивает " +"более точный контроль над маскируемой областью, включая полное маскирование " +"головы (настраивается в настройках маски).\n" +"L|components: Маска, разработанная для сегментации лица на основе " +"расположения ориентиров. Для создания маски вокруг внешних ориентиров " +"строится выпуклая оболочка.\n" +"L|custom (пользовательская): Фиктивная маска, которая заполняет область " +"маски всеми 1 или 0 (настраивается в настройках). Она необходима только в " +"том случае, если вы собираетесь вручную редактировать пользовательские маски " +"в ручном инструменте. Эта маска не использует GPU.\n" +"L|extended: Маска предназначена для сегментации лица на основе расположения " +"ориентиров. Выпуклая оболочка строится вокруг внешних ориентиров, и маска " +"расширяется вверх на лоб.\n" +"L|vgg-clear: Маска предназначена для интеллектуальной сегментации " +"преимущественно фронтальных лиц без препятствий. Профильные лица и " +"препятствия могут привести к снижению производительности.\n" +"L|vgg-obstructed: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была специально обучена " +"распознавать некоторые препятствия на лице (руки и очки). Лица в профиль " +"могут иметь низкую производительность.\n" +"L|unet-dfl: Маска, разработанная для интеллектуальной сегментации " +"преимущественно фронтальных лиц. Модель маски была обучена членами " +"сообщества и для дальнейшего описания нуждается в тестировании. Профильные " +"лица могут иметь низкую производительность." + +#: tools/mask/cli.py:121 +msgid "" +"R|The Mask tool process to perform.\n" +"L|all: Update the mask for all faces in the alignments file for the selected " +"'masker'.\n" +"L|missing: Create a mask for all faces in the alignments file where a mask " +"does not previously exist for the selected 'masker'.\n" +"L|output: Don't update the masks, just output the selected 'masker' for " +"review/editing in external tools to the given output folder.\n" +"L|import: Import masks that have been edited outside of faceswap into the " +"alignments file. Note: 'custom' must be the selected 'masker' and the masks " +"must be in the same format as the 'input-type' (frames or faces)" +msgstr "" +"R|El proceso de la herramienta Máscara a realizar.\n" +"L|all: actualiza la máscara de todas las caras en el archivo de alineaciones " +"para el 'masker' seleccionado.\n" +"L|missing: crea una máscara para todas las caras en el archivo de " +"alineaciones donde no existe previamente una máscara para el 'masker' " +"seleccionado.\n" +"L|output: no actualice las máscaras, simplemente envíe el 'masker' " +"seleccionado para su revisión/edición en herramientas externas a la carpeta " +"de salida proporcionada.\n" +"L|import: importa máscaras que se han editado fuera de faceswap al archivo " +"de alineaciones. Nota: 'custom' debe ser el 'masker' seleccionado y las " +"máscaras deben tener el mismo formato que el 'input-type' (frames o faces)" + +#: tools/mask/cli.py:135 tools/mask/cli.py:154 tools/mask/cli.py:176 +msgid "import" +msgstr "Импортировать" + +#: tools/mask/cli.py:137 +msgid "" +"R|Import only. The path to the folder that contains masks to be imported.\n" +"L|How the masks are provided is not important, but they will be stored, " +"internally, as 8-bit grayscale images.\n" +"L|If the input are images, then the masks must be named exactly the same as " +"input frames/faces (excluding the file extension).\n" +"L|If the input is a video file, then the filename of the masks is not " +"important but should contain the frame number at the end of the filename " +"(but before the file extension). The frame number can be separated from the " +"rest of the filename by any non-numeric character and can be padded by any " +"number of zeros. The frame number must correspond correctly to the frame " +"number in the original video (starting from frame 1)." +msgstr "" +"R|Только импорт. Путь к папке, содержащей маски для импорта.\n" +"L|Как предоставляются маски, не важно, но они будут храниться внутри как 8-" +"битные изображения в оттенках серого.\n" +"L|Если входными данными являются изображения, то имена масок должны быть " +"точно такими же, как у входных кадров/лиц (за исключением расширения " +"файла).\n" +"L|Если входной файл представляет собой видеофайл, то имя файла масок не " +"важно, но должно содержать номер кадра в конце имени файла (но перед " +"расширением файла). Номер кадра может быть отделен от остальной части имени " +"файла любым нечисловым символом и дополнен любым количеством нулей. Номер " +"кадра должен правильно соответствовать номеру кадра в исходном видео " +"(начиная с кадра 1)." + +#: tools/mask/cli.py:156 +msgid "" +"R|Import/Output only. When importing masks, this is the centering to use. " +"For output this is only used for outputting custom imported masks, and " +"should correspond to the centering used when importing the mask. Note: For " +"any job other than 'import' and 'output' this option is ignored as mask " +"centering is handled internally.\n" +"L|face: Centers the mask on the center of the face, adjusting for pitch and " +"yaw. Outside of requirements for full head masking/training, this is likely " +"to be the best choice.\n" +"L|head: Centers the mask on the center of the head, adjusting for pitch and " +"yaw. Note: You should only select head centering if you intend to include " +"the full head (including hair) within the mask and are looking to train a " +"full head model.\n" +"L|legacy: The 'original' extraction technique. Centers the mask near the of " +"the nose with and crops closely to the face. Can result in the edges of the " +"mask appearing outside of the training area." +msgstr "" +"R|Только импорт/вывод. При импорте масок это центрирование для " +"использования. Для вывода это используется только для вывода " +"пользовательских импортированных масок и должно соответствовать " +"центрированию, используемому при импорте маски. Примечание: для любого " +"задания, кроме «импорта» и «вывода», эта опция игнорируется, поскольку " +"центрирование маски обрабатывается внутренне.\n" +"L|face: центрирует маску по центру лица с регулировкой угла наклона и " +"отклонения от курса. Помимо требований к полной маскировке/тренировке " +"головы, это, вероятно, будет лучшим выбором.\n" +"L|head: центрирует маску по центру головы с регулировкой угла наклона и " +"отклонения от курса. Примечание. Выбирать центрирование головы следует " +"только в том случае, если вы собираетесь включить в маску всю голову " +"(включая волосы) и хотите обучить модель полной головы.\n" +"L|legacy: «Оригинальная» техника извлечения. Центрирует маску возле носа и " +"приближает ее к лицу. Это может привести к тому, что края маски окажутся за " +"пределами тренировочной зоны." + +#: tools/mask/cli.py:181 +msgid "" +"Import only. The size, in pixels to internally store the mask at.\n" +"The default is 128 which is fine for nearly all usecases. Larger sizes will " +"result in larger alignments files and longer processing." +msgstr "" +"Только импорт. Размер в пикселях для внутреннего хранения маски.\n" +"Значение по умолчанию — 128, что подходит практически для всех случаев " +"использования. Большие размеры приведут к увеличению размера файлов " +"выравниваний и более длительной обработке." + +#: tools/mask/cli.py:189 tools/mask/cli.py:197 tools/mask/cli.py:211 +#: tools/mask/cli.py:225 tools/mask/cli.py:235 +msgid "output" +msgstr "вывод" + +#: tools/mask/cli.py:191 +msgid "" +"Optional output location. If provided, a preview of the masks created will " +"be output in the given folder." +msgstr "" +"Необязательное местоположение вывода. Если указано, предварительный просмотр " +"созданных масок будет выведен в указанную папку." + +#: tools/mask/cli.py:202 +msgid "" +"Apply gaussian blur to the mask output. Has the effect of smoothing the " +"edges of the mask giving less of a hard edge. the size is in pixels. This " +"value should be odd, if an even number is passed in then it will be rounded " +"to the next odd number. NB: Only effects the output preview. Set to 0 for off" +msgstr "" +"Применяет гауссово размытие к выходу маски. Сглаживает края маски, делая их " +"менее жесткими. размер в пикселях. Это значение должно быть нечетным, если " +"передано четное число, то оно будет округлено до следующего нечетного числа. " +"Примечание: влияет только на предварительный просмотр. Установите значение 0 " +"для выключения" + +#: tools/mask/cli.py:216 +msgid "" +"Helps reduce 'blotchiness' on some masks by making light shades white and " +"dark shades black. Higher values will impact more of the mask. NB: Only " +"effects the output preview. Set to 0 for off" +msgstr "" +"Помогает уменьшить \"пятнистость\" на некоторых масках, делая светлые " +"оттенки белыми, а темные - черными. Более высокие значения влияют на большую " +"часть маски. Примечание: влияет только на предварительный просмотр. " +"Установите значение 0 для выключения" + +#: tools/mask/cli.py:227 +msgid "" +"R|How to format the output when processing is set to 'output'.\n" +"L|combined: The image contains the face/frame, face mask and masked face.\n" +"L|masked: Output the face/frame as rgba image with the face masked.\n" +"L|mask: Only output the mask as a single channel image." +msgstr "" +"R|Как форматировать вывод, когда обработка установлена на 'output'.\n" +"L|combined: Изображение содержит лицо/кадр, маску лица и маскированное " +"лицо.\n" +"L|masked: Вывести лицо/кадр как изображение rgba с маскированным лицом.\n" +"L|mask: Выводить только маску как одноканальное изображение." + +#: tools/mask/cli.py:237 +msgid "" +"R|Whether to output the whole frame or only the face box when using output " +"processing. Only has an effect when using frames as input." +msgstr "" +"R|Выводить ли весь кадр или только поле лица при использовании выходной " +"обработки. Имеет значение только при использовании кадров в качестве входных " +"данных." + +#~ msgid "" +#~ "R|Whether to update all masks in the alignments files, only those faces " +#~ "that do not already have a mask of the given `mask type` or just to " +#~ "output the masks to the `output` location.\n" +#~ "L|all: Update the mask for all faces in the alignments file.\n" +#~ "L|missing: Create a mask for all faces in the alignments file where a " +#~ "mask does not previously exist.\n" +#~ "L|output: Don't update the masks, just output them for review in the " +#~ "given output folder." +#~ msgstr "" +#~ "R|Обновлять ли все маски в файлах выравнивания, только те лица, которые " +#~ "еще не имеют маски заданного `mask type` или просто выводить маски в " +#~ "место `output`.\n" +#~ "L|all: Обновить маску для всех лиц в файле выравнивания.\n" +#~ "L|missing: Создать маску для всех лиц в файле выравнивания, для которых " +#~ "маска ранее не существовала.\n" +#~ "L|output: Не обновлять маски, а просто вывести их для просмотра в " +#~ "указанную выходную папку." diff --git a/locales/ru/LC_MESSAGES/tools.model.cli.mo b/locales/ru/LC_MESSAGES/tools.model.cli.mo new file mode 100644 index 0000000000..37b7545821 Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.model.cli.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.model.cli.po b/locales/ru/LC_MESSAGES/tools.model.cli.po new file mode 100644 index 0000000000..bef71ab233 --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.model.cli.po @@ -0,0 +1,92 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:51+0000\n" +"PO-Revision-Date: 2024-03-29 00:07+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/model/cli.py:13 +msgid "This tool lets you perform actions on saved Faceswap models." +msgstr "" +"Этот инструмент позволяет выполнять действия над сохраненными моделями " +"Faceswap." + +#: tools/model/cli.py:22 +msgid "A tool for performing actions on Faceswap trained model files" +msgstr "" +"Инструмент для выполнения действий над файлами обученных моделей Faceswap" + +#: tools/model/cli.py:34 +msgid "" +"Model directory. A directory containing the model you wish to perform an " +"action on." +msgstr "" +"Папка модели. Папка, содержащая модель, над которой вы хотите выполнить " +"действие." + +#: tools/model/cli.py:43 +msgid "" +"R|Choose which action you want to perform.\n" +"L|'inference' - Create an inference only copy of the model. Strips any " +"layers from the model which are only required for training. NB: This is for " +"exporting the model for use in external applications. Inference generated " +"models cannot be used within Faceswap. See the 'format' option for " +"specifying the model output format.\n" +"L|'nan-scan' - Scan the model file for NaNs or Infs (invalid data).\n" +"L|'restore' - Restore a model from backup." +msgstr "" +"R|Выберите действие, которое вы хотите выполнить.\n" +"L|'inference' - Создать копию модели только для проведения расчетов. Удаляет " +"из модели все слои, которые нужны только для обучения. Примечание: Эта " +"функция предназначена для экспорта модели для использования во внешних " +"приложениях. Модели, созданные в режиме вывода, не могут быть использованы в " +"Faceswap. См. опцию 'format' для указания формата вывода модели.\n" +"L|'nan-scan' - Проверить файл модели на наличие NaNs или Infs (недопустимых " +"данных).\n" +"L|'restore' - Восстановить модель из резервной копии." + +#: tools/model/cli.py:57 tools/model/cli.py:69 +msgid "inference" +msgstr "вывод" + +#: tools/model/cli.py:59 +msgid "" +"R|The format to save the model as. Note: Only used for 'inference' job.\n" +"L|'h5' - Standard Keras H5 format. Does not store any custom layer " +"information. Layers will need to be loaded from Faceswap to use.\n" +"L|'saved-model' - Tensorflow's Saved Model format. Contains all information " +"required to load the model outside of Faceswap." +msgstr "" +"R|Формат для сохранения модели. Примечание: Используется только для задания " +"'inference'.\n" +"L||'h5' - Стандартный формат Keras H5. Не хранит никакой информации о " +"пользовательских слоях. Для использования слои должны быть загружены из " +"Faceswap.\n" +"L|'saved-model' - формат сохраненной модели Tensorflow. Содержит всю " +"информацию, необходимую для загрузки модели вне Faceswap." + +#: tools/model/cli.py:71 +#, fuzzy +#| msgid "" +#| "Only used for 'inference' job. Generate the inference model for B -> A " +#| "instead of A -> B." +msgid "" +"Only used for 'inference' job. Generate the inference model for B -> A " +"instead of A -> B." +msgstr "" +"Используется только для задания 'inference'. Создайте модель вывода для B -> " +"A вместо A -> B." diff --git a/locales/ru/LC_MESSAGES/tools.preview.mo b/locales/ru/LC_MESSAGES/tools.preview.mo new file mode 100644 index 0000000000..780e7173eb Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.preview.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.preview.po b/locales/ru/LC_MESSAGES/tools.preview.po new file mode 100644 index 0000000000..ebcaea18d8 --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.preview.po @@ -0,0 +1,93 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: 2024-03-29 00:06+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/preview/cli.py:15 +msgid "This command allows you to preview swaps to tweak convert settings." +msgstr "" +"Эта команда позволяет просматривать замены для настройки параметров " +"конвертирования." + +#: tools/preview/cli.py:30 +msgid "" +"Preview tool\n" +"Allows you to configure your convert settings with a live preview" +msgstr "" +"Инструмент предпросмотра\n" +"Позволяет настраивать параметры конвертации с помощью предварительного " +"просмотра в реальном времени" + +#: tools/preview/cli.py:47 tools/preview/cli.py:57 tools/preview/cli.py:65 +msgid "data" +msgstr "данные" + +#: tools/preview/cli.py:50 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file." +msgstr "" +"Входная папка или видео. Либо папка, содержащая файлы изображений, которые " +"необходимо обработать, либо путь к видеофайлу." + +#: tools/preview/cli.py:60 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" +"Путь к файлу выравниваний для входных данных, если он не находится в месте " +"по умолчанию" + +#: tools/preview/cli.py:68 +msgid "" +"Model directory. A directory containing the trained model you wish to " +"process." +msgstr "" +"Папка модели. Папка, содержащая обученную модель, которую вы хотите " +"обработать." + +#: tools/preview/cli.py:74 +msgid "Swap the model. Instead of A -> B, swap B -> A" +msgstr "Поменять местами модели. Вместо A -> B заменить B -> A" + +#: tools/preview/control_panels.py:510 +msgid "Save full config" +msgstr "Сохранить полную конфигурацию" + +#: tools/preview/control_panels.py:513 +msgid "Reset full config to default values" +msgstr "Сбросить полную конфигурацию до заводских значений" + +#: tools/preview/control_panels.py:516 +msgid "Reset full config to saved values" +msgstr "Сбросить полную конфигурацию до сохраненных значений" + +#: tools/preview/control_panels.py:667 +#, python-brace-format +msgid "Save {title} config" +msgstr "Сохранить конфигурацию {title}" + +#: tools/preview/control_panels.py:670 +#, python-brace-format +msgid "Reset {title} config to default values" +msgstr "Сбросить полную конфигурацию {title} до заводских значений" + +#: tools/preview/control_panels.py:673 +#, python-brace-format +msgid "Reset {title} config to saved values" +msgstr "Сбросить полную конфигурацию {title} до сохраненных значений" diff --git a/locales/ru/LC_MESSAGES/tools.sort.cli.mo b/locales/ru/LC_MESSAGES/tools.sort.cli.mo new file mode 100644 index 0000000000..6b832be91e Binary files /dev/null and b/locales/ru/LC_MESSAGES/tools.sort.cli.mo differ diff --git a/locales/ru/LC_MESSAGES/tools.sort.cli.po b/locales/ru/LC_MESSAGES/tools.sort.cli.po new file mode 100644 index 0000000000..9b76d494ae --- /dev/null +++ b/locales/ru/LC_MESSAGES/tools.sort.cli.po @@ -0,0 +1,414 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +msgid "" +msgstr "" +"Project-Id-Version: \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: 2024-03-29 00:06+0000\n" +"Last-Translator: \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && " +"n%10<=4 && (n%100<12 || n%100>14) ? 1 : 2);\n" +"X-Generator: Poedit 3.4.2\n" + +#: tools/sort/cli.py:15 +msgid "This command lets you sort images using various methods." +msgstr "Эта команда позволяет сортировать изображения различными методами." + +#: tools/sort/cli.py:21 +msgid "" +" Adjust the '-t' ('--threshold') parameter to control the strength of " +"grouping." +msgstr "" +" Настройте параметр '-t' ('--threshold') для контроля силы группировки." + +#: tools/sort/cli.py:22 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the percentage of color pixels " +"that appear in the image." +msgstr "" +" Настройте параметр '-b' ('--bins') для управления количеством корзинок для " +"группировки. Каждое изображение распределяется по корзинкам в зависимости от " +"процента цветных пикселей, присутствующих в изображении." + +#: tools/sort/cli.py:25 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the number of degrees the face " +"is orientated from center." +msgstr "" +" Настройте параметр '-b' ('--bins') для управления количеством корзинок для " +"группировки. Каждое изображение распределяется по корзинам по количеству " +"градусов, на которые лицо ориентировано от центра." + +#: tools/sort/cli.py:28 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. The minimum and maximum values are taken for the chosen sort " +"metric. The bins are then populated with the results from the group sorting." +msgstr "" +" Настройте параметр '-b' ('--bins') для управления количеством корзинок для " +"группировки. Для выбранной метрики сортировки берутся минимальное и " +"максимальное значения. Затем корзины заполняются результатами групповой " +"сортировки." + +#: tools/sort/cli.py:32 +msgid "faces by blurriness." +msgstr "лица по размытости." + +#: tools/sort/cli.py:33 +msgid "faces by fft filtered blurriness." +msgstr "лица по размытости с фильтрацией fft." + +#: tools/sort/cli.py:34 +msgid "" +"faces by the estimated distance of the alignments from an 'average' face. " +"This can be useful for eliminating misaligned faces. Sorts from most like an " +"average face to least like an average face." +msgstr "" +"лица по оценочному расстоянию выравнивания от \"среднего\" лица. Это может " +"быть полезно для устранения неправильно расположенных лиц. Сортирует от " +"наиболее похожего на среднее лицо к наименее похожему на среднее лицо." + +#: tools/sort/cli.py:37 +msgid "" +"faces using VGG Face2 by face similarity. This uses a pairwise clustering " +"algorithm to check the distances between 512 features on every face in your " +"set and order them appropriately." +msgstr "" +"лиц с помощью VGG Face2 по сходству лиц. При этом используется алгоритм " +"парной кластеризации для проверки расстояний между 512 признаками на каждом " +"лице в вашем наборе и их упорядочивания соответствующим образом." + +#: tools/sort/cli.py:40 +msgid "faces by their landmarks." +msgstr "лица по их ориентирам." + +#: tools/sort/cli.py:41 +msgid "Like 'face-cnn' but sorts by dissimilarity." +msgstr "Как 'face-cnn', но сортирует по непохожести." + +#: tools/sort/cli.py:42 +msgid "faces by Yaw (rotation left to right)." +msgstr "лица по Yaw (вращение слева направо)." + +#: tools/sort/cli.py:43 +msgid "faces by Pitch (rotation up and down)." +msgstr "лица по Pitch (вращение вверх и вниз)." + +#: tools/sort/cli.py:44 +msgid "" +"faces by Roll (rotation). Aligned faces should have a roll value close to " +"zero. The further the Roll value from zero the higher liklihood the face is " +"misaligned." +msgstr "" +"грани по Roll (повороту). Выровненные грани должны иметь значение Roll, " +"близкое к нулю. Чем дальше значение Roll от нуля, тем выше вероятность того, " +"что лицо неправильно выровнено." + +#: tools/sort/cli.py:46 +msgid "faces by their color histogram." +msgstr "лица по их цветовой гистограмме." + +#: tools/sort/cli.py:47 +msgid "Like 'hist' but sorts by dissimilarity." +msgstr "Как 'hist', но сортирует по непохожести." + +#: tools/sort/cli.py:48 +msgid "" +"images by the average intensity of the converted grayscale color channel." +msgstr "" +"изображения по средней интенсивности преобразованного полутонового цветового " +"канала." + +#: tools/sort/cli.py:49 +msgid "" +"images by their number of black pixels. Useful when faces are near borders " +"and a large part of the image is black." +msgstr "" +"изображения по количеству черных пикселей. Полезно, когда лица находятся " +"вблизи границ и большая часть изображения черная." + +#: tools/sort/cli.py:51 +msgid "" +"images by the average intensity of the converted Y color channel. Bright " +"lighting and oversaturated images will be ranked first." +msgstr "" +"изображений по средней интенсивности преобразованного цветового канала Y. " +"Яркое освещение и перенасыщенные изображения будут ранжироваться в первую " +"очередь." + +#: tools/sort/cli.py:53 +msgid "" +"images by the average intensity of the converted Cg color channel. Green " +"images will be ranked first and red images will be last." +msgstr "" +"изображений по средней интенсивности преобразованного цветового канала Cg. " +"Зеленые изображения занимают первое место, а красные - последнее." + +#: tools/sort/cli.py:55 +msgid "" +"images by the average intensity of the converted Co color channel. Orange " +"images will be ranked first and blue images will be last." +msgstr "" +"изображений по средней интенсивности преобразованного цветового канала Co. " +"Оранжевые изображения занимают первое место, а синие - последнее." + +#: tools/sort/cli.py:57 +msgid "" +"images by their size in the original frame. Faces further from the camera " +"and from lower resolution sources will be sorted first, whilst faces closer " +"to the camera and from higher resolution sources will be sorted last." +msgstr "" +"изображения по их размеру в исходном кадре. Лица, расположенные дальше от " +"камеры и полученные из источников с низким разрешением, будут отсортированы " +"первыми, а лица, расположенные ближе к камере и полученные из источников с " +"высоким разрешением, будут отсортированы последними." + +#: tools/sort/cli.py:81 +msgid "Sort faces using a number of different techniques" +msgstr "Сортировка лиц с использованием различных методов" + +#: tools/sort/cli.py:91 tools/sort/cli.py:98 tools/sort/cli.py:110 +#: tools/sort/cli.py:150 +msgid "data" +msgstr "данные" + +#: tools/sort/cli.py:92 +msgid "Input directory of aligned faces." +msgstr "Входная папка соотнесенных лиц." + +#: tools/sort/cli.py:100 +msgid "" +"Output directory for sorted aligned faces. If not provided and 'keep' is " +"selected then a new folder called 'sorted' will be created within the input " +"folder to house the output. If not provided and 'keep' is not selected then " +"the images will be sorted in-place, overwriting the original contents of the " +"'input_dir'" +msgstr "" +"Выходная папка для отсортированных выровненных лиц. Если не указано и " +"выбрано 'keep', то в папке input будет создана новая папка под названием " +"'sorted' для размещения выходных данных. Если не указано и не выбрано " +"'keep', то изображения будут отсортированы на месте, перезаписывая исходное " +"содержимое 'input_dir'." + +#: tools/sort/cli.py:112 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple folders of faces you wish to sort. The faces will be output to " +"separate sub-folders in the output_dir" +msgstr "" +"R|Если выбрано, то input_dir должен быть родительской папкой, содержащей " +"несколько папок с лицами, которые вы хотите отсортировать. Лица будут " +"выведены в отдельные вложенные папки в output_dir" + +#: tools/sort/cli.py:121 +msgid "sort settings" +msgstr "настройки сортировки" + +#: tools/sort/cli.py:124 +msgid "" +"R|Choose how images are sorted. Selecting a sort method gives the images a " +"new filename based on the order the image appears within the given method.\n" +"L|'none': Don't sort the images. When a 'group-by' method is selected, " +"selecting 'none' means that the files will be moved/copied into their " +"respective bins, but the files will keep their original filenames. Selecting " +"'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" +"R|Выбор способа сортировки изображений. При выборе метода сортировки " +"изображениям присваивается новое имя файла, основанное на порядке появления " +"изображения в данном методе.\n" +"L|'none': Не сортировать изображения. Если выбран метод 'group-by', выбор " +"'none' означает, что файлы будут перемещены/скопированы в соответствующие " +"корзины, но файлы сохранят свои оригинальные имена. Выбор значения 'none' " +"как для 'sort-by', так и для 'group-by' ничего не даст" + +#: tools/sort/cli.py:136 tools/sort/cli.py:164 tools/sort/cli.py:184 +msgid "group settings" +msgstr "настройки группировки" + +#: tools/sort/cli.py:139 +#, fuzzy +#| msgid "" +#| "R|Selecting a group by method will move/copy files into numbered bins " +#| "based on the selected method.\n" +#| "L|'none': Don't bin the images. Folders will be sorted by the selected " +#| "'sort-by' but will not be binned, instead they will be sorted into a " +#| "single folder. Selecting 'none' for both 'sort-by' and 'group-by' will " +#| "do nothing" +msgid "" +"R|Selecting a group by method will move/copy files into numbered bins based " +"on the selected method.\n" +"L|'none': Don't bin the images. Folders will be sorted by the selected 'sort-" +"by' but will not be binned, instead they will be sorted into a single " +"folder. Selecting 'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" +"R|Выбор группы по методу приведет к перемещению/копированию файлов в " +"пронумерованные корзины в соответствии с выбранным методом.\n" +"L|'none': Не сортировать изображения. Папки будут отсортированы по " +"выбранному \"sort-by\", но не будут разбиты на папки, вместо этого они будут " +"отсортированы в одну папку. Выбор значения 'none' как для 'sort-by', так и " +"для 'group-by' ничего не даст" + +#: tools/sort/cli.py:152 +msgid "" +"Whether to keep the original files in their original location. Choosing a " +"'sort-by' method means that the files have to be renamed. Selecting 'keep' " +"means that the original files will be kept, and the renamed files will be " +"created in the specified output folder. Unselecting keep means that the " +"original files will be moved and renamed based on the selected sort/group " +"criteria." +msgstr "" +"Сохранять ли исходные файлы в их первоначальном расположении. Выбор метода " +"\"сортировать по\" означает, что файлы должны быть переименованы. Выбор " +"'keep' означает, что исходные файлы будут сохранены, а переименованные файлы " +"будут созданы в указанной выходной папке. Отмена выбора \"keep\" означает, " +"что исходные файлы будут перемещены и переименованы в соответствии с " +"выбранными критериями сортировки/группировки." + +#: tools/sort/cli.py:167 +msgid "" +"R|Float value. Minimum threshold to use for grouping comparison with 'face-" +"cnn' 'hist' and 'face' methods.\n" +"The lower the value the more discriminating the grouping is. Leaving -1.0 " +"will allow Faceswap to choose the default value.\n" +"L|For 'face-cnn' 7.2 should be enough, with 4 being very discriminating. \n" +"L|For 'hist' 0.3 should be enough, with 0.2 being very discriminating. \n" +"L|For 'face' between 0.1 (more bins) to 0.5 (fewer bins) should be about " +"right.\n" +"Be careful setting a value that's too extrene in a directory with many " +"images, as this could result in a lot of folders being created. Defaults: " +"face-cnn 7.2, hist 0.3, face 0.25" +msgstr "" +"R|Плавающее значение. Минимальный порог, используемый для сравнения " +"группировок с методами 'face-cnn' 'hist' и 'face'.\n" +"Чем меньше значение, тем более дискриминационной является группировка. Если " +"оставить значение -1.0, Faceswap сможет выбрать значение по умолчанию.\n" +"L|Для 'face-cnn' 7,2 должно быть достаточно, при этом 4 будет очень " +"дискриминационным. \n" +"L|Для 'hist' 0.3 должно быть достаточно, при этом 0.2 очень хорошо " +"различает. \n" +"L|For 'face' от 0,1 (больше бинов) до 0,5 (меньше бинов) должно быть " +"достаточно.\n" +"Будьте осторожны, устанавливая слишком большое значение в каталоге с большим " +"количеством изображений, так как это может привести к созданию большого " +"количества папок. По умолчанию: face-cnn 7.2, hist 0.3, face 0.25" + +#: tools/sort/cli.py:187 +#, fuzzy, python-format +#| msgid "" +#| "R|Integer value. Used to control the number of bins created for grouping " +#| "by: any 'blur' methods, 'color' methods or 'face metric' methods " +#| "('distance', 'size') and 'orientation; methods ('yaw', 'pitch'). For any " +#| "other grouping methods see the '-t' ('--threshold') option.\n" +#| "L|For 'face metric' methods the bins are filled, according the the " +#| "distribution of faces between the minimum and maximum chosen metric.\n" +#| "L|For 'color' methods the number of bins represents the divider of the " +#| "percentage of colored pixels. Eg. For a bin number of '5': The first " +#| "folder will have the faces with 0%% to 20%% colored pixels, second 21%% " +#| "to 40%%, etc. Any empty bins will be deleted, so you may end up with " +#| "fewer bins than selected.\n" +#| "L|For 'blur' methods folder 0 will be the least blurry, while the last " +#| "folder will be the blurriest.\n" +#| "L|For 'orientation' methods the number of bins is dictated by how much " +#| "180 degrees is divided. Eg. If 18 is selected, then each folder will be a " +#| "10 degree increment. Folder 0 will contain faces looking the most to the " +#| "left/down whereas the last folder will contain the faces looking the most " +#| "to the right/up. NB: Some bins may be empty if faces do not fit the " +#| "criteria.\n" +#| "Default value: 5" +msgid "" +"R|Integer value. Used to control the number of bins created for grouping by: " +"any 'blur' methods, 'color' methods or 'face metric' methods ('distance', " +"'size') and 'orientation; methods ('yaw', 'pitch'). For any other grouping " +"methods see the '-t' ('--threshold') option.\n" +"L|For 'face metric' methods the bins are filled, according the the " +"distribution of faces between the minimum and maximum chosen metric.\n" +"L|For 'color' methods the number of bins represents the divider of the " +"percentage of colored pixels. Eg. For a bin number of '5': The first folder " +"will have the faces with 0%% to 20%% colored pixels, second 21%% to 40%%, " +"etc. Any empty bins will be deleted, so you may end up with fewer bins than " +"selected.\n" +"L|For 'blur' methods folder 0 will be the least blurry, while the last " +"folder will be the blurriest.\n" +"L|For 'orientation' methods the number of bins is dictated by how much 180 " +"degrees is divided. Eg. If 18 is selected, then each folder will be a 10 " +"degree increment. Folder 0 will contain faces looking the most to the left/" +"down whereas the last folder will contain the faces looking the most to the " +"right/up. NB: Some bins may be empty if faces do not fit the criteria. \n" +"Default value: 5" +msgstr "" +"R| Целочисленное значение. Используется для управления количеством бинов, " +"создаваемых для группировки: любыми методами 'размытия', 'цвета' или " +"методами 'метрики лица' ('расстояние', 'размер') и 'ориентации; методы " +"('yaw', 'pitch'). Для любых других методов группировки смотрите опцию '-" +"t' ('--threshold').\n" +"L|Для методов 'face metric' бины заполняются в соответствии с распределением " +"лиц между минимальной и максимальной выбранной метрикой.\n" +"L|Для методов 'color' количество бинов представляет собой делитель процента " +"цветных пикселей. Например, для числа бинов \"5\": В первой папке будут лица " +"с 0%% - 20%% цветных пикселей, во второй 21%% - 40%% и т.д. Все пустые папки " +"будут удалены, поэтому в итоге у вас может оказаться меньше папок, чем было " +"выбрано.\n" +"L|Для методов 'blur' папка 0 будет наименее размытой, а последняя папка " +"будет самой размытой.\n" +"L|Для методов \"orientation\" количество бинов диктуется тем, на сколько " +"делится 180 градусов. Например, если выбрано 18, то каждая папка будет иметь " +"шаг в 10 градусов. Папка 0 будет содержать лица, направленные больше всего " +"влево/вниз, а последняя папка будет содержать лица, направленные больше " +"всего вправо/вверх. Примечание: Некоторые папки могут быть пустыми, если " +"лица не соответствуют критериям.\n" +"Значение по умолчанию: 5" + +#: tools/sort/cli.py:207 tools/sort/cli.py:217 +msgid "settings" +msgstr "настройки" + +#: tools/sort/cli.py:210 +msgid "" +"Logs file renaming changes if grouping by renaming, or it logs the file " +"copying/movement if grouping by folders. If no log file is specified with " +"'--log-file', then a 'sort_log.json' file will be created in the input " +"directory." +msgstr "" +"Ведет журнал изменений переименования файлов при группировке по " +"переименованию, или журнал копирования/перемещения файлов при группировке по " +"папкам. Если файл журнала не указан с помощью '--log-file', то в каталоге " +"ввода будет создан файл 'sort_log.json'." + +#: tools/sort/cli.py:221 +msgid "" +"Specify a log file to use for saving the renaming or grouping information. " +"If specified extension isn't 'json' or 'yaml', then json will be used as the " +"serializer, with the supplied filename. Default: sort_log.json" +msgstr "" +"Укажите файл журнала, который будет использоваться для сохранения информации " +"о переименовании или группировке. Если указанное расширение не 'json' или " +"'yaml', то в качестве сериализатора будет использоваться json, с указанным " +"именем файла. По умолчанию: sort_log.json" + +#~ msgid " option is deprecated. Use 'yaw'" +#~ msgstr " является устаревшей. Используйте 'yaw'" + +#~ msgid " option is deprecated. Use 'color-black'" +#~ msgstr " является устаревшей. Используйте 'color-black'" + +#~ msgid "output" +#~ msgstr "вывод" + +#~ msgid "" +#~ "Deprecated and no longer used. The final processing will be dictated by " +#~ "the sort/group by methods and whether 'keep_original' is selected." +#~ msgstr "" +#~ "Устарело и больше не используется. Окончательная обработка будет " +#~ "диктоваться методами sort/group by и тем, выбрана ли опция " +#~ "'keep_original'." diff --git a/locales/tools.alignments.cli.pot b/locales/tools.alignments.cli.pot new file mode 100644 index 0000000000..4f1e02ae15 --- /dev/null +++ b/locales/tools.alignments.cli.pot @@ -0,0 +1,178 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-04-19 11:28+0100\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/alignments/cli.py:16 +msgid "" +"This command lets you perform various tasks pertaining to an alignments file." +msgstr "" + +#: tools/alignments/cli.py:31 +msgid "" +"Alignments tool\n" +"This tool allows you to perform numerous actions on or using an alignments " +"file against its corresponding faceset/frame source." +msgstr "" + +#: tools/alignments/cli.py:43 +msgid " Must Pass in a frames folder/source video file (-r)." +msgstr "" + +#: tools/alignments/cli.py:44 +msgid " Must Pass in a faces folder (-c)." +msgstr "" + +#: tools/alignments/cli.py:45 +msgid "" +" Must Pass in either a frames folder/source video file OR a faces folder (-r " +"or -c)." +msgstr "" + +#: tools/alignments/cli.py:47 +msgid "" +" Must Pass in a frames folder/source video file AND a faces folder (-r and -" +"c)." +msgstr "" + +#: tools/alignments/cli.py:49 +msgid " Use the output option (-o) to process results." +msgstr "" + +#: tools/alignments/cli.py:58 tools/alignments/cli.py:104 +msgid "processing" +msgstr "" + +#: tools/alignments/cli.py:61 +#, python-brace-format +msgid "" +"R|Choose which action you want to perform. NB: All actions require an " +"alignments file (-a) to be passed in.\n" +"L|'draw': Draw landmarks on frames in the selected folder/video. A subfolder " +"will be created within the frames folder to hold the output.{0}\n" +"L|'export': Export the contents of an alignments file to a json file. Can be " +"used for editing alignment information in external tools and then re-" +"importing by using Faceswap's Extract 'Import' plugins. Note: masks and " +"identity vectors will not be included in the exported file, so will be re-" +"generated when the json file is imported back into Faceswap. All data is " +"exported with the origin (0, 0) at the top left of the canvas.\n" +"L|'extract': Re-extract faces from the source frames/video based on " +"alignment data. This is a lot quicker than re-detecting faces. Can pass in " +"the '-een' (--extract-every-n) parameter to only extract every nth frame." +"{1}\n" +"L|'from-faces': Generate alignment file(s) from a folder of extracted faces. " +"if the folder of faces comes from multiple sources, then multiple alignments " +"files will be created. NB: for faces which have been extracted from folders " +"of source images, rather than a video, a single alignments file will be " +"created as there is no way for the process to know how many folders of " +"images were originally used. You do not need to provide an alignments file " +"path to run this job. {3}\n" +"L|'missing-alignments': Identify frames that do not exist in the alignments " +"file.{2}{0}\n" +"L|'missing-frames': Identify frames in the alignments file that do not " +"appear within the frames folder/video.{2}{0}\n" +"L|'multi-faces': Identify where multiple faces exist within the alignments " +"file.{2}{4}\n" +"L|'no-faces': Identify frames that exist within the alignment file but no " +"faces were detected.{2}{0}\n" +"L|'remove-faces': Remove deleted faces from an alignments file. The original " +"alignments file will be backed up.{3}\n" +"L|'rename' - Rename faces to correspond with their parent frame and position " +"index in the alignments file (i.e. how they are named after running extract)." +"{3}\n" +"L|'sort': Re-index the alignments from left to right. For alignments with " +"multiple faces this will ensure that the left-most face is at index 0.\n" +"L|'spatial': Perform spatial and temporal filtering to smooth alignments " +"(EXPERIMENTAL!)" +msgstr "" + +#: tools/alignments/cli.py:107 +msgid "" +"R|How to output discovered items ('faces' and 'frames' only):\n" +"L|'console': Print the list of frames to the screen. (DEFAULT)\n" +"L|'file': Output the list of frames to a text file (stored within the source " +"directory).\n" +"L|'move': Move the discovered items to a sub-folder within the source " +"directory." +msgstr "" + +#: tools/alignments/cli.py:118 tools/alignments/cli.py:141 +#: tools/alignments/cli.py:148 +msgid "data" +msgstr "" + +#: tools/alignments/cli.py:125 +msgid "" +"Full path to the alignments file to be processed. If you have input a " +"'frames_dir' and don't provide this option, the process will try to find the " +"alignments file at the default location. All jobs require an alignments file " +"with the exception of 'from-faces' when the alignments file will be " +"generated in the specified faces folder." +msgstr "" + +#: tools/alignments/cli.py:142 +msgid "Directory containing source frames that faces were extracted from." +msgstr "" + +#: tools/alignments/cli.py:150 +msgid "" +"R|Run the aligmnents tool on multiple sources. The following jobs support " +"batch mode:\n" +"L|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " +"sort, spatial.\n" +"If batch mode is selected then the other options should be set as follows:\n" +"L|alignments_file: For 'sort' and 'spatial' this should point to the parent " +"folder containing the alignments files to be processed. For all other jobs " +"this option is ignored, and the alignments files must exist at their default " +"location relative to the original frames folder/video.\n" +"L|faces_dir: For 'from-faces' this should be a parent folder, containing sub-" +"folders of extracted faces from which to generate alignments files. For " +"'extract' this should be a parent folder where sub-folders will be created " +"for each extraction to be run. For all other jobs this option is ignored.\n" +"L|frames_dir: For 'draw', 'extract', 'missing-alignments', 'missing-frames' " +"and 'no-faces' this should be a parent folder containing video files or sub-" +"folders of images to perform the alignments job on. The alignments file " +"should exist at the default location. For all other jobs this option is " +"ignored." +msgstr "" + +#: tools/alignments/cli.py:176 tools/alignments/cli.py:188 +#: tools/alignments/cli.py:198 +msgid "extract" +msgstr "" + +#: tools/alignments/cli.py:178 +msgid "" +"[Extract only] Extract every 'nth' frame. This option will skip frames when " +"extracting faces. For example a value of 1 will extract faces from every " +"frame, a value of 10 will extract faces from every 10th frame." +msgstr "" + +#: tools/alignments/cli.py:189 +msgid "[Extract only] The output size of extracted faces." +msgstr "" + +#: tools/alignments/cli.py:200 +msgid "" +"[Extract only] Only extract faces that have been resized by this percent or " +"more to meet the specified extract size (`-sz`, `--size`). Useful for " +"excluding low-res images from a training set. Set to 0 to extract all faces. " +"Eg: For an extract size of 512px, A setting of 50 will only include faces " +"that have been resized from 256px or above. Setting to 100 will only extract " +"faces that have been resized from 512px or above. A setting of 200 will only " +"extract faces that have been downscaled from 1024px or above." +msgstr "" diff --git a/locales/tools.effmpeg.cli.pot b/locales/tools.effmpeg.cli.pot new file mode 100644 index 0000000000..72ab831efa --- /dev/null +++ b/locales/tools.effmpeg.cli.pot @@ -0,0 +1,147 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:50+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/effmpeg/cli.py:15 +msgid "This command allows you to easily execute common ffmpeg tasks." +msgstr "" + +#: tools/effmpeg/cli.py:52 +msgid "A wrapper for ffmpeg for performing image <> video converting." +msgstr "" + +#: tools/effmpeg/cli.py:64 +msgid "" +"R|Choose which action you want ffmpeg ffmpeg to do.\n" +"L|'extract': turns videos into images \n" +"L|'gen-vid': turns images into videos \n" +"L|'get-fps' returns the chosen video's fps.\n" +"L|'get-info' returns information about a video.\n" +"L|'mux-audio' add audio from one video to another.\n" +"L|'rescale' resize video.\n" +"L|'rotate' rotate video.\n" +"L|'slice' cuts a portion of the video into a separate video file." +msgstr "" + +#: tools/effmpeg/cli.py:78 +msgid "Input file." +msgstr "" + +#: tools/effmpeg/cli.py:79 tools/effmpeg/cli.py:86 tools/effmpeg/cli.py:100 +msgid "data" +msgstr "" + +#: tools/effmpeg/cli.py:89 +msgid "" +"Output file. If no output is specified then: if the output is meant to be a " +"video then a video called 'out.mkv' will be created in the input directory; " +"if the output is meant to be a directory then a directory called 'out' will " +"be created inside the input directory. Note: the chosen output file " +"extension will determine the file encoding." +msgstr "" + +#: tools/effmpeg/cli.py:102 +msgid "Path to reference video if 'input' was not a video." +msgstr "" + +#: tools/effmpeg/cli.py:108 tools/effmpeg/cli.py:118 tools/effmpeg/cli.py:156 +#: tools/effmpeg/cli.py:185 +msgid "output" +msgstr "" + +#: tools/effmpeg/cli.py:110 +msgid "" +"Provide video fps. Can be an integer, float or fraction. Negative values " +"will will make the program try to get the fps from the input or reference " +"videos." +msgstr "" + +#: tools/effmpeg/cli.py:120 +msgid "" +"Image format that extracted images should be saved as. '.bmp' will offer the " +"fastest extraction speed, but will take the most storage space. '.png' will " +"be slower but will take less storage." +msgstr "" + +#: tools/effmpeg/cli.py:127 tools/effmpeg/cli.py:136 tools/effmpeg/cli.py:145 +msgid "clip" +msgstr "" + +#: tools/effmpeg/cli.py:129 +msgid "" +"Enter the start time from which an action is to be applied. Default: " +"00:00:00, in HH:MM:SS format. You can also enter the time with or without " +"the colons, e.g. 00:0000 or 026010." +msgstr "" + +#: tools/effmpeg/cli.py:138 +msgid "" +"Enter the end time to which an action is to be applied. If both an end time " +"and duration are set, then the end time will be used and the duration will " +"be ignored. Default: 00:00:00, in HH:MM:SS." +msgstr "" + +#: tools/effmpeg/cli.py:147 +msgid "" +"Enter the duration of the chosen action, for example if you enter 00:00:10 " +"for slice, then the first 10 seconds after and including the start time will " +"be cut out into a new video. Default: 00:00:00, in HH:MM:SS format. You can " +"also enter the time with or without the colons, e.g. 00:0000 or 026010." +msgstr "" + +#: tools/effmpeg/cli.py:158 +msgid "" +"Mux the audio from the reference video into the input video. This option is " +"only used for the 'gen-vid' action. 'mux-audio' action has this turned on " +"implicitly." +msgstr "" + +#: tools/effmpeg/cli.py:169 tools/effmpeg/cli.py:179 +msgid "rotate" +msgstr "" + +#: tools/effmpeg/cli.py:171 +msgid "" +"Transpose the video. If transpose is set, then degrees will be ignored. For " +"cli you can enter either the number or the long command name, e.g. to use " +"(1, 90Clockwise) -tr 1 or -tr 90Clockwise" +msgstr "" + +#: tools/effmpeg/cli.py:180 +msgid "Rotate the video clockwise by the given number of degrees." +msgstr "" + +#: tools/effmpeg/cli.py:187 +msgid "Set the new resolution scale if the chosen action is 'rescale'." +msgstr "" + +#: tools/effmpeg/cli.py:192 tools/effmpeg/cli.py:200 +msgid "settings" +msgstr "" + +#: tools/effmpeg/cli.py:194 +msgid "" +"Reduces output verbosity so that only serious errors are printed. If both " +"quiet and verbose are set, verbose will override quiet." +msgstr "" + +#: tools/effmpeg/cli.py:202 +msgid "" +"Increases output verbosity. If both quiet and verbose are set, verbose will " +"override quiet." +msgstr "" diff --git a/locales/tools.manual.pot b/locales/tools.manual.pot new file mode 100644 index 0000000000..4e3fe2e9ab --- /dev/null +++ b/locales/tools.manual.pot @@ -0,0 +1,224 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:55+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/manual/cli.py:13 +msgid "" +"This command lets you perform various actions on frames, faces and " +"alignments files using visual tools." +msgstr "" + +#: tools/manual/cli.py:23 +msgid "" +"A tool to perform various actions on frames, faces and alignments files " +"using visual tools" +msgstr "" + +#: tools/manual/cli.py:35 tools/manual/cli.py:44 +msgid "data" +msgstr "" + +#: tools/manual/cli.py:38 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" + +#: tools/manual/cli.py:46 +msgid "" +"Video file or directory containing source frames that faces were extracted " +"from." +msgstr "" + +#: tools/manual/cli.py:53 tools/manual/cli.py:62 +msgid "options" +msgstr "" + +#: tools/manual/cli.py:55 +msgid "" +"Force regeneration of the low resolution jpg thumbnails in the alignments " +"file." +msgstr "" + +#: tools/manual/cli.py:64 +msgid "" +"The process attempts to speed up generation of thumbnails by extracting from " +"the video in parallel threads. For some videos, this causes the caching " +"process to hang. If this happens, then set this option to generate the " +"thumbnails in a slower, but more stable single thread." +msgstr "" + +#: tools/manual\faceviewer\frame.py:163 +msgid "Display the landmarks mesh" +msgstr "" + +#: tools/manual\faceviewer\frame.py:164 +msgid "Display the mask" +msgstr "" + +#: tools/manual\frameviewer\editor\_base.py:628 +#: tools/manual\frameviewer\editor\landmarks.py:44 +#: tools/manual\frameviewer\editor\mask.py:75 +msgid "Magnify/Demagnify the View" +msgstr "" + +#: tools/manual\frameviewer\editor\bounding_box.py:33 +#: tools/manual\frameviewer\editor\extract_box.py:32 +msgid "Delete Face" +msgstr "" + +#: tools/manual\frameviewer\editor\bounding_box.py:36 +msgid "" +"Bounding Box Editor\n" +"Edit the bounding box being fed into the aligner to recalculate the landmarks.\n" +"\n" +" - Grab the corner anchors to resize the bounding box.\n" +" - Click and drag the bounding box to relocate.\n" +" - Click in empty space to create a new bounding box.\n" +" - Right click a bounding box to delete a face." +msgstr "" + +#: tools/manual\frameviewer\editor\bounding_box.py:70 +msgid "Aligner to use. FAN will obtain better alignments, but cv2-dnn can be useful if FAN cannot get decent alignments and you want to set a base to edit from." +msgstr "" + +#: tools/manual\frameviewer\editor\bounding_box.py:83 +msgid "" +"Normalization method to use for feeding faces to the aligner. This can help the aligner better align faces with difficult lighting conditions. Different methods will yield different results on different sets. NB: This does not impact the output face, just the input to the aligner.\n" +"\tnone: Don't perform normalization on the face.\n" +"\tclahe: Perform Contrast Limited Adaptive Histogram Equalization on the face.\n" +"\thist: Equalize the histograms on the RGB channels.\n" +"\tmean: Normalize the face colors to the mean." +msgstr "" + +#: tools/manual\frameviewer\editor\extract_box.py:35 +msgid "" +"Extract Box Editor\n" +"Move the extract box that has been generated by the aligner. Click and drag:\n" +"\n" +" - Inside the bounding box to relocate the landmarks.\n" +" - The corner anchors to resize the landmarks.\n" +" - Outside of the corners to rotate the landmarks." +msgstr "" + +#: tools/manual\frameviewer\editor\landmarks.py:27 +msgid "" +"Landmark Point Editor\n" +"Edit the individual landmark points.\n" +"\n" +" - Click and drag individual points to relocate.\n" +" - Draw a box to select multiple points to relocate." +msgstr "" + +#: tools/manual\frameviewer\editor\mask.py:33 +msgid "" +"Mask Editor\n" +"Edit the mask.\n" +" - NB: For Landmark based masks (e.g. components/extended) it is better to make sure the landmarks are correct rather than editing the mask directly. Any change to the landmarks after editing the mask will override your manual edits." +msgstr "" + +#: tools/manual\frameviewer\editor\mask.py:77 +msgid "Draw Tool" +msgstr "" + +#: tools/manual\frameviewer\editor\mask.py:78 +msgid "Erase Tool" +msgstr "" + +#: tools/manual\frameviewer\editor\mask.py:97 +msgid "Select which mask to edit" +msgstr "" + +#: tools/manual\frameviewer\editor\mask.py:104 +msgid "Set the brush size. ([ - decrease, ] - increase)" +msgstr "" + +#: tools/manual\frameviewer\editor\mask.py:111 +msgid "Select the brush cursor color." +msgstr "" + +#: tools/manual\frameviewer\frame.py:78 +msgid "Play/Pause (SPACE)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:79 +msgid "Go to First Frame (HOME)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:80 +msgid "Go to Previous Frame (Z)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:81 +msgid "Go to Next Frame (X)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:82 +msgid "Go to Last Frame (END)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:83 +msgid "Extract the faces to a folder... (Ctrl+E)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:84 +msgid "Save the Alignments file (Ctrl+S)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:85 +msgid "Filter Frames to only those Containing the Selected Item (F)" +msgstr "" + +#: tools/manual\frameviewer\frame.py:86 +msgid "Set the distance from an 'average face' to be considered misaligned. Higher distances are more restrictive" +msgstr "" + +#: tools/manual\frameviewer\frame.py:391 +msgid "View alignments" +msgstr "" + +#: tools/manual\frameviewer\frame.py:392 +msgid "Bounding box editor" +msgstr "" + +#: tools/manual\frameviewer\frame.py:393 +msgid "Location editor" +msgstr "" + +#: tools/manual\frameviewer\frame.py:394 +msgid "Mask editor" +msgstr "" + +#: tools/manual\frameviewer\frame.py:395 +msgid "Landmark point editor" +msgstr "" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Next" +msgstr "" + +#: tools/manual\frameviewer\frame.py:470 +msgid "Previous" +msgstr "" + +#: tools/manual\frameviewer\frame.py:481 +msgid "Revert to saved Alignments ({})" +msgstr "" + +#: tools/manual\frameviewer\frame.py:487 +msgid "Copy {} Alignments ({})" +msgstr "" diff --git a/locales/tools.mask.cli.pot b/locales/tools.mask.cli.pot new file mode 100644 index 0000000000..f8024c88ee --- /dev/null +++ b/locales/tools.mask.cli.pot @@ -0,0 +1,200 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-06-28 13:45+0100\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/mask/cli.py:15 +msgid "" +"This tool allows you to generate, import, export or preview masks for " +"existing alignments." +msgstr "" + +#: tools/mask/cli.py:25 +msgid "" +"Mask tool\n" +"Generate, import, export or preview masks for existing alignments files." +msgstr "" + +#: tools/mask/cli.py:35 tools/mask/cli.py:47 tools/mask/cli.py:58 +#: tools/mask/cli.py:69 +msgid "data" +msgstr "" + +#: tools/mask/cli.py:39 +msgid "" +"Full path to the alignments file that contains the masks if not at the " +"default location. NB: If the input-type is faces and you wish to update the " +"corresponding alignments file, then you must provide a value here as the " +"location cannot be automatically detected." +msgstr "" + +#: tools/mask/cli.py:51 +msgid "Directory containing extracted faces, source frames, or a video file." +msgstr "" + +#: tools/mask/cli.py:61 +msgid "" +"R|Whether the `input` is a folder of faces or a folder frames/video\n" +"L|faces: The input is a folder containing extracted faces.\n" +"L|frames: The input is a folder containing frames or is a video" +msgstr "" + +#: tools/mask/cli.py:71 +msgid "" +"R|Run the mask tool on multiple sources. If selected then the other options " +"should be set as follows:\n" +"L|input: A parent folder containing either all of the video files to be " +"processed, or containing sub-folders of frames/faces.\n" +"L|output-folder: If provided, then sub-folders will be created within the " +"given location to hold the previews for each input.\n" +"L|alignments: Alignments field will be ignored for batch processing. The " +"alignments files must exist at the default location (for frames). For batch " +"processing of masks with 'faces' as the input type, then only the PNG header " +"within the extracted faces will be updated." +msgstr "" + +#: tools/mask/cli.py:87 tools/mask/cli.py:119 +msgid "process" +msgstr "" + +#: tools/mask/cli.py:89 +msgid "" +"R|Masker to use.\n" +"L|bisenet-fp: Relatively lightweight NN based mask that provides more " +"refined control over the area to be masked including full head masking " +"(configurable in mask settings).\n" +"L|components: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks to create a mask.\n" +"L|custom: A dummy mask that fills the mask area with all 1s or 0s " +"(configurable in settings). This is only required if you intend to manually " +"edit the custom masks yourself in the manual tool. This mask does not use " +"the GPU.\n" +"L|extended: Mask designed to provide facial segmentation based on the " +"positioning of landmark locations. A convex hull is constructed around the " +"exterior of the landmarks and the mask is extended upwards onto the " +"forehead.\n" +"L|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " +"faces clear of obstructions. Profile faces and obstructions may result in " +"sub-par performance.\n" +"L|vgg-obstructed: Mask designed to provide smart segmentation of mostly " +"frontal faces. The mask model has been specifically trained to recognize " +"some facial obstructions (hands and eyeglasses). Profile faces may result in " +"sub-par performance.\n" +"L|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " +"faces. The mask model has been trained by community members. Profile faces " +"may result in sub-par performance." +msgstr "" + +#: tools/mask/cli.py:121 +msgid "" +"R|The Mask tool process to perform.\n" +"L|all: Update the mask for all faces in the alignments file for the selected " +"'masker'.\n" +"L|missing: Create a mask for all faces in the alignments file where a mask " +"does not previously exist for the selected 'masker'.\n" +"L|output: Don't update the masks, just output the selected 'masker' for " +"review/editing in external tools to the given output folder.\n" +"L|import: Import masks that have been edited outside of faceswap into the " +"alignments file. Note: 'custom' must be the selected 'masker' and the masks " +"must be in the same format as the 'input-type' (frames or faces)" +msgstr "" + +#: tools/mask/cli.py:135 tools/mask/cli.py:154 tools/mask/cli.py:176 +msgid "import" +msgstr "" + +#: tools/mask/cli.py:137 +msgid "" +"R|Import only. The path to the folder that contains masks to be imported.\n" +"L|How the masks are provided is not important, but they will be stored, " +"internally, as 8-bit grayscale images.\n" +"L|If the input are images, then the masks must be named exactly the same as " +"input frames/faces (excluding the file extension).\n" +"L|If the input is a video file, then the filename of the masks is not " +"important but should contain the frame number at the end of the filename " +"(but before the file extension). The frame number can be separated from the " +"rest of the filename by any non-numeric character and can be padded by any " +"number of zeros. The frame number must correspond correctly to the frame " +"number in the original video (starting from frame 1)." +msgstr "" + +#: tools/mask/cli.py:156 +msgid "" +"R|Import/Output only. When importing masks, this is the centering to use. " +"For output this is only used for outputting custom imported masks, and " +"should correspond to the centering used when importing the mask. Note: For " +"any job other than 'import' and 'output' this option is ignored as mask " +"centering is handled internally.\n" +"L|face: Centers the mask on the center of the face, adjusting for pitch and " +"yaw. Outside of requirements for full head masking/training, this is likely " +"to be the best choice.\n" +"L|head: Centers the mask on the center of the head, adjusting for pitch and " +"yaw. Note: You should only select head centering if you intend to include " +"the full head (including hair) within the mask and are looking to train a " +"full head model.\n" +"L|legacy: The 'original' extraction technique. Centers the mask near the of " +"the nose with and crops closely to the face. Can result in the edges of the " +"mask appearing outside of the training area." +msgstr "" + +#: tools/mask/cli.py:181 +msgid "" +"Import only. The size, in pixels to internally store the mask at.\n" +"The default is 128 which is fine for nearly all usecases. Larger sizes will " +"result in larger alignments files and longer processing." +msgstr "" + +#: tools/mask/cli.py:189 tools/mask/cli.py:197 tools/mask/cli.py:211 +#: tools/mask/cli.py:225 tools/mask/cli.py:235 +msgid "output" +msgstr "" + +#: tools/mask/cli.py:191 +msgid "" +"Optional output location. If provided, a preview of the masks created will " +"be output in the given folder." +msgstr "" + +#: tools/mask/cli.py:202 +msgid "" +"Apply gaussian blur to the mask output. Has the effect of smoothing the " +"edges of the mask giving less of a hard edge. the size is in pixels. This " +"value should be odd, if an even number is passed in then it will be rounded " +"to the next odd number. NB: Only effects the output preview. Set to 0 for off" +msgstr "" + +#: tools/mask/cli.py:216 +msgid "" +"Helps reduce 'blotchiness' on some masks by making light shades white and " +"dark shades black. Higher values will impact more of the mask. NB: Only " +"effects the output preview. Set to 0 for off" +msgstr "" + +#: tools/mask/cli.py:227 +msgid "" +"R|How to format the output when processing is set to 'output'.\n" +"L|combined: The image contains the face/frame, face mask and masked face.\n" +"L|masked: Output the face/frame as rgba image with the face masked.\n" +"L|mask: Only output the mask as a single channel image." +msgstr "" + +#: tools/mask/cli.py:237 +msgid "" +"R|Whether to output the whole frame or only the face box when using output " +"processing. Only has an effect when using frames as input." +msgstr "" diff --git a/locales/tools.model.cli.pot b/locales/tools.model.cli.pot new file mode 100644 index 0000000000..f5f2e9c690 --- /dev/null +++ b/locales/tools.model.cli.pot @@ -0,0 +1,63 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:51+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/model/cli.py:13 +msgid "This tool lets you perform actions on saved Faceswap models." +msgstr "" + +#: tools/model/cli.py:22 +msgid "A tool for performing actions on Faceswap trained model files" +msgstr "" + +#: tools/model/cli.py:34 +msgid "" +"Model directory. A directory containing the model you wish to perform an " +"action on." +msgstr "" + +#: tools/model/cli.py:43 +msgid "" +"R|Choose which action you want to perform.\n" +"L|'inference' - Create an inference only copy of the model. Strips any " +"layers from the model which are only required for training. NB: This is for " +"exporting the model for use in external applications. Inference generated " +"models cannot be used within Faceswap. See the 'format' option for " +"specifying the model output format.\n" +"L|'nan-scan' - Scan the model file for NaNs or Infs (invalid data).\n" +"L|'restore' - Restore a model from backup." +msgstr "" + +#: tools/model/cli.py:57 tools/model/cli.py:69 +msgid "inference" +msgstr "" + +#: tools/model/cli.py:59 +msgid "" +"R|The format to save the model as. Note: Only used for 'inference' job.\n" +"L|'h5' - Standard Keras H5 format. Does not store any custom layer " +"information. Layers will need to be loaded from Faceswap to use.\n" +"L|'saved-model' - Tensorflow's Saved Model format. Contains all information " +"required to load the model outside of Faceswap." +msgstr "" + +#: tools/model/cli.py:71 +msgid "" +"Only used for 'inference' job. Generate the inference model for B -> A " +"instead of A -> B." +msgstr "" diff --git a/locales/tools.preview.pot b/locales/tools.preview.pot new file mode 100644 index 0000000000..1dac39da19 --- /dev/null +++ b/locales/tools.preview.pot @@ -0,0 +1,80 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/preview/cli.py:15 +msgid "This command allows you to preview swaps to tweak convert settings." +msgstr "" + +#: tools/preview/cli.py:30 +msgid "" +"Preview tool\n" +"Allows you to configure your convert settings with a live preview" +msgstr "" + +#: tools/preview/cli.py:47 tools/preview/cli.py:57 tools/preview/cli.py:65 +msgid "data" +msgstr "" + +#: tools/preview/cli.py:50 +msgid "" +"Input directory or video. Either a directory containing the image files you " +"wish to process or path to a video file." +msgstr "" + +#: tools/preview/cli.py:60 +msgid "" +"Path to the alignments file for the input, if not at the default location" +msgstr "" + +#: tools/preview/cli.py:68 +msgid "" +"Model directory. A directory containing the trained model you wish to " +"process." +msgstr "" + +#: tools/preview/cli.py:74 +msgid "Swap the model. Instead of A -> B, swap B -> A" +msgstr "" + +#: tools/preview/control_panels.py:510 +msgid "Save full config" +msgstr "" + +#: tools/preview/control_panels.py:513 +msgid "Reset full config to default values" +msgstr "" + +#: tools/preview/control_panels.py:516 +msgid "Reset full config to saved values" +msgstr "" + +#: tools/preview/control_panels.py:667 +#, python-brace-format +msgid "Save {title} config" +msgstr "" + +#: tools/preview/control_panels.py:670 +#, python-brace-format +msgid "Reset {title} config to default values" +msgstr "" + +#: tools/preview/control_panels.py:673 +#, python-brace-format +msgid "Reset {title} config to saved values" +msgstr "" diff --git a/locales/tools.sort.cli.pot b/locales/tools.sort.cli.pot new file mode 100644 index 0000000000..8a963636d0 --- /dev/null +++ b/locales/tools.sort.cli.pot @@ -0,0 +1,262 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR , YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2024-03-28 23:53+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=CHARSET\n" +"Content-Transfer-Encoding: 8bit\n" + +#: tools/sort/cli.py:15 +msgid "This command lets you sort images using various methods." +msgstr "" + +#: tools/sort/cli.py:21 +msgid "" +" Adjust the '-t' ('--threshold') parameter to control the strength of " +"grouping." +msgstr "" + +#: tools/sort/cli.py:22 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the percentage of color pixels " +"that appear in the image." +msgstr "" + +#: tools/sort/cli.py:25 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. Each image is allocated to a bin by the number of degrees the face " +"is orientated from center." +msgstr "" + +#: tools/sort/cli.py:28 +msgid "" +" Adjust the '-b' ('--bins') parameter to control the number of bins for " +"grouping. The minimum and maximum values are taken for the chosen sort " +"metric. The bins are then populated with the results from the group sorting." +msgstr "" + +#: tools/sort/cli.py:32 +msgid "faces by blurriness." +msgstr "" + +#: tools/sort/cli.py:33 +msgid "faces by fft filtered blurriness." +msgstr "" + +#: tools/sort/cli.py:34 +msgid "" +"faces by the estimated distance of the alignments from an 'average' face. " +"This can be useful for eliminating misaligned faces. Sorts from most like an " +"average face to least like an average face." +msgstr "" + +#: tools/sort/cli.py:37 +msgid "" +"faces using VGG Face2 by face similarity. This uses a pairwise clustering " +"algorithm to check the distances between 512 features on every face in your " +"set and order them appropriately." +msgstr "" + +#: tools/sort/cli.py:40 +msgid "faces by their landmarks." +msgstr "" + +#: tools/sort/cli.py:41 +msgid "Like 'face-cnn' but sorts by dissimilarity." +msgstr "" + +#: tools/sort/cli.py:42 +msgid "faces by Yaw (rotation left to right)." +msgstr "" + +#: tools/sort/cli.py:43 +msgid "faces by Pitch (rotation up and down)." +msgstr "" + +#: tools/sort/cli.py:44 +msgid "" +"faces by Roll (rotation). Aligned faces should have a roll value close to " +"zero. The further the Roll value from zero the higher liklihood the face is " +"misaligned." +msgstr "" + +#: tools/sort/cli.py:46 +msgid "faces by their color histogram." +msgstr "" + +#: tools/sort/cli.py:47 +msgid "Like 'hist' but sorts by dissimilarity." +msgstr "" + +#: tools/sort/cli.py:48 +msgid "" +"images by the average intensity of the converted grayscale color channel." +msgstr "" + +#: tools/sort/cli.py:49 +msgid "" +"images by their number of black pixels. Useful when faces are near borders " +"and a large part of the image is black." +msgstr "" + +#: tools/sort/cli.py:51 +msgid "" +"images by the average intensity of the converted Y color channel. Bright " +"lighting and oversaturated images will be ranked first." +msgstr "" + +#: tools/sort/cli.py:53 +msgid "" +"images by the average intensity of the converted Cg color channel. Green " +"images will be ranked first and red images will be last." +msgstr "" + +#: tools/sort/cli.py:55 +msgid "" +"images by the average intensity of the converted Co color channel. Orange " +"images will be ranked first and blue images will be last." +msgstr "" + +#: tools/sort/cli.py:57 +msgid "" +"images by their size in the original frame. Faces further from the camera " +"and from lower resolution sources will be sorted first, whilst faces closer " +"to the camera and from higher resolution sources will be sorted last." +msgstr "" + +#: tools/sort/cli.py:81 +msgid "Sort faces using a number of different techniques" +msgstr "" + +#: tools/sort/cli.py:91 tools/sort/cli.py:98 tools/sort/cli.py:110 +#: tools/sort/cli.py:150 +msgid "data" +msgstr "" + +#: tools/sort/cli.py:92 +msgid "Input directory of aligned faces." +msgstr "" + +#: tools/sort/cli.py:100 +msgid "" +"Output directory for sorted aligned faces. If not provided and 'keep' is " +"selected then a new folder called 'sorted' will be created within the input " +"folder to house the output. If not provided and 'keep' is not selected then " +"the images will be sorted in-place, overwriting the original contents of the " +"'input_dir'" +msgstr "" + +#: tools/sort/cli.py:112 +msgid "" +"R|If selected then the input_dir should be a parent folder containing " +"multiple folders of faces you wish to sort. The faces will be output to " +"separate sub-folders in the output_dir" +msgstr "" + +#: tools/sort/cli.py:121 +msgid "sort settings" +msgstr "" + +#: tools/sort/cli.py:124 +msgid "" +"R|Choose how images are sorted. Selecting a sort method gives the images a " +"new filename based on the order the image appears within the given method.\n" +"L|'none': Don't sort the images. When a 'group-by' method is selected, " +"selecting 'none' means that the files will be moved/copied into their " +"respective bins, but the files will keep their original filenames. Selecting " +"'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" + +#: tools/sort/cli.py:136 tools/sort/cli.py:164 tools/sort/cli.py:184 +msgid "group settings" +msgstr "" + +#: tools/sort/cli.py:139 +msgid "" +"R|Selecting a group by method will move/copy files into numbered bins based " +"on the selected method.\n" +"L|'none': Don't bin the images. Folders will be sorted by the selected 'sort-" +"by' but will not be binned, instead they will be sorted into a single " +"folder. Selecting 'none' for both 'sort-by' and 'group-by' will do nothing" +msgstr "" + +#: tools/sort/cli.py:152 +msgid "" +"Whether to keep the original files in their original location. Choosing a " +"'sort-by' method means that the files have to be renamed. Selecting 'keep' " +"means that the original files will be kept, and the renamed files will be " +"created in the specified output folder. Unselecting keep means that the " +"original files will be moved and renamed based on the selected sort/group " +"criteria." +msgstr "" + +#: tools/sort/cli.py:167 +msgid "" +"R|Float value. Minimum threshold to use for grouping comparison with 'face-" +"cnn' 'hist' and 'face' methods.\n" +"The lower the value the more discriminating the grouping is. Leaving -1.0 " +"will allow Faceswap to choose the default value.\n" +"L|For 'face-cnn' 7.2 should be enough, with 4 being very discriminating. \n" +"L|For 'hist' 0.3 should be enough, with 0.2 being very discriminating. \n" +"L|For 'face' between 0.1 (more bins) to 0.5 (fewer bins) should be about " +"right.\n" +"Be careful setting a value that's too extrene in a directory with many " +"images, as this could result in a lot of folders being created. Defaults: " +"face-cnn 7.2, hist 0.3, face 0.25" +msgstr "" + +#: tools/sort/cli.py:187 +#, python-format +msgid "" +"R|Integer value. Used to control the number of bins created for grouping by: " +"any 'blur' methods, 'color' methods or 'face metric' methods ('distance', " +"'size') and 'orientation; methods ('yaw', 'pitch'). For any other grouping " +"methods see the '-t' ('--threshold') option.\n" +"L|For 'face metric' methods the bins are filled, according the the " +"distribution of faces between the minimum and maximum chosen metric.\n" +"L|For 'color' methods the number of bins represents the divider of the " +"percentage of colored pixels. Eg. For a bin number of '5': The first folder " +"will have the faces with 0%% to 20%% colored pixels, second 21%% to 40%%, " +"etc. Any empty bins will be deleted, so you may end up with fewer bins than " +"selected.\n" +"L|For 'blur' methods folder 0 will be the least blurry, while the last " +"folder will be the blurriest.\n" +"L|For 'orientation' methods the number of bins is dictated by how much 180 " +"degrees is divided. Eg. If 18 is selected, then each folder will be a 10 " +"degree increment. Folder 0 will contain faces looking the most to the left/" +"down whereas the last folder will contain the faces looking the most to the " +"right/up. NB: Some bins may be empty if faces do not fit the criteria. \n" +"Default value: 5" +msgstr "" + +#: tools/sort/cli.py:207 tools/sort/cli.py:217 +msgid "settings" +msgstr "" + +#: tools/sort/cli.py:210 +msgid "" +"Logs file renaming changes if grouping by renaming, or it logs the file " +"copying/movement if grouping by folders. If no log file is specified with " +"'--log-file', then a 'sort_log.json' file will be created in the input " +"directory." +msgstr "" + +#: tools/sort/cli.py:221 +msgid "" +"Specify a log file to use for saving the renaming or grouping information. " +"If specified extension isn't 'json' or 'yaml', then json will be used as the " +"serializer, with the supplied filename. Default: sort_log.json" +msgstr "" diff --git a/plugins/convert/color/__init__.py b/plugins/convert/color/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/convert/color/_base.py b/plugins/convert/color/_base.py new file mode 100644 index 0000000000..6bbe623e8f --- /dev/null +++ b/plugins/convert/color/_base.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" Parent class for color Adjustments for faceswap.py converter """ + +import logging +import numpy as np + +from plugins.convert import convert_config + +logger = logging.getLogger(__name__) + + +class Adjustment(): + """ Parent class for adjustments """ + def __init__(self, configfile=None, config=None): + logger.debug("Initializing %s: (configfile: %s, config: %s)", + self.__class__.__name__, configfile, config) + convert_config.load_config(config_file=configfile) + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self, old_face, new_face, raw_mask): + """ Override for specific color adjustment process """ + raise NotImplementedError + + def run(self, old_face, new_face, raw_mask): + """ Perform selected adjustment on face """ + # pylint:disable=duplicate-code + logger.trace("Performing color adjustment") # type:ignore[attr-defined] + # Remove Mask for processing + reinsert_mask = False + final_mask = None + if new_face.shape[2] == 4: + reinsert_mask = True + final_mask = new_face[:, :, -1] + new_face = new_face[:, :, :3] + new_face = self.process(old_face, new_face, raw_mask) + new_face = np.clip(new_face, 0.0, 1.0) + if reinsert_mask and new_face.shape[2] != 4: + # Reinsert Mask + assert final_mask is not None + new_face = np.concatenate((new_face, np.expand_dims(final_mask, axis=-1)), -1) + logger.trace("Performed color adjustment") # type:ignore[attr-defined] + return new_face diff --git a/plugins/convert/color/avg_color.py b/plugins/convert/color/avg_color.py new file mode 100644 index 0000000000..97a590599d --- /dev/null +++ b/plugins/convert/color/avg_color.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +""" Average colour adjustment color matching adjustment plugin for faceswap.py converter """ + +import numpy as np +from lib.utils import get_module_objects +from ._base import Adjustment + + +class Color(Adjustment): + """ Adjust the mean of the color channels to be the same for the swap and old frame """ + + def process(self, + old_face: np.ndarray, + new_face: np.ndarray, + raw_mask: np.ndarray) -> np.ndarray: + """ Adjust the mean of the original face and the new face to be the same + + Parameters + ---------- + old_face: :class:`numpy.ndarray` + The original face + new_face: :class:`numpy.ndarray` + The Faceswap generated face + raw_mask: :class:`numpy.ndarray` + A raw mask for including the face area only + + Returns + ------- + :class:`numpy.ndarray` + The adjusted face patch + """ + for _ in [0, 1]: + diff = old_face - new_face + if np.any(raw_mask): + avg_diff = np.sum(diff * raw_mask, axis=(0, 1)) + adjustment = avg_diff / np.sum(raw_mask, axis=(0, 1)) + else: + adjustment = diff + new_face += adjustment + return new_face + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/color/color_transfer.py b/plugins/convert/color/color_transfer.py new file mode 100644 index 0000000000..6cb67f01a9 --- /dev/null +++ b/plugins/convert/color/color_transfer.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +""" Color Transfer adjustment color matching adjustment plugin for faceswap.py converter + source: https://github.com/jrosebr1/color_transfer + The MIT License (MIT) + + Copyright (c) 2014 Adrian Rosebrock, http://www.pyimagesearch.com + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. """ + +import cv2 +import numpy as np +from lib.utils import get_module_objects +from ._base import Adjustment +from . import color_transfer_defaults as cfg + + +class Color(Adjustment): + """ + Transfers the color distribution from the source to the target + image using the mean and standard deviations of the L*a*b* + color space. + + This implementation is (loosely) based on to the "Color Transfer + between Images" paper by Reinhard et al., 2001. + """ + + def process(self, old_face, new_face, raw_mask): # pylint:disable=too-many-locals + """ + Parameters + ---------- + source: NumPy array + OpenCV image in BGR color space (the source image) + target: NumPy array + OpenCV image in BGR color space (the target image) + clip: Should components of L*a*b* image be scaled by np.clip before + converting back to BGR color space? + If False then components will be min-max scaled appropriately. + Clipping will keep target image brightness truer to the input. + Scaling will adjust image brightness to avoid washed out portions + in the resulting color transfer that can be caused by clipping. + preserve_paper: Should color transfer strictly follow methodology + layed out in original paper? The method does not always produce + aesthetically pleasing results. + If False then L*a*b* components will scaled using the reciprocal of + the scaling factor proposed in the paper. This method seems to produce + more consistently aesthetically pleasing results + + Returns + ------- + transfer: NumPy array + OpenCV image (w, h, 3) NumPy array (uint8) + """ + clip = cfg.clip() + preserve_paper = cfg.preserve_paper() + + # convert the images from the RGB to L*ab* color space, being + # sure to utilizing the floating point data type (note: OpenCV + # expects floats to be 32-bit, so use that instead of 64-bit) + source = cv2.cvtColor( # pylint:disable=no-member + np.rint(old_face * raw_mask * 255.0).astype("uint8"), + cv2.COLOR_BGR2LAB).astype("float32") # pylint:disable=no-member + target = cv2.cvtColor( # pylint:disable=no-member + np.rint(new_face * raw_mask * 255.0).astype("uint8"), + cv2.COLOR_BGR2LAB).astype("float32") # pylint:disable=no-member + # compute color statistics for the source and target images + (l_mean_src, l_std_src, + a_mean_src, a_std_src, + b_mean_src, b_std_src) = self.image_stats(source) + (l_mean_tar, l_std_tar, + a_mean_tar, a_std_tar, + b_mean_tar, b_std_tar) = self.image_stats(target) + + # subtract the means from the target image + (light, col_a, col_b) = cv2.split(target) # pylint:disable=no-member + light -= l_mean_tar + col_a -= a_mean_tar + col_b -= b_mean_tar + + if preserve_paper: + # scale by the standard deviations using paper proposed factor + light = (l_std_tar / l_std_src) * light + col_a = (a_std_tar / a_std_src) * col_a + col_b = (b_std_tar / b_std_src) * col_b + else: + # scale by the standard deviations using reciprocal of paper proposed factor + light = (l_std_src / l_std_tar) * light + col_a = (a_std_src / a_std_tar) * col_a + col_b = (b_std_src / b_std_tar) * col_b + + # add in the source mean + light += l_mean_src + col_a += a_mean_src + col_b += b_mean_src + + # clip/scale the pixel intensities to [0, 255] if they fall + # outside this range + light = self._scale_array(light, clip=clip) + col_a = self._scale_array(col_a, clip=clip) + col_b = self._scale_array(col_b, clip=clip) + + # merge the channels together and convert back to the RGB color + # space, being sure to utilize the 8-bit unsigned integer data + # type + transfer = cv2.merge([light, col_a, col_b]) # pylint:disable=no-member + transfer = cv2.cvtColor( # pylint:disable=no-member + transfer.astype("uint8"), + cv2.COLOR_LAB2BGR).astype("float32") / 255.0 # pylint:disable=no-member + background = new_face * (1 - raw_mask) + merged = transfer + background + # return the color transferred image + return merged + + @staticmethod + def image_stats(image): + """ + Parameters + ---------- + + image: NumPy array + OpenCV image in L*a*b* color space + + Returns + ------- + Tuple of mean and standard deviations for the L*, a*, and b* + channels, respectively + """ + # compute the mean and standard deviation of each channel + (light, col_a, col_b) = cv2.split(image) # pylint:disable=no-member + (l_mean, l_std) = (light.mean(), light.std()) + (a_mean, a_std) = (col_a.mean(), col_a.std()) + (b_mean, b_std) = (col_b.mean(), col_b.std()) + + # return the color statistics + return (l_mean, l_std, a_mean, a_std, b_mean, b_std) + + @staticmethod + def _min_max_scale(arr, new_range=(0, 255)): + """ + Perform min-max scaling to a NumPy array + + Parameters + ---------- + arr: NumPy array to be scaled to [new_min, new_max] range + new_range: tuple of form (min, max) specifying range of + transformed array + + Returns + ------- + NumPy array that has been scaled to be in + [new_range[0], new_range[1]] range + """ + # get array's current min and max + arr_min = arr.min() + arr_max = arr.max() + + # check if scaling needs to be done to be in new_range + if arr_min < new_range[0] or arr_max > new_range[1]: + # perform min-max scaling + scaled = (new_range[1] - new_range[0]) * (arr - arr_min) / (arr_max - + arr_min) + new_range[0] + else: + # return array if already in range + scaled = arr + + return scaled + + def _scale_array(self, arr, clip=True): + """ + Trim NumPy array values to be in [0, 255] range with option of + clipping or scaling. + + Parameters + ---------- + arr: array to be trimmed to [0, 255] range + clip: should array be scaled by np.clip? if False then input + array will be min-max scaled to range + [max([arr.min(), 0]), min([arr.max(), 255])] + + Returns + ------- + NumPy array that has been scaled to be in [0, 255] range + """ + if clip: + scaled = np.clip(arr, 0, 255) + else: + scale_range = (max([arr.min(), 0]), min([arr.max(), 255])) + scaled = self._min_max_scale(arr, new_range=scale_range) + + return scaled + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/color/color_transfer_defaults.py b/plugins/convert/color/color_transfer_defaults.py new file mode 100755 index 0000000000..b12931c6f5 --- /dev/null +++ b/plugins/convert/color/color_transfer_defaults.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Color_Transfer Color plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +""" +from lib.config import ConfigItem + + +HELPTEXT = ( + "Options for transfering the color distribution from the source to the target image using the " + "mean and standard deviations of the L*a*b* color space.\nThis implementation is (loosely) " + "based on the 'Color Transfer between Images' paper by Reinhard et al., 2001. matching the " + "histograms between the source and destination faces.") + + +clip = ConfigItem( + datatype=bool, + default=True, + group="method", + info="Should components of L*a*b* image be scaled by numpy.clip before converting back to " + "BGR color space?\nIf False then components will be min-max scaled appropriately.\n" + "Clipping will keep target image brightness truer to the input.\nScaling will adjust " + "image brightness to avoid washed out portions in the resulting color transfer that " + "can be caused by clipping.") + +preserve_paper = ConfigItem( + datatype=bool, + group="method", + default=True, + info="Should color transfer strictly follow methodology layed out in original paper?\nThe " + "method does not always produce aesthetically pleasing results.\nIf False then " + "L*a*b* components will be scaled using the reciprocal of the scaling factor " + "proposed in the paper. This method seems to produce more consistently aesthetically " + "pleasing results.") diff --git a/plugins/convert/color/manual_balance.py b/plugins/convert/color/manual_balance.py new file mode 100644 index 0000000000..3719bec67a --- /dev/null +++ b/plugins/convert/color/manual_balance.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +""" Manual Balance colour adjustment plugin for faceswap.py converter """ + +import cv2 +import numpy as np +from lib.utils import get_module_objects +from ._base import Adjustment +from . import manual_balance_defaults as cfg + + +class Color(Adjustment): + """ Adjust the mean of the color channels to be the same for the swap and old frame """ + + def process(self, old_face, new_face, raw_mask): + image = self.convert_colorspace(new_face * 255.0) + adjustment = np.array([cfg.balance_1() / 100.0, + cfg.balance_2() / 100.0, + cfg.balance_3() / 100.0]).astype("float32") + for idx in range(3): + if adjustment[idx] >= 0: + image[:, :, idx] = ((1 - image[:, :, idx]) * adjustment[idx]) + image[:, :, idx] + else: + image[:, :, idx] = image[:, :, idx] * (1 + adjustment[idx]) + + image = self.convert_colorspace(image * 255.0, to_bgr=True) + image = self.adjust_contrast(image) + return image + + def adjust_contrast(self, image): + """ + Adjust image contrast and brightness. + """ + contrast = max(-126, int(round(cfg.contrast() * 1.27))) + brightness = max(-126, int(round(cfg.brightness() * 1.27))) + + if not contrast and not brightness: + return image + + image = np.rint(image * 255.0).astype("uint8") + image = np.clip(image * (contrast/127+1) - contrast + brightness, 0, 255) + image = np.clip(np.divide(image, 255, dtype=np.float32), .0, 1.0) + return image + + def convert_colorspace(self, new_face, to_bgr=False): + """ Convert colorspace based on mode or back to bgr """ + mode = cfg.colorspace().lower() + colorspace = "YCrCb" if mode == "ycrcb" else mode.upper() + conversion = f"{colorspace}2BGR" if to_bgr else f"BGR2{colorspace}" + image = cv2.cvtColor(new_face.astype("uint8"), # pylint:disable=no-member + getattr(cv2, f"COLOR_{conversion}")).astype("float32") / 255.0 + return image + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/color/manual_balance_defaults.py b/plugins/convert/color/manual_balance_defaults.py new file mode 100755 index 0000000000..b3a7dfde6b --- /dev/null +++ b/plugins/convert/color/manual_balance_defaults.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Manual_Balance Color plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Options for manually altering the balance of colors of the swapped face" + + +colorspace = ConfigItem( + datatype=str, + default="HSV", + group="color balance", + info="The colorspace to use for adjustment: The three adjustment sliders will " + "effect the image differently depending on which colorspace is selected:" + "\n\t RGB: Red, Green, Blue. An additive colorspace where colors are obtained " + "by a linear combination of Red, Green, and Blue values. The three channels " + "are correlated by the amount of light hitting the surface. In RGB color " + "space the color information is separated into three channels but the same " + "three channels also encode brightness information." + "\n\t HSV: Hue, Saturation, Value. Hue - Dominant wavelength. Saturation - " + "Purity / shades of color. Value - Intensity. Best thing is that it uses only " + "one channel to describe color (H), making it very intuitive to specify color." + "\n\t LAB: Lightness, A, B. Lightness - Intensity. A - Color range from green " + "to magenta. B - Color range from blue to yellow. The L channel is " + "independent of color information and encodes brightness only. The other two " + "channels encode color." + "\n\t YCrCb: Y - Luminance or Luma component obtained from RGB after gamma " + "correction. Cr - how far is the red component from Luma. Cb - how far is the " + "blue component from Luma. Separates the luminance and chrominance components " + "into different channels.", + choices=["RGB", "HSV", "LAB", "YCrCb"], + gui_radio=True) + +balance_1 = ConfigItem( + datatype=float, + default=0.0, + group="color balance", + info="Balance of channel 1:" + "\n\tRGB: Red" + "\n\tHSV: Hue" + "\n\tLAB: Lightness" + "\n\tYCrCb: Luma", + rounding=1, + min_max=(-100.0, 100.0)) + +balance_2 = ConfigItem( + datatype=float, + default=0.0, + group="color balance", + info="Balance of channel 2:" + "\n\tRGB: Green" + "\n\tHSV: Saturation" + "\n\tLAB: Green > Magenta" + "\n\tYCrCb: Distance of red from Luma", + rounding=1, + min_max=(-100.0, 100.0)) + +balance_3 = ConfigItem( + datatype=float, + default=0.0, + group="color balance", + info="Balance of channel 3:" + "\n\tRGB: Blue" + "\n\tHSV: Intensity" + "\n\tLAB: Blue > Yellow" + "\n\tYCrCb: Distance of blue from Luma", + rounding=1, + min_max=(-100.0, 100.0)) + +contrast = ConfigItem( + datatype=float, + default=0.0, + group="brightness contrast", + info="Amount of contrast applied.", + rounding=1, + min_max=(-100.0, 100.0)) + +brightness = ConfigItem( + datatype=float, + default=0.0, + group="brightness contrast", + info="Amount of brighness applied.", + rounding=1, + min_max=(-100.0, 100.0)) diff --git a/plugins/convert/color/match_hist.py b/plugins/convert/color/match_hist.py new file mode 100644 index 0000000000..c743118385 --- /dev/null +++ b/plugins/convert/color/match_hist.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +""" Match histogram colour adjustment color matching adjustment plugin + for faceswap.py converter """ + +import numpy as np +from lib.utils import get_module_objects +from ._base import Adjustment +from . import match_hist_defaults as cfg + + +class Color(Adjustment): + """ Match the histogram of the color intensity of each channel """ + + def process(self, old_face, new_face, raw_mask): + mask_indices = np.nonzero(raw_mask.squeeze()) + new_face = [self.hist_match(old_face[:, :, c], + new_face[:, :, c], + mask_indices, + cfg.threshold() / 100) + for c in range(3)] + new_face = np.stack(new_face, axis=-1) + return new_face + + @staticmethod + def hist_match(old_channel, new_channel, mask_indices, threshold): + """ Construct the histogram of the color intensity of a channel + for the swap and the original. Match the histogram of the original + by interpolation + """ + if mask_indices[0].size == 0: + return new_channel + + old_masked = old_channel[mask_indices] + new_masked = new_channel[mask_indices] + _, bin_idx, s_counts = np.unique(new_masked, return_inverse=True, return_counts=True) + t_values, t_counts = np.unique(old_masked, return_counts=True) + s_quants = np.cumsum(s_counts, dtype='float32') + t_quants = np.cumsum(t_counts, dtype='float32') + s_quants = threshold * s_quants / s_quants[-1] # cdf + t_quants /= t_quants[-1] # cdf + interp_s_values = np.interp(s_quants, t_quants, t_values) + new_channel[mask_indices] = interp_s_values[bin_idx] + return new_channel + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/color/match_hist_defaults.py b/plugins/convert/color/match_hist_defaults.py new file mode 100755 index 0000000000..19dd891c4a --- /dev/null +++ b/plugins/convert/color/match_hist_defaults.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Match_Hist Color plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Options for matching the histograms between the source and destination faces" + + +threshold = ConfigItem( + datatype=float, + default=99.0, + group="settings", + info="Adjust the threshold for histogram matching. Can reduce extreme colors leaking in " + "by filtering out colors at the extreme ends of the histogram spectrum.", + rounding=1, + min_max=(90.0, 100.0)) diff --git a/plugins/convert/color/seamless_clone.py b/plugins/convert/color/seamless_clone.py new file mode 100644 index 0000000000..7d6680d5d7 --- /dev/null +++ b/plugins/convert/color/seamless_clone.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +""" Seamless clone adjustment plugin for faceswap.py converter +NB: This probably isn't the best place for this, but it is independent of color adjustments and +does not have a natural home, so here for now and called as an extra plugin from lib/convert.py +""" +import cv2 +import numpy as np +from lib.utils import get_module_objects +from ._base import Adjustment + + +class Color(Adjustment): + """ Seamless clone the swapped face into the old face with cv2 + NB: This probably isn't the best place for this, but it doesn't work well and does not have a + natural home, so here for now. + """ + def process(self, old_face, new_face, raw_mask): # pylint:disable=too-many-locals + height, width, _ = old_face.shape + height = height // 2 + width = width // 2 + + y_indices, x_indices, _ = np.nonzero(raw_mask) + y_crop = slice(np.min(y_indices), np.max(y_indices)) + x_crop = slice(np.min(x_indices), np.max(x_indices)) + y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) + x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) + + insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask[insertion_mask != 0] = 255 + prior = np.rint(np.pad(old_face * 255.0, + ((height, height), (width, width), (0, 0)), + 'constant')).astype("uint8") + + blended = cv2.seamlessClone(insertion, # pylint:disable=no-member + prior, + insertion_mask, + (x_center, y_center), + cv2.NORMAL_CLONE) # pylint:disable=no-member + blended = blended[height:-height, width:-width] + + return blended.astype("float32") / 255.0 + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/convert_config.py b/plugins/convert/convert_config.py new file mode 100644 index 0000000000..9f174c77c1 --- /dev/null +++ b/plugins/convert/convert_config.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" Default configurations for convert """ + +import logging +import os + +from lib.config import FaceswapConfig + +logger = logging.getLogger(__name__) + + +class _Config(FaceswapConfig): + """ Config File for Convert """ + + def set_defaults(self, helptext=""): + """ Set the default values for config """ + super().set_defaults(helptext=helptext) + self._defaults_from_plugin(os.path.dirname(__file__)) + + +_CONFIG: _Config | None = None + + +def load_config(config_file: str | None = None) -> _Config: + """ Load the Extraction configuration .ini file + + Parameters + ---------- + config_file : str | None, optional + Path to a custom .ini configuration file to load. Default: ``None`` (use default + configuration file) + + Returns + ------- + :class:`_Config` + The loaded convert config object + """ + global _CONFIG # pylint:disable=global-statement + if _CONFIG is None: + _CONFIG = _Config(configfile=config_file) + return _CONFIG diff --git a/plugins/convert/mask/__init__.py b/plugins/convert/mask/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/convert/mask/mask_blend.py b/plugins/convert/mask/mask_blend.py new file mode 100644 index 0000000000..a46014dbd2 --- /dev/null +++ b/plugins/convert/mask/mask_blend.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +""" Plugin to blend the edges of the face between the swap and the original face. """ +import logging +import typing as T + +import cv2 +import numpy as np + +from lib.align import BlurMask, DetectedFace +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from plugins.convert import convert_config +from . import mask_blend_defaults as cfg + +logger = logging.getLogger(__name__) + + +class Mask(): + """ Manipulations to perform to the mask that is to be applied to the output of the Faceswap + model. + + Parameters + ---------- + mask_type: str + The mask type to use for this plugin + output_size: int + The size of the output from the Faceswap model. + coverage_ratio: float + The coverage ratio that the Faceswap model was trained at. + configfile: str, Optional + Optional location of custom configuration ``ini`` file. If ``None`` then use the default + config location. Default: ``None`` + """ + def __init__(self, + mask_type: str, + output_size: int, + coverage_ratio: float, + configfile: str | None = None) -> None: + logger.debug(parse_class_init(locals())) + self._mask_type = mask_type + convert_config.load_config(config_file=configfile) + + self._coverage_ratio = coverage_ratio + self._box = self._get_box(output_size) + + self._erodes = [erode / 100 + for erode in [cfg.erosion(), cfg.erosion_left(), cfg.erosion_top(), + cfg.erosion_right(), cfg.erosion_bottom()]] + self._do_erode = any(amount != 0 for amount in self._erodes) + + def _get_box(self, output_size: int) -> np.ndarray: + """ Apply a gradient overlay to the edge of the swap box to smooth out any hard areas + that where the face intersects with the edge of the swap area. + + Gradient is created from 1/16th distance from the edge of the face box and uses the + parameters as provided for mask blend settings + + Parameters + ---------- + output_size: int + The size of the box that contains the swapped face + + Returns + ------- + :class:`numpy.ndarray` + The box mask + """ + box = np.zeros((output_size, output_size, 1), dtype="float32") + edge = (output_size // 32) + 1 + box[edge:-edge, edge:-edge] = 1.0 + + if cfg.type() != "none": + box = BlurMask("gaussian", + box, + 6, + is_ratio=True).blurred + return box + + def run(self, + detected_face: DetectedFace, + source_offset: np.ndarray, + target_offset: np.ndarray, + centering: T.Literal["legacy", "face", "head"], + predicted_mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]: + """ Obtain the requested mask type and perform any defined mask manipulations. + + Parameters + ---------- + detected_face: :class:`lib.align.DetectedFace` + The DetectedFace object as returned from :class:`scripts.convert.Predictor`. + source_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at its stored centering + target_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at the requested target centering + centering: [`"legacy"`, `"face"`, `"head"`] + The centering to obtain the mask for + predicted_mask: :class:`numpy.ndarray`, optional + The predicted mask as output from the Faceswap Model, if the model was trained + with a mask, otherwise ``None``. Default: ``None``. + + Returns + ------- + mask: :class:`numpy.ndarray` + The mask with all requested manipulations applied + raw_mask: :class:`numpy.ndarray` + The mask with no erosion/dilation applied + """ + logger.trace("Performing mask adjustment: (detected_face: %s, " # type: ignore + "source_offset: %s, target_offset: %s, centering: '%s', predicted_mask: %s", + detected_face, source_offset, target_offset, centering, + predicted_mask is not None) + mask = self._get_mask(detected_face, + predicted_mask, + centering, + source_offset, + target_offset) + raw_mask = mask.copy() + + if self._mask_type != "none": + out = self._erode(mask) if self._do_erode else mask + out = np.minimum(out, self._box) + else: + out = mask + + logger.trace( # type: ignore + "mask shape: %s, raw_mask shape: %s", mask.shape, raw_mask.shape) + return out, raw_mask + + def _get_mask(self, + detected_face: DetectedFace, + predicted_mask: np.ndarray | None, + centering: T.Literal["legacy", "face", "head"], + source_offset: np.ndarray, + target_offset: np.ndarray) -> np.ndarray: + """ Return the requested mask with any requested blurring applied. + + Parameters + ---------- + detected_face: :class:`lib.align.DetectedFace` + The DetectedFace object as returned from :class:`scripts.convert.Predictor`. + predicted_mask: :class:`numpy.ndarray` + The predicted mask as output from the Faceswap Model if the model was trained + with a mask, otherwise ``None`` + centering: [`"legacy"`, `"face"`, `"head"`] + The centering to obtain the mask for + source_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at its stored centering + target_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at the requested target centering + + Returns + ------- + :class:`numpy.ndarray` + The requested mask. + """ + if self._mask_type == "none": + mask = np.ones_like(self._box) # Return a dummy mask if not using a mask + elif self._mask_type == "predicted" and predicted_mask is not None: + mask = self._process_predicted_mask(predicted_mask) + else: + mask = self._get_stored_mask(detected_face, centering, source_offset, target_offset) + + logger.trace(mask.shape) # type: ignore + return mask + + def _process_predicted_mask(self, mask: np.ndarray) -> np.ndarray: + """ Process blurring of the predicted mask + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The predicted mask as output from the Faceswap Model + + Returns + ------ + :class:`numpy.ndarray` + The processed predicted mask + """ + blur_type = T.cast(T.Literal["gaussian", "normalized", "none"], cfg.type().lower()) + if blur_type != "none": + mask = BlurMask(blur_type, + mask, + cfg.kernel_size(), + passes=cfg.passes()).blurred + return mask + + def _get_stored_mask(self, + detected_face: DetectedFace, + centering: T.Literal["legacy", "face", "head"], + source_offset: np.ndarray, + target_offset: np.ndarray) -> np.ndarray: + """ get the requested stored mask from the detected face object. + + Parameters + ---------- + detected_face: :class:`lib.align.DetectedFace` + The DetectedFace object as returned from :class:`scripts.convert.Predictor`. + centering: [`"legacy"`, `"face"`, `"head"`] + The centering to obtain the mask for + source_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at its stored centering + target_offset: :class:`numpy.ndarray` + The (x, y) offset for the mask at the requested target centering + + Returns + ------- + :class:`numpy.ndarray` + The mask sized to Faceswap model output with any requested blurring applied. + """ + mask = detected_face.mask[self._mask_type] + blur_type = T.cast(T.Literal["gaussian", "normalized"] | None, cfg.type().lower()) + blur_type = None if blur_type == "none" else blur_type + mask.set_blur_and_threshold(blur_kernel=cfg.kernel_size(), + blur_type=blur_type, + blur_passes=cfg.passes(), + threshold=cfg.threshold()) + mask.set_sub_crop(source_offset, target_offset, centering, self._coverage_ratio) + face_mask = mask.mask + mask_size = face_mask.shape[0] + face_size = self._box.shape[0] + if mask_size != face_size: + interp = cv2.INTER_CUBIC if mask_size < face_size else cv2.INTER_AREA + face_mask = cv2.resize(face_mask, + self._box.shape[:2], + interpolation=interp)[..., None].astype("float32") / 255. + else: + face_mask = face_mask.astype("float32") / 255. + return face_mask + + # MASK MANIPULATIONS + def _erode(self, mask: np.ndarray) -> np.ndarray: + """ Erode or dilate mask the mask based on configuration options. + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The mask to be eroded or dilated + + Returns + ------- + :class:`numpy.ndarray` + The mask with erosion/dilation applied + """ + kernels = self._get_erosion_kernels(mask) + if not any(k.any() for k in kernels): + return mask # No kernels could be created from selected input res + eroded = mask + for idx, (kernel, ratio) in enumerate(zip(kernels, self._erodes)): + if not kernel.any(): + continue + anchor = [-1, -1] + if idx > 0: + pos = 1 if idx % 2 == 0 else 0 + if ratio > 0: + val = max(kernel.shape) - 1 if idx < 3 else 0 + else: + val = 0 if idx < 3 else max(kernel.shape) - 1 + anchor[pos] = val + + func = cv2.erode if ratio > 0 else cv2.dilate + eroded = func(eroded, kernel, iterations=1, anchor=anchor) + + return eroded[..., None] + + def _get_erosion_kernels(self, mask: np.ndarray) -> list[np.ndarray]: + """ Get the erosion kernels for each of the center, left, top right and bottom erosions. + + An approximation is made based on the number of positive pixels within the mask to create + an ellipse to act as kernel. + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The mask to be eroded or dilated + + Returns + ------- + list + The erosion kernels to be used for erosion/dilation + """ + mask_radius = np.sqrt(np.sum(mask)) / 2 + kernel_sizes = [max(0, int(abs(ratio * mask_radius))) for ratio in self._erodes] + kernels = [] + for idx, size in enumerate(kernel_sizes): + kernel = [size, size] + shape = cv2.MORPH_ELLIPSE if idx == 0 else cv2.MORPH_RECT + if idx > 1: + pos = 0 if idx % 2 == 0 else 1 + kernel[pos] = 1 # Set x/y to 1px based on whether eroding top/bottom, left/right + kernels.append(cv2.getStructuringElement(shape, kernel) if size else np.array(0)) + logger.trace("Erosion kernels: %s", [k.shape for k in kernels]) # type: ignore + return kernels + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/mask/mask_blend_defaults.py b/plugins/convert/mask/mask_blend_defaults.py new file mode 100755 index 0000000000..4d926f1443 --- /dev/null +++ b/plugins/convert/mask/mask_blend_defaults.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Mask_Blend Mask plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Options for blending the edges between the mask and the background image" + + +type = ConfigItem( # pylint:disable=redefined-builtin + datatype=str, + default="normalized", + group="Blending type", + info="The type of blending to use:" + "\n\t gaussian: Blend with Gaussian filter. Slower, but often better than Normalized" + "\n\t normalized: Blend with Normalized box filter. Faster than Gaussian" + "\n\t none: Don't perform blending", + choices=["gaussian", "normalized", "none"]) + +kernel_size = ConfigItem( + datatype=int, + default=3, + group="settings", + info="The kernel size dictates how much blending should occur.\n" + "The size is the diameter of the kernel in pixels (calculated from a 128px mask). " + "This value should be odd, if an even number is passed in then it will be rounded to " + "the next odd number. Higher sizes means more blending.", + rounding=1, + min_max=(1, 9)) + +passes = ConfigItem( + default=4, + datatype=int, + group="settings", + info="The number of passes to perform. Additional passes of the blending algorithm can " + "improve smoothing at a time cost. This is more useful for 'box' type blending.\n" + "Additional passes have exponentially less effect so it's not worth setting this too " + "high.", + rounding=1, + min_max=(1, 8)) + +threshold = ConfigItem( + default=4, + datatype=int, + group="settings", + info="Sets pixels that are near white to white and near black to black. Set to 0 for off.", + rounding=1, + min_max=(0, 50)) + +erosion = ConfigItem( + datatype=float, + default=0.0, + group="settings", + info="Apply erosion to the whole of the face mask.\n" + "Erosion kernel size as a percentage of the mask radius area.\n" + "Positive values apply erosion which reduces the size of the swapped area.\n" + "Negative values apply dilation which increases the swapped area.", + rounding=1, + min_max=(-100.0, 100.0)) + +erosion_top = ConfigItem( + datatype=float, + default=0.0, + group="settings", + info="Apply erosion to the top part of the mask only.\n" + "Positive values apply erosion which pulls the mask into the center.\n" + "Negative values apply dilation which pushes the mask away from the center.", + rounding=1, + min_max=(-100.0, 100.0)) + +erosion_bottom = ConfigItem( + datatype=float, + default=0.0, + group="settings", + info="Apply erosion to the bottom part of the mask only.\n" + "Positive values apply erosion which pulls the mask into the center.\n" + "Negative values apply dilation which pushes the mask away from the center.", + rounding=1, + min_max=(-100.0, 100.0)) + +erosion_left = ConfigItem( + default=0.0, + datatype=float, + group="settings", + info="Apply erosion to the left part of the mask only.\n" + "Positive values apply erosion which pulls the mask into the center.\n" + "Negative values apply dilation which pushes the mask away from the center.", + rounding=1, + min_max=(-100.0, 100.0)) + +erosion_right = ConfigItem( + datatype=float, + default=0.0, + group="settings", + info="Apply erosion to the right part of the mask only.\n" + "Positive values apply erosion which pulls the mask into the center.\n" + "Negative values apply dilation which pushes the mask away from the center.", + rounding=1, + min_max=(-100.0, 100.0)) diff --git a/plugins/convert/masked.py b/plugins/convert/masked.py deleted file mode 100644 index b7c3be6920..0000000000 --- a/plugins/convert/masked.py +++ /dev/null @@ -1,363 +0,0 @@ -#!/usr/bin/env python3 -""" Masked converter for faceswap.py - Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955 - found on https://www.reddit.com/r/deepfakes/ """ - -import logging -import cv2 -import numpy as np -from lib.model.masks import dfl_full - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Convert(): - """ Swap a source face with a target """ - def __init__(self, encoder, model, arguments): - logger.debug("Initializing %s: (encoder: '%s', model: %s, arguments: %s", - self.__class__.__name__, encoder, model, arguments) - self.encoder = encoder - self.args = arguments - self.input_size = model.input_shape[0] - self.training_size = model.state.training_size - self.training_coverage_ratio = model.training_opts["coverage_ratio"] - self.input_mask_shape = model.state.mask_shapes[0] if model.state.mask_shapes else None - - self.crop = None - self.mask = None - logger.debug("Initialized %s", self.__class__.__name__) - - def patch_image(self, image, detected_face): - """ Patch the image """ - logger.trace("Patching image") - image = image.astype('float32') - image_size = (image.shape[1], image.shape[0]) - coverage = int(self.training_coverage_ratio * self.training_size) - padding = (self.training_size - coverage) // 2 - logger.trace("coverage: %s, padding: %s", coverage, padding) - - self.crop = slice(padding, self.training_size - padding) - if not self.mask: # Init the mask on first image - self.mask = Mask(self.args.mask_type, self.training_size, padding, self.crop) - - detected_face.load_aligned(image, size=self.training_size, align_eyes=False) - new_image = self.get_new_image(image, detected_face, coverage, image_size) - image_mask = self.get_image_mask(detected_face, image_size) - patched_face = self.apply_fixes(image, new_image, image_mask, - image_size, detected_face) - - logger.trace("Patched image") - return patched_face - - def get_new_image(self, image, detected_face, coverage, image_size): - """ Get the new face from the predictor """ - logger.trace("coverage: %s", coverage) - src_face = detected_face.aligned_face - coverage_face = src_face[self.crop, self.crop] - old_face = coverage_face.copy() - coverage_face = cv2.resize(coverage_face, # pylint: disable=no-member - (self.input_size, self.input_size), - interpolation=cv2.INTER_AREA) # pylint: disable=no-member - coverage_face = np.expand_dims(coverage_face, 0) - np.clip(coverage_face / 255.0, 0.0, 1.0, out=coverage_face) - - if self.input_mask_shape: - mask = np.zeros(self.input_mask_shape, np.float32) - mask = np.expand_dims(mask, 0) - feed = [coverage_face, mask] - else: - feed = [coverage_face] - logger.trace("Input shapes: %s", [item.shape for item in feed]) - new_face = self.encoder(feed)[0] - new_face = new_face.squeeze() - logger.trace("Output shape: %s", new_face.shape) - - new_face = cv2.resize(new_face, # pylint: disable=no-member - (coverage, coverage), - interpolation=cv2.INTER_CUBIC) # pylint: disable=no-member - np.clip(new_face * 255.0, 0.0, 255.0, out=new_face) - - if self.args.smooth_box: - self.smooth_box(old_face, new_face) - - src_face[self.crop, self.crop] = new_face - background = image.copy() - interpolator = detected_face.adjusted_interpolators[1] - new_image = cv2.warpAffine( # pylint: disable=no-member - src_face, - detected_face.adjusted_matrix, - image_size, - background, - flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member - borderMode=cv2.BORDER_TRANSPARENT) # pylint: disable=no-member - return new_image - - @staticmethod - def smooth_box(old_face, new_face): - """ Perform gaussian blur on the edges of the output rect """ - height = new_face.shape[0] - crop = slice(0, height) - erode = slice(height // 15, -height // 15) - sigma = height / 16 # 10 for the default 160 size - window = int(np.ceil(sigma * 3.0)) - window = window + 1 if window % 2 == 0 else window - mask = np.zeros_like(new_face) - mask[erode, erode] = 1.0 - mask = cv2.GaussianBlur(mask, # pylint: disable=no-member - (window, window), - sigma) - new_face[crop, crop] = (mask * new_face + (1.0 - mask ) * old_face) - - def get_image_mask(self, detected_face, image_size): - """ Get the image mask """ - mask = self.mask.get_mask(detected_face, image_size) - if self.args.erosion_size != 0: - kwargs = {'src': mask, - 'kernel': self.set_erosion_kernel(mask), - 'iterations': 1} - if self.args.erosion_size > 0: - mask = cv2.erode(**kwargs) # pylint: disable=no-member - else: - mask = cv2.dilate(**kwargs) # pylint: disable=no-member - - if self.args.blur_size != 0: - blur_size = self.set_blur_size(mask) - mask = cv2.blur(mask, (blur_size, blur_size)) # pylint: disable=no-member - - return np.clip(mask, 0.0, 1.0, out=mask) - - def set_erosion_kernel(self, mask): - """ Set the erosion kernel """ - erosion_ratio = self.args.erosion_size / 100 - mask_radius = np.sqrt(np.sum(mask)) / 2 - percent_erode = max(1, int(abs(erosion_ratio * mask_radius))) - erosion_kernel = cv2.getStructuringElement( # pylint: disable=no-member - cv2.MORPH_ELLIPSE, # pylint: disable=no-member - (percent_erode, percent_erode)) - logger.trace("erosion_kernel shape: %s", erosion_kernel.shape) - return erosion_kernel - - def set_blur_size(self, mask): - """ Set the blur size to absolute or percentage """ - blur_ratio = self.args.blur_size / 100 - mask_radius = np.sqrt(np.sum(mask)) / 2 - blur_size = int(max(1, blur_ratio * mask_radius)) - logger.trace("blur_size: %s", blur_size) - return blur_size - - def apply_fixes(self, frame, new_image, image_mask, image_size, detected_face): - """ Apply fixes """ - - if self.args.sharpen_image is not None and self.args.sharpen_image.lower() != "none": - np.clip(new_image, 0.0, 255.0, out=new_image) - if self.args.sharpen_image == "box_filter": - kernel = np.ones((3, 3)) * (-1) - kernel[1, 1] = 9 - new_image = cv2.filter2D(new_image, -1, kernel) # pylint: disable=no-member - elif self.args.sharpen_image == "gaussian_filter": - blur = cv2.GaussianBlur(new_image, (0, 0), 3.0) # pylint: disable=no-member - new_image = cv2.addWeighted(new_image, # pylint: disable=no-member - 1.5, - blur, - -0.5, - 0, - new_image) - - if self.args.avg_color_adjust: - for _ in [0, 1]: - np.clip(new_image, 0.0, 255.0, out=new_image) - diff = frame - new_image - avg_diff = np.sum(diff * image_mask, axis=(0, 1)) - adjustment = avg_diff / np.sum(image_mask, axis=(0, 1)) - new_image = new_image + adjustment - - if self.args.match_histogram: - np.clip(new_image, 0.0, 255.0, out=new_image) - new_image = self.color_hist_match(new_image, frame, image_mask) - - if self.args.seamless_clone: - h, w, _ = frame.shape - h = h // 2 - w = w // 2 - - y_indices, x_indices, _ = np.nonzero(image_mask) - y_crop = slice(np.min(y_indices), np.max(y_indices)) - x_crop = slice(np.min(x_indices), np.max(x_indices)) - y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2) + h) - x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2) + w) - - ''' - # test with average of centroid rather than the h /2 , w/2 center - y_center = int(np.rint(np.average(y_indices) + h) - x_center = int(np.rint(np.average(x_indices) + w) - ''' - - insertion = np.rint(new_image[y_crop, x_crop, :]).astype('uint8') - insertion_mask = image_mask[y_crop, x_crop, :] - insertion_mask[insertion_mask != 0] = 255 - insertion_mask = insertion_mask.astype('uint8') - - prior = np.pad(frame, ((h, h), (w, w), (0, 0)), 'constant').astype('uint8') - - blended = cv2.seamlessClone(insertion, # pylint: disable=no-member - prior, - insertion_mask, - (x_center, y_center), - cv2.NORMAL_CLONE) # pylint: disable=no-member - blended = blended[h:-h, w:-w] - - else: - foreground = new_image * image_mask - background = frame * (1.0 - image_mask) - blended = foreground + background - - np.clip(blended, 0.0, 255.0, out=blended) - - if self.args.draw_transparent: - # Adding a 4th channel should happen after all other channel operations - - # Add mask as 4th channel for saving as alpha on supported output formats - new_image = dfl_full(detected_face.landmarks_as_xy, blended, channels=4 ) - image_mask = cv2.cvtColor(image_mask, cv2.COLOR_RGB2RGBA) - image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA) - - return np.rint(blended).astype('uint8') - - def color_hist_match(self, new, frame, image_mask): - for channel in [0, 1, 2]: - new[:, :, channel] = self.hist_match(new[:, :, channel], - frame[:, :, channel], - image_mask[:, :, channel]) - # source = np.stack([self.hist_match(source[:,:,c], target[:,:,c],image_mask[:,:,c]) - # for c in [0,1,2]], - # axis=2) - return new - - def hist_match(self, new, frame, image_mask): - - mask_indices = np.nonzero(image_mask) - if len(mask_indices[0]) == 0: - return new - - m_new = new[mask_indices].ravel() - m_frame = frame[mask_indices].ravel() - s_values, bin_idx, s_counts = np.unique(m_new, return_inverse=True, return_counts=True) - t_values, t_counts = np.unique(m_frame, return_counts=True) - s_quants = np.cumsum(s_counts, dtype='float32') - t_quants = np.cumsum(t_counts, dtype='float32') - s_quants /= s_quants[-1] # cdf - t_quants /= t_quants[-1] # cdf - interp_s_values = np.interp(s_quants, t_quants, t_values) - new.put(mask_indices, interp_s_values[bin_idx]) - - ''' - bins = np.arange(256) - template_CDF, _ = np.histogram(m_frame, bins=bins, density=True) - flat_new_image = np.interp(m_source.ravel(), bins[:-1], template_CDF) * 255.0 - return flat_new_image.reshape(m_source.shape) * 255.0 - ''' - - return new - - -class Mask(): - """ Return the requested mask """ - - def __init__(self, mask_type, training_size, padding, crop): - """ Set requested mask """ - logger.debug("Initializing %s: (mask_type: '%s', training_size: %s, padding: %s)", - self.__class__.__name__, mask_type, training_size, padding) - - self.training_size = training_size - self.padding = padding - self.mask_type = mask_type - self.crop = crop - - logger.debug("Initialized %s", self.__class__.__name__) - - def get_mask(self, detected_face, image_size): - """ Return a face mask """ - kwargs = {"matrix": detected_face.adjusted_matrix, - "interpolators": detected_face.adjusted_interpolators, - "landmarks": detected_face.landmarks_as_xy, - "image_size": image_size} - logger.trace("kwargs: %s", kwargs) - mask = getattr(self, self.mask_type)(**kwargs) - mask = self.finalize_mask(mask) - logger.trace("mask shape: %s", mask.shape) - return mask - - def cnn(self, **kwargs): - """ CNN Mask """ - # Insert FCN-VGG16 segmentation mask model here - logger.info("cnn not yet implemented, using facehull instead") - return self.facehull(**kwargs) - - def rect(self, **kwargs): - """ Namespace for rect mask. This is the same as 'none' in the cli """ - return self.none(**kwargs) - - def none(self, **kwargs): - """ Rect Mask """ - logger.trace("Getting mask") - interpolator = kwargs["interpolators"][1] - ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32') - mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') - # central_core = slice(self.padding, -self.padding) - ones[self.crop, self.crop] = 1.0 - cv2.warpAffine(ones, # pylint: disable=no-member - kwargs["matrix"], - kwargs["image_size"], - mask, - flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member - borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member - borderValue=0.0) - return mask - - def dfl(self, **kwargs): - """ DFaker Mask """ - logger.trace("Getting mask") - dummy = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') - mask = dfl_full(kwargs["landmarks"], dummy, channels=3) - mask = self.intersect_rect(mask, **kwargs) - return mask - - def facehull(self, **kwargs): - """ Facehull Mask """ - logger.trace("Getting mask") - mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') - hull = cv2.convexHull( # pylint: disable=no-member - np.array(kwargs["landmarks"]).reshape((-1, 2))) - cv2.fillConvexPoly(mask, # pylint: disable=no-member - hull, - (1.0, 1.0, 1.0), - lineType=cv2.LINE_AA) # pylint: disable=no-member - mask = self.intersect_rect(mask, **kwargs) - return mask - - def ellipse(self, **kwargs): - """ Ellipse Mask """ - logger.trace("Getting mask") - mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') - ell = cv2.fitEllipse( # pylint: disable=no-member - np.array(kwargs["landmarks"]).reshape((-1, 2))) - cv2.ellipse(mask, # pylint: disable=no-member - box=ell, - color=(1.0, 1.0, 1.0), - thickness=-1) - return mask - - def intersect_rect(self, hull_mask, **kwargs): - """ Intersect the given hull mask with the roi """ - logger.trace("Intersecting rect") - mask = self.rect(**kwargs) - mask *= hull_mask - return mask - - @staticmethod - def finalize_mask(mask): - """ Finalize the mask """ - logger.trace("Finalizing mask") - np.nan_to_num(mask, copy=False) - np.clip(mask, 0.0, 1.0, out=mask) - return mask diff --git a/plugins/convert/scaling/__init__.py b/plugins/convert/scaling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/convert/scaling/_base.py b/plugins/convert/scaling/_base.py new file mode 100644 index 0000000000..db6f74407d --- /dev/null +++ b/plugins/convert/scaling/_base.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" Parent class for scaling Adjustments for faceswap.py converter """ + +import logging +import numpy as np + +from lib.logger import parse_class_init +from plugins.convert import convert_config + +logger = logging.getLogger(__name__) + + +class Adjustment(): + """ Parent class for scaling adjustments """ + def __init__(self, configfile=None): + logger.debug(parse_class_init(locals())) + convert_config.load_config(config_file=configfile) + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self, new_face): + """ Override for specific scaling adjustment process """ + raise NotImplementedError + + def run(self, new_face): + """ Perform selected adjustment on face """ + # pylint:disable=duplicate-code + logger.trace("Performing scaling adjustment") # type:ignore[attr-defined] + # Remove Mask for processing + reinsert_mask = False + final_mask = None + if new_face.shape[2] == 4: + reinsert_mask = True + final_mask = new_face[:, :, -1] + new_face = new_face[:, :, :3] + new_face = self.process(new_face) + new_face = np.clip(new_face, 0.0, 1.0) + if reinsert_mask and new_face.shape[2] != 4: + # Reinsert Mask + assert final_mask is not None + new_face = np.concatenate((new_face, np.expand_dims(final_mask, axis=-1)), -1) + logger.trace("Performed scaling adjustment") # type:ignore[attr-defined] + return new_face diff --git a/plugins/convert/scaling/sharpen.py b/plugins/convert/scaling/sharpen.py new file mode 100644 index 0000000000..158165b873 --- /dev/null +++ b/plugins/convert/scaling/sharpen.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +""" Sharpening for enlarged face for faceswap.py converter """ +import cv2 +import numpy as np + +from lib.utils import get_module_objects + +from ._base import Adjustment, logger +from . import sharpen_defaults as cfg + + +class Scaling(Adjustment): + """ Sharpening Adjustments for the face applied after warp to final frame """ + + def process(self, new_face: np.ndarray) -> np.ndarray: + """ Sharpen using the requested technique + + Parameters + ---------- + new_face : :class:`numpy.ndarray` + A batch of swapped image patch that is to have sharpening applied + + Returns + ------- + :class:`numpy.ndarray` + The batch of swapped faces with sharpening applied + """ + if cfg.method() == "none": + return new_face + amount = cfg.amount() / 100.0 + kernel, radius = self.get_kernel_size(new_face, cfg.radius()) + new_face = getattr(self, cfg.method())(new_face, kernel, radius, amount) + return new_face + + @classmethod + def get_kernel_size(cls, + new_face: np.ndarray, + radius_percent: float) -> tuple[tuple[int, int], int]: + """ Return the kernel size and central point for the given radius + relative to frame width. + + Parameters + ---------- + new_face : :class:`numpy.ndarray` + The swapped image patch that is to have sharpening applied + + radius_percent : float + The percentage of the image size to use as the sharpening kernel + + Returns + ------- + kernel_size : tuple[int, int] + The sharpening kernel + radius : int + The pixel radius the kernel + """ + radius = max(1, round(new_face.shape[1] * radius_percent / 100)) + kernel_size = int((radius * 2) + 1) + full_kernel_size = (kernel_size, kernel_size) + logger.trace(kernel_size) # type:ignore[attr-defined] + return full_kernel_size, radius + + @classmethod + def box(cls, + new_face: np.ndarray, + kernel_size: tuple[int, int], + radius: int, + amount: float) -> np.ndarray: + """ Sharpen using box filter + + Parameters + ---------- + new_face : :class:`numpy.ndarray` + The batch of swapped image patches that is to have sharpening applied + kernel_size : tuple[int, int] + The sharpening kernel size + radius : int + The pixel radius the kernel + amount : float + The amount of sharpening to apply + + Returns + ------- + :class:`numpy.ndarray` + The batch of swapped faces with box sharpening applied + """ + kernel: np.ndarray = np.zeros(kernel_size, dtype="float32") + kernel[radius, radius] = 1.0 + box_filter = np.ones(kernel_size, dtype="float32") / kernel_size[0]**2 + kernel = kernel + (kernel - box_filter) * amount + new_face = cv2.filter2D(new_face, -1, kernel) + return new_face + + @classmethod + def gaussian(cls, + new_face: np.ndarray, + kernel_size: tuple[int, int], + radius: float, # pylint:disable=unused-argument + amount: float) -> np.ndarray: + """ Sharpen using gaussian filter + + Parameters + ---------- + new_face : :class:`numpy.ndarray` + The batch of swapped image patches that is to have sharpening applied + kernel_size : tuple[int, int] + The sharpening kernel size + radius : int + The pixel radius the kernel. Unused + amount : float + The amount of sharpening to apply + + Returns + ------- + :class:`numpy.ndarray` + The batch of swapped faces with gaussian sharpening applied + """ + blur = cv2.GaussianBlur(new_face, kernel_size, 0) + new_face = cv2.addWeighted(new_face, + 1.0 + (0.5 * amount), + blur, + -(0.5 * amount), + 0) + return new_face + + @classmethod + def unsharp_mask(cls, + new_face: np.ndarray, + kernel_size: tuple[int, int], + center: float, # pylint:disable=unused-argument + amount: float) -> np.ndarray: + """ Sharpen using unsharp mask + + Parameters + ---------- + new_face : :class:`numpy.ndarray` + The batch of swapped image patches that is to have sharpening applied + kernel_size : tuple[int, int] + The sharpening kernel size + radius : int + The pixel radius the kernel. Unused + amount : float + The amount of sharpening to apply + + Returns + ------- + :class:`numpy.ndarray` + The batch of swapped faces with unsharp-mask sharpening applied + """ + threshold = cfg.threshold() / 255.0 + blur = cv2.GaussianBlur(new_face, kernel_size, 0) + low_contrast_mask = (abs(new_face - blur) < threshold).astype("float32") + sharpened = (new_face * (1.0 + amount)) + (blur * -amount) + new_face = (new_face * (1.0 - low_contrast_mask)) + (sharpened * low_contrast_mask) + return new_face + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/scaling/sharpen_defaults.py b/plugins/convert/scaling/sharpen_defaults.py new file mode 100755 index 0000000000..bd0adca59a --- /dev/null +++ b/plugins/convert/scaling/sharpen_defaults.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Sharpen Scaling plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Options for sharpening the face after placement" + + +method = ConfigItem( + datatype=str, + default="none", + group="sharpen type", + info="The type of sharpening to use:" + "\n\t none: Don't perform any sharpening." + "\n\t box: Fastest, but weakest method. Uses a box filter to assess edges." + "\n\t gaussian: Slower, but better than box. Uses a gaussian filter to assess edges." + "\n\t unsharp-mask: Slowest, but most tweakable. Uses the unsharp-mask method to " + "assess edges.", + choices=["none", "box", "gaussian", "unsharp_mask"], + gui_radio=True) + +amount = ConfigItem( + datatype=int, + default=150, + group="settings", + info="Percentage that controls the magnitude of each overshoot (how much darker and how " + "much lighter the edge borders become).\nThis can also be thought of as how much " + "contrast is added at the edges. It does not affect the width of the edge rims.", + rounding=1, + min_max=(100, 500)) + +radius = ConfigItem( + datatype=float, + default=0.3, + group="settings", + info="Affects the size of the edges to be enhanced or how wide the edge rims become, so a " + "smaller radius enhances smaller-scale detail.\nRadius is set as a percentage of the " + "final frame width and rounded to the nearest pixel. E.g for a 1280 width frame, a " + "0.6 percenatage will give a radius of 8px.\nHigher radius values can cause halos at " + "the edges, a detectable faint light rim around objects. Fine detail needs a smaller " + "radius. \nRadius and amount interact; reducing one allows more of the other.", + rounding=1, + min_max=(0.1, 5.0)) + +threshold = ConfigItem( + datatype=float, + default=5.0, + group="settings", + info="[unsharp_mask only] Controls the minimal brightness change that will be sharpened " + "or how far apart adjacent tonal values have to be before the filter does anything.\n" + "This lack of action is important to prevent smooth areas from becoming speckled. " + "The threshold setting can be used to sharpen more pronounced edges, while leaving " + "subtler edges untouched. \nLow values should sharpen more because fewer areas are " + "excluded. \nHigher threshold values exclude areas of lower contrast.", + rounding=1, + min_max=(1.0, 10.0)) diff --git a/plugins/convert/writer/__init__.py b/plugins/convert/writer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/convert/writer/_base.py b/plugins/convert/writer/_base.py new file mode 100644 index 0000000000..0cce2cf748 --- /dev/null +++ b/plugins/convert/writer/_base.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +""" Parent class for output writers for faceswap.py converter """ + +import logging +import os +import re +import typing as T + +import numpy as np + +from lib.logger import parse_class_init +from plugins.convert import convert_config + +logger = logging.getLogger(__name__) + + +class Output(): + """ Parent class for writer plugins. + + Parameters + ---------- + output_folder: str + The full path to the output folder where the converted media should be saved + configfile: str, optional + The full path to a custom configuration ini file. If ``None`` is passed + then the file is loaded from the default location. Default: ``None``. + """ + def __init__(self, output_folder: str, configfile: str | None = None) -> None: + logger.debug(parse_class_init(locals())) + convert_config.load_config(config_file=configfile) + self.output_folder: str = output_folder + + # For creating subfolders when separate mask is selected + self._subfolders_created: bool = False + + # Methods for making sure frames are written out in frame order + self.re_search = re.compile(r"(\d+)(?=\.\w+$)") # Identify frame numbers + self.cache: dict = {} # Cache for when frames must be written in correct order + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def is_stream(self) -> bool: + """ bool: Whether the writer outputs a stream or a series images. + + Writers that write to a stream have a frame_order paramater to dictate + the order in which frames should be written out (eg. gif/ffmpeg) """ + retval = hasattr(self, "_frame_order") + return retval + + @property + def output_alpha(self) -> bool: + """ bool : Override if the plugin can output an alpha channel and the user configuration + option is set to use it. Default ``False`` """ + return False + + @classmethod + def _set_frame_order(cls, + total_count: int, + frame_ranges: list[tuple[int, int]] | None) -> list[int]: + """ Obtain the full list of frames to be converted in order. + + Used for FFMPEG and Gif writers to ensure correct frame order + + Parameters + ---------- + total_count: int + The total number of frames to be converted + frame_ranges: list or ``None`` + List of tuples for starting and end values of each frame range to be converted or + ``None`` if all frames are to be converted + + Returns + ------- + list + Full list of all frame indices to be converted + """ + if frame_ranges is None: + retval = list(range(1, total_count + 1)) + else: + retval = [] + for rng in frame_ranges: + retval.extend(list(range(rng[0], rng[1] + 1))) + logger.debug("frame_order: %s", retval) + return retval + + def get_output_filename(self, + filename: str, + extension: str, + separate_mask: bool = False) -> list[str]: + """ Obtain the full path for the output file, including the correct extension, for the + given input filename. + + Parameters + ---------- + filename : str + The input frame filename to generate the output file name for + extension : str + The extension to use for the output file + separate_mask: bool, optional + ``True`` if the mask should be saved out to a sub-folder otherwise ``False`` + + Returns + ------- + list + The full path for the output converted frame to be saved to in position 1. The full + path for the mask to be output to in position 2 (if requested) + """ + extension = extension.strip(".") + filename = os.path.splitext(os.path.basename(filename))[0] + out_filename = f"{filename}.{extension}" + retval = [os.path.join(self.output_folder, out_filename)] + if separate_mask: + retval.append(os.path.join(self.output_folder, "masks", out_filename)) + + if separate_mask and not self._subfolders_created: + locations = [os.path.dirname(loc) for loc in retval] + logger.debug("Creating sub-folders: %s", locations) + for location in locations: + os.makedirs(location, exist_ok=True) + + logger.trace("in filename: '%s', out filename: '%s'", filename, retval) # type:ignore + return retval + + def cache_frame(self, filename: str, image: np.ndarray) -> None: + """ Add the incoming converted frame to the cache ready for writing out. + + Used for ffmpeg and gif writers to ensure that the frames are written out in the correct + order. + + Parameters + ---------- + filename: str + The filename of the incoming frame, where the frame index can be extracted from + image: class:`numpy.ndarray` + The converted frame corresponding to the given filename + """ + re_frame = re.search(self.re_search, filename) + assert re_frame is not None + frame_no = int(re_frame.group()) + self.cache[frame_no] = image + logger.trace("Added to cache. Frame no: %s", frame_no) # type: ignore + logger.trace("Current cache: %s", sorted(self.cache.keys())) # type:ignore + + def write(self, filename: str, image: T.Any) -> None: + """ Override for specific frame writing method. + + Parameters + ---------- + filename: str + The incoming frame filename. + image: Any + The converted image to be written. Could be a numpy array, a bytes encoded image or + any other plugin specific format + """ + raise NotImplementedError + + def pre_encode(self, image: np.ndarray, **kwargs) -> T.Any: # pylint:disable=unused-argument + """ Some writer plugins support the pre-encoding of images prior to saving out. As + patching is done in multiple threads, but writing is done in a single thread, it can + speed up the process to do any pre-encoding as part of the converter process. + + If the writer supports pre-encoding then override this to pre-encode the image in + :mod:`lib.convert` to speed up saving. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The converted image that is to be run through the pre-encoding function + + Returns + ------- + Any or ``None`` + If ``None`` then the writer does not support pre-encoding, otherwise return output of + the plugin specific pre-enccode function + """ + return None + + def close(self) -> None: + """ Override for specific converted frame writing close methods """ + raise NotImplementedError diff --git a/plugins/convert/writer/ffmpeg.py b/plugins/convert/writer/ffmpeg.py new file mode 100644 index 0000000000..50d7407c68 --- /dev/null +++ b/plugins/convert/writer/ffmpeg.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" Video output writer for faceswap.py converter """ +from __future__ import annotations +import os +import typing as T + +from math import ceil +from subprocess import CalledProcessError, check_output, STDOUT + +import imageio +import imageio_ffmpeg as im_ffm +import numpy as np + +from lib.utils import get_module_objects + +from ._base import Output, logger +from . import ffmpeg_defaults as cfg + +if T.TYPE_CHECKING: + from collections.abc import Generator + + +class Writer(Output): + """ Video output writer using imageio-ffmpeg. + + Parameters + ---------- + output_folder: str + The folder to save the output video to + total_count: int + The total number of frames to be converted + frame_ranges: list or ``None`` + List of tuples for starting and end values of each frame range to be converted or ``None`` + if all frames are to be converted + source_video: str + The full path to the source video for obtaining fps and audio + kwargs: dict + Any additional standard :class:`plugins.convert.writer._base.Output` key word arguments. + """ + def __init__(self, + output_folder: str, + total_count: int, + frame_ranges: list[tuple[int, int]] | None, + source_video: str, + **kwargs) -> None: + super().__init__(output_folder, **kwargs) + logger.debug("total_count: %s, frame_ranges: %s, source_video: '%s'", + total_count, frame_ranges, source_video) + self._source_video: str = source_video + self._output_filename: str = self._get_output_filename() + self._frame_ranges: list[tuple[int, int]] | None = frame_ranges + self._frame_order: list[int] = self._set_frame_order(total_count, frame_ranges) + self._output_dimensions: str | None = None # Fix dims on 1st received frame + # Need to know dimensions of first frame, so set writer then + self._writer: Generator[None, np.ndarray, None] | None = None + + @property + def _valid_tunes(self) -> dict: + """ dict: Valid tune selections for libx264 and libx265 codecs. """ + return {"libx264": ["film", "animation", "grain", "stillimage", "fastdecode", + "zerolatency"], + "libx265": ["grain", "fastdecode", "zerolatency"]} + + @property + def _video_fps(self) -> float: + """ float: The fps of the source video. """ + reader = imageio.get_reader(self._source_video, "ffmpeg") # type:ignore[arg-type] + retval = reader.get_meta_data()["fps"] + reader.close() + logger.debug(retval) + return retval + + @property + def _output_params(self) -> list[str]: + """ list: The FFMPEG Output parameters """ + codec = cfg.codec() + tune = cfg.tune() + # Force all frames to the same size + output_args = ["-vf", f"scale={self._output_dimensions}"] + + output_args.extend(["-crf", str(cfg.crf())]) + output_args.extend(["-preset", cfg.preset()]) + + if tune is not None and tune in self._valid_tunes[codec]: + output_args.extend(["-tune", tune]) + + if codec == "libx264" and cfg.profile() != "auto": + output_args.extend(["-profile:v", cfg.profile()]) + + if codec == "libx264" and cfg.level() != "auto": + output_args.extend(["-level", cfg.level()]) + + logger.debug(output_args) + return output_args + + @property + def _audio_codec(self) -> str | None: + """ str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default) + or ``None`` if skip muxing has been selected in configuration options, or if frame ranges + have been passed in the command line arguments. """ + retval: str | None = "copy" + if cfg.skip_mux(): + logger.info("Skipping audio muxing due to configuration settings.") + retval = None + elif self._frame_ranges is not None: + logger.warning("Muxing audio is not supported for limited frame ranges." + "The output video will be created but you will need to mux audio " + "manually.") + retval = None + elif not self._test_for_audio_stream(): + logger.warning("No audio stream could be found in the source video '%s'. Muxing audio " + "will be disabled.", self._source_video) + retval = None + logger.debug("Audio codec: %s", retval) + return retval + + def _test_for_audio_stream(self) -> bool: + """ Check whether the source video file contains an audio stream. + + If we attempt to mux audio from a source video that does not contain an audio stream + ffmpeg will crash faceswap in a fairly ugly manner. + + Returns + ------- + bool + ``True`` if an audio stream is found in the source video file, otherwise ``False`` + + Raises + ------ + ValueError + If a subprocess error is raised scanning the input video file + """ + exe = im_ffm.get_ffmpeg_exe() + cmd = [exe, "-hide_banner", "-i", self._source_video, "-f", "ffmetadata", "-"] + + try: + out = check_output(cmd, stderr=STDOUT) + except CalledProcessError as err: + err_out = err.output.decode(errors="ignore") + msg = f"Error checking audio stream. Status: {err.returncode}\n{err_out}" + raise ValueError(msg) from err + + retval = False + for line in out.splitlines(): + if not line.strip().startswith(b"Stream #"): + continue + logger.debug("scanning Stream line: %s", line.decode(errors="ignore").strip()) + if b"Audio" in line: + retval = True + break + logger.debug("Audio found: %s", retval) + return retval + + def _get_output_filename(self) -> str: + """ Return full path to video output file. + + The filename is the same as the input video with `"_converted"` appended to the end. The + file extension is as selected in the plugin settings. If a file already exists with the + given filename, then `"_1"` is appended to the end of the filename. This number iterates + until a valid filename that does not exist is found. + + Returns + ------- + str + The full path to the output video filename + """ + filename = os.path.basename(self._source_video) + filename = os.path.splitext(filename)[0] + ext = cfg.container() + idx = 0 + while True: + out_file = f"{filename}_converted{'' if idx == 0 else f'_{idx}'}.{ext}" + retval = os.path.join(self.output_folder, out_file) + if not os.path.exists(retval): + break + idx += 1 + logger.info("Outputting to: '%s'", retval) + return retval + + def _get_writer(self, frame_dims: tuple[int, int]) -> Generator[None, np.ndarray, None]: + """ Add the requested encoding options and return the writer. + + Parameters + ---------- + frame_dims: tuple + The (rows, colums) shape of the input image + + Returns + ------- + generator + The imageio ffmpeg writer + """ + audio_codec = self._audio_codec + audio_path = None if audio_codec is None else self._source_video + logger.debug("writer audio_path: '%s'", audio_path) + + retval = im_ffm.write_frames(self._output_filename, + size=(frame_dims[1], frame_dims[0]), + fps=self._video_fps, + quality=None, + codec=cfg.codec(), + macro_block_size=8, + ffmpeg_log_level="error", + ffmpeg_timeout=10, + output_params=self._output_params, + audio_path=audio_path, + audio_codec=audio_codec) + logger.debug("FFMPEG Writer created: %s", retval) + retval.send(None) + + return retval + + def write(self, filename: str, image: np.ndarray) -> None: + """ Frames come from the pool in arbitrary order, so frames are cached for writing out + in the correct order. + + Parameters + ---------- + filename: str + The incoming frame filename. + image: :class:`numpy.ndarray` + The converted image to be written + """ + logger.trace("Received frame: (filename: '%s', shape: %s", # type:ignore[attr-defined] + filename, image.shape) + if not self._output_dimensions: + input_dims = T.cast(tuple[int, int], image.shape[:2]) + self._set_dimensions(input_dims) + self._writer = self._get_writer(input_dims) + self.cache_frame(filename, image) + self._save_from_cache() + + def _set_dimensions(self, frame_dims: tuple[int, int]) -> None: + """ Set the attribute :attr:`_output_dimensions` based on the first frame received. + This protects against different sized images coming in and ensures all images are written + to ffmpeg at the same size. Dimensions are mapped to a macro block size 8. + + Parameters + ---------- + frame_dims: tuple + The (rows, colums) shape of the input image + """ + logger.debug("input dimensions: %s", frame_dims) + self._output_dimensions = (f"{int(ceil(frame_dims[1] / 8) * 8)}:" + f"{int(ceil(frame_dims[0] / 8) * 8)}") + logger.debug("Set dimensions: %s", self._output_dimensions) + + def _save_from_cache(self) -> None: + """ Writes any consecutive frames to the video container that are ready to be output + from the cache. """ + assert self._writer is not None + while self._frame_order: + if self._frame_order[0] not in self.cache: + logger.trace("Next frame not ready. Continuing") # type:ignore[attr-defined] + break + save_no = self._frame_order.pop(0) + save_image = self.cache.pop(save_no) + logger.trace("Rendering from cache. Frame no: %s", # type:ignore[attr-defined] + save_no) + self._writer.send(np.ascontiguousarray(save_image[:, :, ::-1])) + logger.trace("Current cache size: %s", len(self.cache)) # type:ignore[attr-defined] + + def close(self) -> None: + """ Close the ffmpeg writer and mux the audio """ + if self._writer is not None: + self._writer.close() + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/writer/ffmpeg_defaults.py b/plugins/convert/writer/ffmpeg_defaults.py new file mode 100755 index 0000000000..3ea4e820cc --- /dev/null +++ b/plugins/convert/writer/ffmpeg_defaults.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Ffmpeg Writer plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Options for encoding converted frames to video." + + +container = ConfigItem( + datatype=str, + default="mp4", + group="codec", + info="Video container to use.", + choices=["avi", "flv", "mkv", "mov", "mp4", "mpeg", "webm"], + gui_radio=True) + +codec = ConfigItem( + datatype=str, + default="libx264", + group="codec", + info="Video codec to use:" + "\n\t libx264: H.264. A widely supported and commonly used codec." + "\n\t libx265: H.265 / HEVC video encoder application library.", + choices=["libx264", "libx265"], + gui_radio=True) + +crf = ConfigItem( + datatype=int, + default=23, + group="quality", + info="Constant Rate Factor: 0 is lossless and 51 is worst quality possible. A " + "lower value generally leads to higher quality, and a subjectively sane range " + "is 17-28. Consider 17 or 18 to be visually lossless or nearly so; it should " + "look the same or nearly the same as the input but it isn't technically " + "lossless.\nThe range is exponential, so increasing the CRF value +6 results " + "in roughly half the bitrate / file size, while -6 leads to roughly twice the " + "bitrate.", + rounding=1, + min_max=(0, 51)) + +preset = ConfigItem( + datatype=str, + default="medium", + group="quality", + info="A preset is a collection of options that will provide a certain encoding " + "speed to compression ratio.\nA slower preset will provide better compression " + "(compression is quality per filesize).\nUse the slowest preset that you have " + "patience for.", + choices=["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", + "slower", "veryslow"], + gui_radio=True) + +tune = ConfigItem( + datatype=str, + default="none", + group="settings", + info="Change settings based upon the specifics of your input:" + "\n\t none: Don't perform any additional tuning." + "\n\t film: [H.264 only] Use for high quality movie content; lowers deblocking." + "\n\t animation: [H.264 only] Good for cartoons; uses higher deblocking and more " + "reference frames." + "\n\t grain: Preserves the grain structure in old, grainy film material." + "\n\t stillimage: [H.264 only] Good for slideshow-like content." + "\n\t fastdecode: Allows faster decoding by disabling certain filters." + "\n\t zerolatency: Good for fast encoding and low-latency streaming.", + choices=["none", "film", "animation", "grain", "stillimage", "fastdecode", "zerolatency"]) + +profile = ConfigItem( + datatype=str, + default="auto", + group="settings", + info="[H.264 Only] Limit the output to a specific H.264 profile. Don't change this " + "unless your target device only supports a certain profile.", + choices=["auto", "baseline", "main", "high", "high10", "high422", "high444"]) + +level = ConfigItem( + datatype=str, + default="auto", + group="settings", + info="[H.264 Only] Set the encoder level, Don't change this unless your target " + "device only supports a certain level.", + choices=["auto", "1", "1b", "1.1", "1.2", "1.3", "2", "2.1", "2.2", "3", "3.1", "3.2", "4", + "4.1", "4.2", "5", "5.1", "5.2", "6", "6.1", "6.2"]) + +skip_mux = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Skip muxing audio to the final video output. This will result in a video without an " + "audio track.") diff --git a/plugins/convert/writer/gif.py b/plugins/convert/writer/gif.py new file mode 100644 index 0000000000..d00171f196 --- /dev/null +++ b/plugins/convert/writer/gif.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +""" Animated GIF writer for faceswap.py converter """ +from __future__ import annotations +import os +import typing as T + +import cv2 +import imageio + +from lib.utils import get_module_objects + +from ._base import Output, logger +from . import gif_defaults as cfg + +if T.TYPE_CHECKING: + from imageio.core import format as im_format # noqa:F401 + + +class Writer(Output): + """ GIF output writer using imageio. + + + Parameters + ---------- + output_folder: str + The folder to save the output gif to + total_count: int + The total number of frames to be converted + frame_ranges: list or ``None`` + List of tuples for starting and end values of each frame range to be converted or ``None`` + if all frames are to be converted + kwargs: dict + Any additional standard :class:`plugins.convert.writer._base.Output` key word arguments. + """ + def __init__(self, + output_folder: str, + total_count: int, + frame_ranges: list[tuple[int, int]] | None, + **kwargs) -> None: + logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges) + super().__init__(output_folder, **kwargs) + self._frame_order: list[int] = self._set_frame_order(total_count, frame_ranges) + # Fix dims on 1st received frame + self._output_dimensions: tuple[int, int] | None = None + # Need to know dimensions of first frame, so set writer then + self._writer: imageio.plugins.pillowmulti.GIFFormat.Writer | None = None + self._gif_file: str | None = None # Set filename based on first file seen + + @property + def _gif_params(self) -> dict: + """ dict: The selected gif plugin configuration options. """ + kwargs = {"fps": cfg.fps(), + "loop": cfg.loop(), + "palettesize": cfg.palettesize(), + "subrectangles": cfg.subrectangles()} + logger.debug(kwargs) + return kwargs + + def _get_writer(self) -> im_format.Format.Writer: + """ Obtain the GIF writer with the requested GIF encoding options. + + Returns + ------- + :class:`imageio.plugins.pillowmulti.GIFFormat.Writer` + The imageio GIF writer + """ + assert self._gif_file is not None + return imageio.get_writer(self._gif_file, + mode="i", + **self._gif_params) + + def write(self, filename: str, image) -> None: + """ Frames come from the pool in arbitrary order, so frames are cached for writing out + in the correct order. + + Parameters + ---------- + filename: str + The incoming frame filename. + image: :class:`numpy.ndarray` + The converted image to be written + """ + logger.trace("Received frame: (filename: '%s', shape: %s", # type: ignore + filename, image.shape) + if not self._gif_file: + self._set_gif_filename(filename) + self._set_dimensions(image.shape[:2]) + self._writer = self._get_writer() + if (image.shape[1], image.shape[0]) != self._output_dimensions: + image = cv2.resize(image, self._output_dimensions) # pylint:disable=no-member + self.cache_frame(filename, image) + self._save_from_cache() + + def _set_gif_filename(self, filename: str) -> None: + """ Set the full path to GIF output file to :attr:`_gif_file` + + The filename is the created from the source filename of the first input image received with + `"_converted"` appended to the end and a .gif extension. If a file already exists with the + given filename, then `"_1"` is appended to the end of the filename. This number iterates + until a valid filename that does not exist is found. + + Parameters + ---------- + filename: str + The incoming frame filename. + """ + + logger.debug("sample filename: '%s'", filename) + filename = os.path.splitext(os.path.basename(filename))[0] + snip = len(filename) + for char in list(filename[::-1]): + if not char.isdigit() and char not in ("_", "-"): + break + snip -= 1 + filename = filename[:snip] + + idx = 0 + while True: + out_file = f"{filename}_converted{'' if idx == 0 else f'_{idx}'}.gif" + retval = os.path.join(self.output_folder, out_file) + if not os.path.exists(retval): + break + idx += 1 + + self._gif_file = retval + logger.info("Outputting to: '%s'", self._gif_file) + + def _set_dimensions(self, frame_dims: tuple[int, int]) -> None: + """ Set the attribute :attr:`_output_dimensions` based on the first frame received. This + protects against different sized images coming in and ensure all images get written to the + Gif at the sema dimensions. """ + # pylint:disable=duplicate-code + logger.debug("input dimensions: %s", frame_dims) + self._output_dimensions = (frame_dims[1], frame_dims[0]) + logger.debug("Set dimensions: %s", self._output_dimensions) + + def _save_from_cache(self) -> None: + """ Writes any consecutive frames to the GIF container that are ready to be output + from the cache. """ + # pylint:disable=duplicate-code + assert self._writer is not None + while self._frame_order: + if self._frame_order[0] not in self.cache: + logger.trace("Next frame not ready. Continuing") # type: ignore + break + save_no = self._frame_order.pop(0) + save_image = self.cache.pop(save_no) + logger.trace("Rendering from cache. Frame no: %s", save_no) # type: ignore + self._writer.append_data(save_image[:, :, ::-1]) + logger.trace("Current cache size: %s", len(self.cache)) # type: ignore + + def close(self) -> None: + """ Close the GIF writer on completion. """ + if self._writer is not None: + self._writer.close() + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/writer/gif_defaults.py b/plugins/convert/writer/gif_defaults.py new file mode 100755 index 0000000000..c0dd27c580 --- /dev/null +++ b/plugins/convert/writer/gif_defaults.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Gif Writer plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Options for outputting converted frames to an animated gif." + + +fps = ConfigItem( + datatype=int, + default=25, + group="settings", + info="Frames per Second.", + rounding=1, + min_max=(1, 60)) + +loop = ConfigItem( + datatype=int, + default=0, + group="settings", + info="The number of iterations. Set to 0 to loop indefinitely.", + rounding=1, + min_max=(0, 100)) + +palettesize = ConfigItem( + datatype=str, + default="256", + group="settings", + info="The number of colors to quantize the image to. Is rounded to the nearest power of " + "two.", + choices=["2", "4", "8", "16", "32", "64", "128", "256"]) + +subrectangles = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="If True, will try and optimize the GIF by storing only the rectangular parts of " + "each frame that change with respect to the previous.") diff --git a/plugins/convert/writer/opencv.py b/plugins/convert/writer/opencv.py new file mode 100644 index 0000000000..29752551af --- /dev/null +++ b/plugins/convert/writer/opencv.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" Image output writer for faceswap.py converter + Uses cv2 for writing as in testing this was a lot faster than both Pillow and ImageIO +""" +import typing as T + +import cv2 +import numpy as np + +from lib.utils import get_module_objects +from ._base import Output, logger +from . import opencv_defaults as cfg + + +class Writer(Output): + """ Images output writer using cv2 + + Parameters + ---------- + output_folder: str + The full path to the output folder where the converted media should be saved + configfile: str, optional + The full path to a custom configuration ini file. If ``None`` is passed + then the file is loaded from the default location. Default: ``None``. + """ + def __init__(self, output_folder: str, **kwargs) -> None: + super().__init__(output_folder, **kwargs) + self._extension = f".{cfg.format()}" + self._check_transparency_format() + self._separate_mask = self.output_alpha and cfg.separate_mask() + self._args = self._get_save_args() + + @property + def output_alpha(self) -> bool: + """ bool : OpenCV can output alpha channel. """ + return cfg.draw_transparent() + + def _check_transparency_format(self) -> None: + """ Make sure that the output format is correct if draw_transparent is selected """ + if not self.output_alpha or (self.output_alpha and cfg.format() == "png"): + return + logger.warning("Draw Transparent selected, but the requested format does not support " + "transparency. Changing output format to 'png'") + cfg.format.set("png") + + def _get_save_args(self) -> tuple[int, ...]: + """ Obtain the save parameters for the file format. + + Returns + ------- + tuple + The OpenCV specific arguments for the selected file format + """ + filetype = cfg.format() + args: tuple[int, ...] = tuple() + if filetype == "jpg" and cfg.jpg_quality() > 0: + args = (cv2.IMWRITE_JPEG_QUALITY, + cfg.jpg_quality()) + if filetype == "png" and cfg.png_compress_level() > -1: + args = (cv2.IMWRITE_PNG_COMPRESSION, + cfg.png_compress_level()) + logger.debug(args) + return args + + def write(self, filename: str, image: list[bytes]) -> None: + """ Write out the pre-encoded image to disk. If separate mask has been selected, write out + the encoded mask to a sub-folder in the output directory. + + Parameters + ---------- + filename: str + The full path to write out the image to. + image: list + List of :class:`bytes` objects of length 1 (containing just the image to write out) + or length 2 (containing the image and mask to write out) + """ + logger.trace("Outputting: (filename: '%s'", filename) # type:ignore + filenames = self.get_output_filename(filename, cfg.format(), self._separate_mask) + # pylint:disable=duplicate-code + for fname, img in zip(filenames, image): + try: + with open(fname, "wb") as outfile: + outfile.write(img) + except Exception as err: # pylint:disable=broad-except + logger.error("Failed to save image '%s'. Original Error: %s", filename, err) + + def pre_encode(self, image: np.ndarray, **kwargs) -> list[bytes]: + """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker. + + Parameters + ---------- + image: :class:`numpy.ndarray` + A 3 or 4 channel BGR swapped frame + + Returns + ------- + list + List of :class:`bytes` objects ready for writing. The list will be of length 1 with + image bytes object as the only member unless separate mask has been requested, in which + case it will be length 2 with the image in position 0 and mask in position 1 + """ + logger.trace("Pre-encoding image") # type:ignore + retval = [] + + if self._separate_mask: + mask = image[..., -1] + image = image[..., :3] + + retval.append(cv2.imencode(self._extension, + mask, + self._args)[1]) + + retval.insert(0, cv2.imencode(self._extension, + image, + self._args)[1]) + return T.cast(list[bytes], retval) + + def close(self) -> None: + """ Does nothing as OpenCV writer does not need a close method """ + return + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/writer/opencv_defaults.py b/plugins/convert/writer/opencv_defaults.py new file mode 100755 index 0000000000..61ea6b9feb --- /dev/null +++ b/plugins/convert/writer/opencv_defaults.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Opencv Writer plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "Options for outputting converted frames to a series of images using OpenCV\n" + "OpenCV can be faster than other image writers, but lacks some configuration " + "options and formats." +) + + +format = ConfigItem( # pylint:disable=redefined-builtin + datatype=str, + default="png", + group="format", + info="Image format to use:" + "\n\t bmp: Windows bitmap" + "\n\t jpg: JPEG format" + "\n\t jp2: JPEG 2000 format" + "\n\t png: Portable Network Graphics" + "\n\t ppm: Portable Pixmap Format", + choices=["bmp", "jpg", "jp2", "png", "ppm"], + gui_radio=True) + +draw_transparent = ConfigItem( + datatype=bool, + default=False, + group="format", + info="Place the swapped face on a transparent layer rather than the original frame.\nNB: " + "This is only compatible with images saved in png format. If an incompatible format " + "is selected then the image will be saved as a png.") + +separate_mask = ConfigItem( + datatype=bool, + default=False, + group="format", + info="Seperate the mask into its own single channel image. This only applies when " + "'draw-transparent' is selected. If enabled, the RGB image will be saved into the " + "selected output folder whilst the masks will be saved into a sub-folder named " + "`masks`. If not enabled then the mask will be included in the alpha-channel of the " + "RGBA output.") + +jpg_quality = ConfigItem( + datatype=int, + default=75, + group="compression", + info="[jpg only] Set the jpg quality. 1 is worst 95 is best. Higher quality leads to " + "larger file sizes.", + rounding=1, + min_max=(1, 95)) + +png_compress_level = ConfigItem( + datatype=int, + default=3, + group="compression", + info="[png only] ZLIB compression level, 1 gives best speed, 9 gives best compression, 0 " + "gives no compression at all.", + rounding=1, + min_max=(0, 9)) diff --git a/plugins/convert/writer/patch.py b/plugins/convert/writer/patch.py new file mode 100644 index 0000000000..01f00d7c58 --- /dev/null +++ b/plugins/convert/writer/patch.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +""" Face patch output writer for faceswap.py converter + Extracts the swapped Face Patch from faceswap rather than the final composited frame along with + the transformation matrix for re-inserting the face into the origial frame +""" +import json +import logging +import re +import typing as T + +import os +import cv2 +import numpy as np + +from lib.image import encode_image, png_read_meta, tiff_read_meta +from lib.utils import get_module_objects +from ._base import Output +from . import patch_defaults as cfg + +logger = logging.getLogger(__name__) + + +class Writer(Output): + """ Face patch writer for outputting swapped face patches and transformation matrices + + Parameters + ---------- + output_folder: str + The full path to the output folder where the face patches should besaved + patch_size: int + The size of the face patch output from the model + configfile: str, optional + The full path to a custom configuration ini file. If ``None`` is passed + then the file is loaded from the default location. Default: ``None``. + """ + def __init__(self, output_folder: str, patch_size: int, **kwargs) -> None: + logger.debug("patch_size: %s", patch_size) + super().__init__(output_folder, **kwargs) + self._extension = {"png": ".png", "tiff": ".tif"}[cfg.format()] + self._separate_mask = cfg.separate_mask() + self._fname_split = re.compile("[^0-9a-zA-Z]") + + if self._extension == ".png" and cfg.bit_depth() not in ("8", "16"): + logger.warning("Patch Writer: Bit Depth '%s' is unsupported for format '%s'. " + "Updating to '16'", cfg.bit_depth(), cfg.format()) + cfg.bit_depth.set("16") + + self._dtype = {"8": np.uint8, "16": np.uint16, "32": np.float32}[cfg.bit_depth()] + self._multiplier = {"8": 255., "16": 65535., "32": 1.}[cfg.bit_depth()] + + self._dummy_patch = np.zeros((1, patch_size, patch_size, 4), dtype=np.float32) + + tl_box = np.array([[0, 0], [patch_size, 0], [patch_size, patch_size], [0, patch_size]], + dtype=np.float32) + self._patch_corner = {"top-left": tl_box[0], + "top-right": tl_box[1], + "bottom-right": tl_box[2], + "bottom-left": tl_box[3]}[cfg.origin()].copy() + self._box = tl_box + if cfg.origin() in ("top-right", "bottom-left"): + self._box[[1, 3], :] = self._box[[3, 1], :] # keep clockwise from 0,0 + + self._args = self._get_save_args() + self._matrices: dict[str, dict[str, list[list[float]]]] = {} + + def _get_save_args(self) -> tuple[int, ...]: + """ Obtain the save parameters for the file format. + + Returns + ------- + tuple + The OpenCV specific arguments for the selected file format + """ + args: tuple[int, ...] = tuple() + if self._extension == ".png" and cfg.png_compress_level() > -1: + args = (cv2.IMWRITE_PNG_COMPRESSION, cfg.png_compress_level()) + if self._extension == ".tif" and cfg.bit_depth() != "32": + tiff_methods = {"none": 1, "lzw": 5, "deflate": 8} + method = cfg.tiff_compression_method() + method = "none" if method is None else method + args = (cv2.IMWRITE_TIFF_COMPRESSION, tiff_methods[method]) + logger.debug(args) + return args + + def _get_new_filename(self, filename: str, face_index: int) -> str: + """ Obtain the filename for the output file based on the frame's filename and the user + selected naming options + + Parameters + ---------- + filename: str + The original frame's filename + face_index: int + The index of the face within the frame + + Returns + ------- + str + The new filename for naming the output face patch + """ + face_idx = str(face_index).rjust(2, "0") + fname, ext = os.path.splitext(filename) + fname = os.path.basename(fname) + + split_fname = self._fname_split.split(fname) + if split_fname and split_fname[-1].isdigit(): + i_frame_no = (int(split_fname[-1]) + + (int(cfg.start_index()) - 1) + + cfg.index_offset()) + frame_no = f".{str(i_frame_no).rjust(cfg.number_padding(), '0')}" + base_fname = fname[:-len(split_fname[-1]) - 1] + else: + frame_no = "" + base_fname = fname + + retval = "" + if cfg.include_filename(): + retval += base_fname + if cfg.face_index_location() == "before": + retval = f"{retval}_{face_idx}" + retval += frame_no + if cfg.face_index_location() == "after": + retval = f"{retval}.{face_idx}" + retval += ext + logger.trace("source filename: '%s', output filename: '%s'", # type:ignore[attr-defined] + filename, retval) + return retval + + def write(self, filename: str, image: list[list[bytes]]) -> None: + """ Write out the pre-encoded image to disk. If separate mask has been selected, write out + the encoded mask to a sub-folder in the output directory. + + Parameters + ---------- + filename: str + The full path to write out the image to. + image: list[list[bytes]] + List of list of :class:`bytes` objects of containing all swapped faces from a frame to + write out. The inner list will be of length 1 (mask included in the alpha channel) or + length 2 (mask to write out separately) + """ + logger.trace("Outputting: (filename: '%s')", filename) # type:ignore[attr-defined] + + read_func = png_read_meta if self._extension == ".png" else tiff_read_meta + for idx, face in enumerate(image): + new_filename = self._get_new_filename(filename, idx) + filenames = self.get_output_filename(new_filename, cfg.format(), self._separate_mask) + for fname, img in zip(filenames, face): + try: + with open(fname, "wb") as outfile: + outfile.write(img) + except Exception as err: # pylint:disable=broad-except + logger.error("Failed to save image '%s'. Original Error: %s", filename, err) + if not cfg.json_output(): + continue + mat = T.cast(dict[str, list[list[float]]], read_func(img)) + self._matrices[os.path.splitext(os.path.basename(fname))[0]] = mat + + @classmethod + def _get_inverse_matrices(cls, matrices: np.ndarray) -> np.ndarray: + """ Obtain the inverse matrices for the given matrices. If ``None`` is supplied return a + dummy transformation matrix that performs no action + + Parameters + ---------- + matrices : :class:`numpy.ndarray` + The original transform matrices that the inverse needs to be calculated for + + Returns + ------- + :class:`numpy.ndarray` + The inverse transformation matrices + """ + if not np.any(matrices): + return np.array([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]], dtype=np.float32) + + identity = np.array([[[0., 0., 1.]]], dtype=np.float32) + mat = np.concatenate([matrices, np.repeat(identity, matrices.shape[0], axis=0)], axis=1) + retval = np.linalg.inv(mat) + logger.trace("matrix: %s, inverse: %s", mat, retval) # type:ignore[attr-defined] + return retval + + def _adjust_to_origin(self, matrices: np.ndarray, canvas_size: tuple[int, int]) -> None: + """ Adjust the transformation matrix to use the correct target coordinates system. The + matrix adjustment is done in place, so this does not return a value + + Parameters + ---------- + matrices: :class:`numpy.ndarray` + The transformation matrices to be adjusted + canvas_size: tuple[int, int] + The size of the canvas width, height) that the transformation matrix applies to. + """ + if cfg.origin() == "top-left": + return + + for mat in matrices: + og_cnr = cv2.transform(self._patch_corner[None, None], mat[:2, ...]).squeeze() + x_shift, y_shift = og_cnr + if cfg.origin().split("-")[-1] == "right": + x_shift = canvas_size[0] - x_shift + if cfg.origin().split("-")[0] == "bottom": + y_shift = canvas_size[1] - y_shift + mat[:2, 2] = [x_shift, y_shift] + + if cfg.origin() in ("top-right", "bottom-left"): + matrices[..., :2, :2] *= [[[1, -1], [-1, 1]]] # switch shear + + def _get_roi(self, matrices: np.ndarray) -> np.ndarray: + """ Obtain the (x, y) ROI points of the patch in the original frame. Points are returned + in clockwise order from the origin location + + Parameters + ---------- + matrices: :class:`numpy.ndarray` + The transformation matrices for the current frame + + Returns + ------- + np.ndarray + The ROI of the patches in original frame co-ordinates in clockwise order from the + origin point + """ + retval = [cv2.transform(np.expand_dims(self._box, axis=1), mat[:2, ...]).squeeze() + for mat in matrices] + return np.array(retval, dtype=np.float32) + + def pre_encode(self, image: np.ndarray, **kwargs) -> list[list[bytes]]: + """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker. + + Parameters + ---------- + image: :class:`numpy.ndarray` + A 3 or 4 channel BGR swapped face batch as float32 + canvas_size: tuple[int, int] + The size of the canvas (x, y) that the transformation matrix applies to. + matrices: :class:`numpy.ndarray`, optional + The transformation matrices for extracting the face patches from the original frame. + Must be provided if an image is provided, otherwise ``None`` to insert a dummy matrix + + Returns + ------- + list + List of :class:`bytes` objects ready for writing. The list will be of length 1 with + image bytes object as the only member unless separate mask has been requested, in which + case it will be length 2 with the image in position 0 and mask in position 1 + """ + logger.trace("Pre-encoding image") # type:ignore[attr-defined] + retval = [] + canvas_size: tuple[int, int] = kwargs.get("canvas_size", (1, 1)) + matrices: np.ndarray = kwargs.get("matrices", np.array([])) + + if not np.any(image) and cfg.empty_frames() == "blank": + image = self._dummy_patch + + matrices = self._get_inverse_matrices(matrices) + self._adjust_to_origin(matrices, canvas_size) + rois = self._get_roi(matrices) + patches = (image * self._multiplier).astype(self._dtype) + + for patch, matrix, roi in zip(patches, matrices, rois): + this_face = [] + mat = json.dumps({"transform_matrix": matrix.tolist(), "roi": roi.tolist()}, + ensure_ascii=True).encode("ascii") + if self._separate_mask: + mask = patch[..., -1] + face = patch[..., :3] + + this_face.append(encode_image(mask, + self._extension, + encoding_args=self._args, + metadata=mat)) + else: + face = patch + + this_face.insert(0, encode_image(face, + self._extension, + encoding_args=self._args, + metadata=mat)) + retval.append(this_face) + return retval + + def close(self) -> None: + """ Outputs json file if requested """ + if not cfg.json_output(): + return + fname = os.path.join(self.output_folder, "matrices.json") + with open(fname, "w", encoding="utf-8") as ofile: + json.dump(self._matrices, ofile, indent=2, sort_keys=True) + logger.info("Patch matrices written to: '%s'", fname) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/writer/patch_defaults.py b/plugins/convert/writer/patch_defaults.py new file mode 100755 index 0000000000..4febc3f770 --- /dev/null +++ b/plugins/convert/writer/patch_defaults.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap patch Writer plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = ( + "Options for outputting the raw converted face patches from faceswap\n" + "The raw face patches are output along with the transformation matrix, per face, to " + "transform the face back into the original frame in external tools" +) + +start_index = ConfigItem( + default="0", + info="The starting frame number for the first output frame.", + datatype=str, + choices=["0", "1"], + group="file_naming", + gui_radio=True) + +index_offset = ConfigItem( + default=0, + datatype=int, + group="file_naming", + info="How much to offset the frame numbering by.", + rounding=1, + min_max=(0, 1000)) + +number_padding = ConfigItem( + datatype=int, + default=6, + group="file_naming", + info="Length to pad the frame numbers by.", + rounding=6, + min_max=(0, 10)) + +include_filename = ConfigItem( + datatype=bool, + default=True, + group="file_naming", + info="Prefix the filename of the original frame to each face patch's output filename.") + +face_index_location = ConfigItem( + datatype=str, + default="before", + group="file_naming", + info="For frames that contain multiple faces, where the face index should appear in " + "the filename:" + "\n\t before: places the face index before the frame number." + "\n\t after: places the face index after the frame number.", + choices=["before", "after"], + gui_radio=True) + +origin = ConfigItem( + datatype=str, + default="bottom-left", + group="output", + info="The origin (0, 0) location of the software that patches will be imported into. " + "This impacts the transformation matrix that is supplied with the image patch. " + "Setting the correct origin here will make importing into the external tool " + "simpler." + "\n\t top-left: The origin (0, 0) of the external canvas is at the top left " + "corner." + "\n\t bottom-left: The origin (0, 0) of the external canvas is at the bottom " + "left corner." + "\n\t top-right: The origin (0, 0) of the external canvas is at the top right " + "corner." + "\n\t bottom-right: The origin (0, 0) of the external canvas is at the bottom " + "right corner.", + choices=["top-left", "bottom-left", "top-right", "bottom-right"], + gui_radio=True) + +empty_frames = ConfigItem( + datatype=str, + group="output", + default="blank", + info="How to handle the output of frames without faces:" + "\n\t skip: skips any frames that do not have a face within it. This will lead to " + "gaps within the final image sequence." + "\n\t blank: outputs a blank (empty) face patch for any frames without faces. " + "There will be no gaps within the final image sequence, as those gaps will be " + "padded with empty face patches", + choices=["skip", "blank"], + gui_radio=True) + +json_output = ConfigItem( + datatype=bool, + default=False, + group="output", + info="The transformation matrix, and other associated metadata, is output within the " + "face images EXIF fields. Some external tools can read this data, others cannot." + "enable this option to output a json file which contains this same metadata " + "mapped to each output face patch's filename.") + +separate_mask = ConfigItem( + datatype=bool, + default=False, + group="output", + info="Seperate the mask into its own single channel patch. If enabled, the RGB image " + "will be saved into the selected output folder whilst the masks will be saved " + "into a sub-folder named `masks`. If not enabled then the mask will be included " + "in the alpha-channel of the RGBA output.") + +bit_depth = ConfigItem( + datatype=str, + default="16", + group="format", + info="The bit-depth for the output images:" + "\n\t 8: 8-bit unsigned - Supported by all formats." + "\n\t 16: 16-bit unsigned - Supported by all formats." + "\n\t 32: 32-bit float - Supported by Tiff only.", + choices=["8", "16", "32"], + gui_radio=True) + +format = ConfigItem( # pylint:disable=redefined-builtin + datatype=str, + default="png", + group="format", + info="File format to save as." + "\n\t png: PNG file format. Transformation matrix is written to the custom iTxt " + "header field 'faceswap'" + "\n\t tiff: TIFF file format. Transformation matrix is written to the " + "'image_description' header field", + choices=["png", "tiff"], + gui_radio=True) + +png_compress_level = ConfigItem( + datatype=int, + default=3, + group="format", + info="ZLIB compression level, 1 gives best speed, 9 gives best compression, 0 gives no " + "compression at all.", + rounding=1, + min_max=(0, 9)) + +tiff_compression_method = ConfigItem( + datatype=str, + default="lzw", + group="format", + info="The compression method to use for Tiff files. Note: For 32bit output, SGILOG " + "compression will always be used regardless of what is selected here.", + choices=["none", "lzw", "deflate"], + gui_radio=True) diff --git a/plugins/convert/writer/pillow.py b/plugins/convert/writer/pillow.py new file mode 100644 index 0000000000..7fb1c75e28 --- /dev/null +++ b/plugins/convert/writer/pillow.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" Image output writer for faceswap.py converter """ +from io import BytesIO +from PIL import Image + +import numpy as np + +from lib.utils import get_module_objects +from ._base import Output, logger +from . import pillow_defaults as cfg + + +class Writer(Output): + """ Images output writer using Pillow + + Parameters + ---------- + output_folder: str + The full path to the output folder where the converted media should be saved + configfile: str, optional + The full path to a custom configuration ini file. If ``None`` is passed + then the file is loaded from the default location. Default: ``None``. + """ + def __init__(self, output_folder: str, **kwargs) -> None: + super().__init__(output_folder, **kwargs) + self._check_transparency_format() + # Correct format namings for writing to byte stream + self._format_dict = {"jpg": "JPEG", "jp2": "JPEG 2000", "tif": "TIFF"} + self._separate_mask = self.output_alpha and cfg.separate_mask() + self._kwargs = self._get_save_kwargs() + + @property + def output_alpha(self) -> bool: + """ bool : Pillow can output alpha channel. Returns ``True`` """ + return cfg.draw_transparent() + + def _check_transparency_format(self) -> None: + """ Make sure that the output format is correct if draw_transparent is selected """ + # pylint:disable=duplicate-code + if not self.output_alpha or (self.output_alpha and cfg.format() in ("png", "tif")): + return + logger.warning("Draw Transparent selected, but the requested format does not support " + "transparency. Changing output format to 'png'") + cfg.format.set("png") + + def _get_save_kwargs(self) -> dict[str, bool | int | str]: + """ Return the save parameters for the file format + + Returns + ------- + dict + The specific keyword arguments for the selected file format + """ + filetype = cfg.format() + kwargs: dict[str, bool | int | str] = {} + if filetype in ("gif", "jpg", "png"): + kwargs["optimize"] = cfg.optimize() + if filetype == "gif": + kwargs["interlace"] = cfg.gif_interlace() + if filetype == "png": + kwargs["compress_level"] = cfg.png_compress_level() + if filetype == "tif": + kwargs["compression"] = cfg.tif_compression() + logger.debug(kwargs) + return kwargs + + def write(self, filename: str, image: list[BytesIO]) -> None: + """ Write out the pre-encoded image to disk. If separate mask has been selected, write out + the encoded mask to a sub-folder in the output directory. + + Parameters + ---------- + filename: str + The full path to write out the image to. + image: list + List of :class:`BytesIO` objects of length 1 (containing just the image to write out) + or length 2 (containing the image and mask to write out) + """ + logger.trace("Outputting: (filename: '%s'", filename) # type:ignore + filenames = self.get_output_filename(filename, cfg.format(), self._separate_mask) + try: + for fname, img in zip(filenames, image): + with open(fname, "wb") as outfile: + outfile.write(img.read()) + except Exception as err: # pylint:disable=broad-except + logger.error("Failed to save image '%s'. Original Error: %s", filename, err) + + def pre_encode(self, image: np.ndarray, **kwargs) -> list[BytesIO]: + """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker + + Parameters + ---------- + image: :class:`numpy.ndarray` + A 3 or 4 channel BGR swapped frame + + Returns + ------- + list + List of :class:`BytesIO` objects ready for writing. The list will be of length 1 with + image bytes object as the only member unless separate mask has been requested, in which + case it will be length 2 with the image in position 0 and mask in position 1 + """ + logger.trace("Pre-encoding image") # type:ignore + + if self._separate_mask: + encoded_mask = self._encode_image(image[..., -1]) + image = image[..., :3] + + rgb = [2, 1, 0, 3] if image.shape[2] == 4 else [2, 1, 0] + encoded_image = self._encode_image(image[..., rgb]) + + retval = [encoded_image] + + if self._separate_mask: + retval.append(encoded_mask) + + return retval + + def _encode_image(self, image: np.ndarray) -> BytesIO: + """ Encode an image in the correct format as a bytes object for saving + + Parameters + ---------- + image: :class:`np.ndarray` + The single channel mask to encode for saving + + Returns + ------- + :class:`BytesIO` + The image as a bytes object ready for writing to disk + """ + fmt = self._format_dict.get(cfg.format(), cfg.format().upper()) + encoded = BytesIO() + out_image = Image.fromarray(image) + out_image.save(encoded, fmt, **self._kwargs) + encoded.seek(0) + return encoded + + def close(self) -> None: + """ Does nothing as Pillow writer does not need a close method """ + return + + +__all__ = get_module_objects(__name__) diff --git a/plugins/convert/writer/pillow_defaults.py b/plugins/convert/writer/pillow_defaults.py new file mode 100755 index 0000000000..6995be3acc --- /dev/null +++ b/plugins/convert/writer/pillow_defaults.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Pillow Writer plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "Options for outputting converted frames to a series of images using Pillow\n" + "Pillow is more feature rich than OpenCV but can be slower." +) + + +format = ConfigItem( # pylint:disable=redefined-builtin + group="format", + datatype=str, + default="png", + info="Image format to use:" + "\n\t bmp: Windows bitmap" + "\n\t gif: Graphics Interchange Format (NB: Not animated)" + "\n\t jpg: JPEG format" + "\n\t jp2: JPEG 2000 format" + "\n\t png: Portable Network Graphics" + "\n\t ppm: Portable Pixmap Format" + "\n\t tif: Tag Image File Format", + choices=["bmp", "gif", "jpg", "jp2", "png", "ppm", "tif"], + gui_radio=True) + +draw_transparent = ConfigItem( + datatype=bool, + default=False, + group="format", + info="Place the swapped face on a transparent layer rather than the original frame.\nNB: " + "This is only compatible with images saved in png or tif format. If an incompatible " + "format is selected then the image will be saved as a png.") + +separate_mask = ConfigItem( + datatype=bool, + default=False, + group="format", + info="Seperate the mask into its own single channel image. This only applies when " + "'draw-transparent' is selected. If enabled, the RGB image will be saved into the " + "selected output folder whilst the masks will be saved into a sub-folder named " + "`masks`. If not enabled then the mask will be included in the alpha-channel of the " + "RGBA output.") + +optimize = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="[gif, jpg and png only] If enabled, indicates that the encoder should make an extra " + "pass over the image in order to select optimal encoder settings.") + +gif_interlace = ConfigItem( + datatype=bool, + default=True, + group="settings", + info="[gif only] Set whether to save the gif as interlaced or not.") + +jpg_quality = ConfigItem( + datatype=int, + default=75, + group="compression", + info="[jpg only] Set the jpg quality. 1 is worst 95 is best. Higher quality leads to " + "larger file sizes.", + rounding=1, + min_max=(1, 95)) + +png_compress_level = ConfigItem( + datatype=int, + default=3, + group="compression", + info="[png only] ZLIB compression level, 1 gives best speed, 9 gives best compression, 0 " + "gives no compression at all. When optimize option is set to True this has no effect " + "(it is set to 9 regardless of a value passed).", + rounding=1, + min_max=(0, 9)) + +tif_compression = ConfigItem( + datatype=str, + default="tiff_deflate", + group="compression", + info="[tif only] The desired compression method for the file.", + choices=["none", "tiff_ccitt", "group3", "group4", "tiff_jpeg", "tiff_adobe_deflate", + "tiff_thunderscan", "tiff_deflate", "tiff_sgilog", "tiff_sgilog24", + "tiff_raw_16"]) diff --git a/plugins/extract/__init__.py b/plugins/extract/__init__.py index e69de29bb2..3bffbe70b8 100644 --- a/plugins/extract/__init__.py +++ b/plugins/extract/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Package for Faceswap's extraction pipeline """ +from .extract_media import ExtractMedia +from .pipeline import Extractor diff --git a/plugins/extract/_base.py b/plugins/extract/_base.py new file mode 100644 index 0000000000..9a881b90bb --- /dev/null +++ b/plugins/extract/_base.py @@ -0,0 +1,653 @@ +#!/usr/bin/env python3 +""" Base class for Faceswap :mod:`~plugins.extract.detect`, :mod:`~plugins.extract.align` and +:mod:`~plugins.extract.mask` Plugins +""" +from __future__ import annotations +import logging +import typing as T +from dataclasses import dataclass, field + +import numpy as np +import torch +from keras import device + +from lib.logger import parse_class_init +from lib.multithreading import MultiThread +from lib.queue_manager import queue_manager +from lib.utils import GetModel +from lib.utils import get_backend +from . import extract_config as cfg +from . import ExtractMedia + +if T.TYPE_CHECKING: + from collections.abc import Callable, Generator, Sequence + from queue import Queue + from lib.align import DetectedFace + from .align._base import AlignerBatch + from .detect._base import DetectorBatch + from .mask._base import MaskerBatch + from .recognition._base import RecogBatch + +logger = logging.getLogger(__name__) +BatchType = T.Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"] + + +@dataclass +class ExtractorBatch: + """ Dataclass for holding a batch flowing through post Detector plugins. + + The batch size for post Detector plugins is not the same as the overall batch size. + An image may contain 0 or more detected faces, and these need to be split and recombined + to be able to utilize a plugin's internal batch size. + + Plugin types will inherit from this class and add required keys. + + Parameters + ---------- + image: list + List of :class:`numpy.ndarray` containing the original frames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects + filename: list + List of original frame filenames for the batch + feed: :class:`numpy.ndarray` + Batch of feed images to feed the net with + prediction: :class:`numpy.nd.array` + Batch of predictions. Direct output from the aligner net + data: dict + Any specific data required during the processing phase for a particular plugin + """ + image: list[np.ndarray] = field(default_factory=list) + detected_faces: Sequence[DetectedFace | list[DetectedFace]] = field(default_factory=list) + filename: list[str] = field(default_factory=list) + feed: np.ndarray = field(default_factory=lambda: np.array([])) + prediction: np.ndarray = field(default_factory=lambda: np.array([])) + data: list[dict[str, T.Any]] = field(default_factory=list) + + def __repr__(self) -> str: + """ Prettier repr for debug printing """ + data = [{k: (v.shape, v.dtype) if isinstance(v, np.ndarray) else v for k, v in dat.items()} + for dat in self.data] + return (f"{self.__class__.__name__}(" + f"image={[(img.shape, img.dtype) for img in self.image]}, " + f"detected_faces={self.detected_faces}, " + f"filename={self.filename}, " + f"feed={[(f.shape, f.dtype) for f in self.feed]}, " + f"prediction=({self.prediction.shape}, {self.prediction.dtype}), " + f"data={data}") + + +@dataclass +class PluginInfo: + """ Dataclass to hold information about a plugin instance + + Parameters + ---------- + instance: int + The instance id of the plugin + plugin_type: Literal["align", "detect", "mask", "recognition"] | None, optional + The plugin type that the plugin instance is. Default: ``None`` + is_initialized: bool, optional + ``True`` if the plugin is initialized. Default: ``False`` + """ + instance: int + plugin_type: T.Literal["align", "detect", "mask", "recognition"] | None = None + is_initialized: bool = False + + +@dataclass +class SplitTracker: + """ Dataclass to hold objects for splitting frame's detected faces and rejoining them for + post-detector pliugins + + Parameters + ---------- + faces_per_filename: dict[str, int] + Tracking of faces per filename for recompiling batches + rollover: :class:`ExtractMedia` | None + Batch rollover items + output_faces: list[:class:`~lib.align.detected_face.DetectedFace`] + Recompiled output faces from the plugin + """ + faces_per_filename: dict[str, int] + rollover: ExtractMedia | None + output_faces: list[DetectedFace] + + +class Extractor(): # pylint:disable=too-many-instance-attributes + """ Extractor Plugin Object + + All ``_base`` classes for Aligners, Detectors and Maskers inherit from this class. + + This class sets up a pipeline for working with ML plugins. + + Plugins are split into 3 threads, to utilize Numpy and CV2s parallel processing, as well as + allow the predict function of the model to sit in a dedicated thread. + A plugin is expected to have 3 core functions, each in their own thread: + - :func:`process_input()` - Prepare the data for feeding into a model + - :func:`predict` - Feed the data through the model + - :func:`process_output()` - Perform any data post-processing + + Parameters + ---------- + git_model_id: int + The second digit in the github tag that identifies this model. See + https://github.com/deepfakes-models/faceswap-models for more information + model_filename: str + The name of the model file to be loaded + configfile: str, optional + Path to a custom configuration ``ini`` file. Default: Use system configfile + instance: int, optional + If this plugin is being executed multiple times (i.e. multiple pipelines have been + launched), the instance of the plugin must be passed in for naming convention reasons. + Default: 0 + + + The following attributes should be set in the plugin's :func:`__init__` method after + initializing the parent. + + Attributes + ---------- + name: str + Name of this plugin. Used for display purposes. + input_size: int + The input size to the model in pixels across one edge. The input size should always be + square. + color_format: str + Color format for model. Must be ``'BGR'``, ``'RGB'`` or ``'GRAY'``. Defaults to ``'BGR'`` + if not explicitly set. + vram: int + Approximate VRAM used by the model at :attr:`input_size`. Used to calculate the + :attr:`batchsize`. Be conservative to avoid OOM. + vram_per_batch: int + Approximate additional VRAM used by the model for each additional batch. Used to calculate + the :attr:`batchsize`. Be conservative to avoid OOM. + + See Also + -------- + plugins.extract.detect._base : Detector parent class for extraction plugins. + plugins.extract.align._base : Aligner parent class for extraction plugins. + plugins.extract.mask._base : Masker parent class for extraction plugins. + plugins.extract.pipeline : The extract pipeline that configures and calls all plugins + + """ + def __init__(self, + git_model_id: int | None = None, + model_filename: str | list[str] | None = None, + configfile: str | None = None, + instance: int = 0) -> None: + logger.debug(parse_class_init(locals())) + cfg.load_config(configfile) + + self._info = PluginInfo(instance=instance) + """:class:`PluginInfo`: holds information about the plugin instance""" + + self.model_path = self._get_model(git_model_id, model_filename) + """ str or list: Path to the model file(s) (if required). Multiple model files should + be a list of strings """ + + # << SET THE FOLLOWING IN PLUGINS __init__ IF DIFFERENT FROM DEFAULT >> # + self.name: str | None = None + self.input_size = 0 + self.color_format: T.Literal["BGR", "RGB", "GRAY"] = "BGR" + self.vram = 0 + self.vram_per_batch = 0 + + # << THE FOLLOWING ARE SET IN self.initialize METHOD >> # + self.model: T.Any = None + """varies: The model for this plugin. Set in the plugin's :func:`init_model()` method """ + + # For detectors that support batching, this should be set to the calculated batch size + # that the amount of available VRAM will support. + self.batchsize = 1 + """ int: Batchsize for feeding this model. The number of images the model should + feed through at once. """ + + self._queues: dict[str, Queue] = {} + """ dict: in + out queues and internal queues for this plugin, """ + + self._threads: list[MultiThread] = [] + """ list: Internal threads for this plugin """ + + self._extract_media: dict[str, ExtractMedia] = {} + """ dict: The :class:`~plugins.extract.extract_media.ExtractMedia` objects currently being + processed. Stored at input for pairing back up on output of extractor process """ + + # << THE FOLLOWING PROTECTED ATTRIBUTES ARE SET IN PLUGIN TYPE _base.py >>> # + self._tracker = SplitTracker({}, None, []) + """:class:`SplitTracker`: Holds objects for splitting frame's detected faces and + rejoining them for post-detector pliugins """ + + logger.debug("Initialized _base %s", self.__class__.__name__) + + # <<< OVERIDABLE METHODS >>> # + def init_model(self) -> None: + """ **Override method** + + Override this method to execute the specific model initialization method """ + raise NotImplementedError + + def process_input(self, batch: BatchType) -> None: + """ **Override method** + + Override this method for specific extractor pre-processing of image + + Parameters + ---------- + batch : :class:`ExtractorBatch` + Contains the batch that is currently being passed through the plugin process + """ + raise NotImplementedError + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ **Override method** + + Override this method for specific extractor model prediction function + + Parameters + ---------- + feed: :class:`numpy.ndarray` + The feed images for the batch + + Notes + ----- + Input for :func:`predict` should have been set in :func:`process_input` + + Output from the model should populate the key :attr:`prediction` of the :attr:`batch`. + + For Detect: + the expected output for the :attr:`prediction` of the :attr:`batch` should be a + ``list`` of :attr:`batchsize` of detected face points. These points should be either + a ``list``, ``tuple`` or ``numpy.ndarray`` with the first 4 items being the `left`, + `top`, `right`, `bottom` points, in that order + """ + raise NotImplementedError + + def process_output(self, batch: BatchType) -> None: + """ **Override method** + + Override this method for specific extractor model post predict function + + Parameters + ---------- + batch: :class:`ExtractorBatch` + Contains the batch that is currently being passed through the plugin process + + Notes + ----- + For Align: + The :attr:`landmarks` must be populated in :attr:`batch` from this method. + This should be a ``list`` or :class:`numpy.ndarray` of :attr:`batchsize` containing a + ``list``, ``tuple`` or :class:`numpy.ndarray` of `(x, y)` coordinates of the 68 point + landmarks as calculated from the :attr:`model`. + """ + raise NotImplementedError + + def on_completion(self) -> None: + """ Override to perform an action when the extract process has completed. By default, no + action is undertaken """ + return + + def _predict(self, batch: BatchType) -> BatchType: + """ **Override method** (at `` level) + + This method should be overridden at the `` level (IE. + ``plugins.extract.detect._base`` or ``plugins.extract.align._base``) and should not + be overridden within plugins themselves. + + It acts as a wrapper for the plugin's ``self.predict`` method and handles any + predict processing that is consistent for all plugins within the `plugin_type` + + Parameters + ---------- + batch: :class:`ExtractorBatch` + Contains the batch that is currently being passed through the plugin process + """ + raise NotImplementedError + + def _process_input(self, batch: BatchType) -> BatchType: + """ **Override method** (at `` level) + + This method should be overridden at the `` level (IE. + ``plugins.extract.detect._base`` or ``plugins.extract.align._base``) and should not + be overridden within plugins themselves. + + It acts as a wrapper for the plugin's :func:`process_input` method and handles any + input processing that is consistent for all plugins within the `plugin_type`. + + If this method is not overridden then the plugin's :func:`process_input` is just called. + + Parameters + ---------- + batch: :class:`ExtractorBatch` + Contains the batch that is currently being passed through the plugin process + + Notes + ----- + When preparing an input to the model a the attribute :attr:`feed` must be added + to the :attr:`batch` which contains this input. + """ + self.process_input(batch) + return batch + + def _process_output(self, batch: BatchType) -> BatchType: + """ **Override method** (at `` level) + + This method should be overridden at the `` level (IE. + ``plugins.extract.detect._base`` or ``plugins.extract.align._base``) and should not + be overridden within plugins themselves. + + It acts as a wrapper for the plugin's :func:`process_output` method and handles any + output processing that is consistent for all plugins within the `plugin_type`. + + If this method is not overridden then the plugin's :func:`process_output` is just called. + + Parameters + ---------- + batch: :class:`ExtractorBatch` + Contains the batch that is currently being passed through the plugin process + """ + self.process_output(batch) + return batch + + def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]: + """ **Override method** (at `` level) + + This method should be overridden at the `` level (IE. + :mod:`plugins.extract.detect._base`, :mod:`plugins.extract.align._base` or + :mod:`plugins.extract.mask._base`) and should not be overridden within plugins themselves. + + Handles consistent finalization for all plugins that exist within that plugin type. Its + input is always the output from :func:`process_output()` + + Parameters + ---------- + batch: :class:`ExtractorBatch` + Contains the batch that is currently being passed through the plugin process + """ + raise NotImplementedError + + def get_batch(self, queue: Queue) -> tuple[bool, BatchType]: + """ **Override method** (at `` level) + + This method should be overridden at the `` level (IE. + :mod:`plugins.extract.detect._base`, :mod:`plugins.extract.align._base` or + :mod:`plugins.extract.mask._base`) and should not be overridden within plugins themselves. + + Get :class:`~plugins.extract.extract_media.ExtractMedia` items from the queue in batches of + :attr:`batchsize` + + Parameters + ---------- + queue : queue.Queue() + The ``queue`` that the batch will be fed from. This will be the input to the plugin. + """ + raise NotImplementedError + + @classmethod + def get_device_context(cls, cpu: bool) -> T.ContextManager: + """ Get a device context manager for running inference on the CPU + + Parameters + ---------- + cpu: bool + ``True`` to get a context manager for running on the CPU. ``False`` to get a + context manager for the default device + + Returns + ------- + ContextManager + The context manager for running ops on the selected device + """ + if cpu: + logger.debug("CPU mode selected. Returning CPU device context") + return device("cpu") + + # TODO apple_silicon + if get_backend() == "apple_silicon": + pass + + if torch.cuda.is_available(): + logger.debug("Cuda available. Returning Cuda device context") + return device("cuda") + + logger.debug("Cuda not available. Returning CPU device context") + return device("cpu") + + # <<< THREADING METHODS >>> # + def start(self) -> None: + """ Start all threads + + Exposed for :mod:`~plugins.extract.pipeline` to start plugin's threads + """ + for thread in self._threads: + thread.start() + + def join(self) -> None: + """ Join all threads + + Exposed for :mod:`~plugins.extract.pipeline` to join plugin's threads + """ + for thread in self._threads: + thread.join() + + def check_and_raise_error(self) -> None: + """ Check all threads for errors + + Exposed for :mod:`~plugins.extract.pipeline` to check plugin's threads for errors + """ + for thread in self._threads: + thread.check_and_raise_error() + + def rollover_collector(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia: + """ For extractors after the Detectors, the number of detected faces per frame vs extractor + batch size mean that faces will need to be split/re-joined with frames. The rollover + collector can be used to rollover items that don't fit in a batch. + + Collect the item from the :attr:`_tracker.rollover` dict or from the queue. Add face count + per frame to :attr:`_tracker.faces_per_filename` for joining batches back up in finalize + + Parameters + ---------- + queue: :class:`queue.Queue` + The input queue to the aligner. Should contain + :class:`~plugins.extract.extract_media.ExtractMedia` objects + + Returns + ------- + :class:`~plugins.extract.extract_media.ExtractMedia` or EOF + The next extract media object, or EOF if pipe has ended + """ + if self._tracker.rollover is not None: + logger.trace("Getting from _tracker.rollover: " # type:ignore[attr-defined] + "(filename: `%s`, faces: %s)", + self._tracker.rollover.filename, + len(self._tracker.rollover.detected_faces)) + item: T.Literal["EOF"] | ExtractMedia = self._tracker.rollover + self._tracker.rollover = None + else: + next_item = self._get_item(queue) + # Rollover collector should only be used at entry to plugin + assert isinstance(next_item, (ExtractMedia, str)) + item = next_item + if item != "EOF": + logger.trace("Getting from queue: (filename: %s, " # type:ignore[attr-defined] + "faces: %s)", + item.filename, len(item.detected_faces)) + self._tracker.faces_per_filename[item.filename] = len(item.detected_faces) + return item + + # <<< PROTECTED ACCESS METHODS >>> # + # <<< INIT METHODS >>> # + @classmethod + def _get_model(cls, + git_model_id: int | None, + model_filename: str | list[str] | None) -> str | list[str] | None: + """ Check if model is available, if not, download and unzip it """ + if model_filename is None: + logger.debug("No model_filename specified. Returning None") + return None + if git_model_id is None: + logger.debug("No git_model_id specified. Returning None") + return None + model = GetModel(model_filename, git_model_id) + return model.model_path + + # <<< PLUGIN INITIALIZATION >>> # + def initialize(self, *args, **kwargs) -> None: + """ Initialize the extractor plugin + + Should be called from :mod:`~plugins.extract.pipeline` + """ + logger.debug("initialize %s: (args: %s, kwargs: %s)", + self.__class__.__name__, args, kwargs) + assert self._info.plugin_type is not None and self.name is not None + if self._info.is_initialized: + # When batch processing, plugins will be initialized on first job in batch + logger.debug("Plugin already initialized: %s (%s)", + self.name, self._info.plugin_type.title()) + return + + logger.info("Initializing %s (%s)...", self.name, self._info.plugin_type.title()) + name = self.name.replace(" ", "_").lower() + self._add_queues(kwargs["in_queue"], + kwargs["out_queue"], + [f"predict_{name}", f"post_{name}"]) + self._compile_threads() + self.init_model() + self._info.is_initialized = True + logger.info("Initialized %s (%s) with batchsize of %s", + self.name, self._info.plugin_type.title(), self.batchsize) + + def _add_queues(self, + in_queue: Queue, + out_queue: Queue, + queues: list[str]) -> None: + """ Add the queues + in_queue and out_queue should be previously created queue manager queues. + queues should be a list of queue names """ + self._queues["in"] = in_queue + self._queues["out"] = out_queue + for q_name in queues: + self._queues[q_name] = queue_manager.get_queue( + name=f"{self._info.plugin_type}{self._info.instance}_{q_name}", + maxsize=1) + + # <<< THREAD METHODS >>> # + def _compile_threads(self) -> None: + """ Compile the threads into self._threads list """ + assert self.name is not None + logger.debug("Compiling %s threads", self._info.plugin_type) + name = self.name.replace(" ", "_").lower() + base_name = f"{self._info.plugin_type}_{name}" + self._add_thread(f"{base_name}_input", + self._process_input, + self._queues["in"], + self._queues[f"predict_{name}"]) + self._add_thread(f"{base_name}_predict", + self._predict, + self._queues[f"predict_{name}"], + self._queues[f"post_{name}"]) + self._add_thread(f"{base_name}_output", + self._process_output, + self._queues[f"post_{name}"], + self._queues["out"]) + logger.debug("Compiled %s threads: %s", self._info.plugin_type, self._threads) + + def _add_thread(self, + name: str, + function: Callable[[BatchType], BatchType], + in_queue: Queue, + out_queue: Queue) -> None: + """ Add a MultiThread thread to self._threads """ + logger.debug("Adding thread: (name: %s, function: %s, in_queue: %s, out_queue: %s)", + name, function, in_queue, out_queue) + self._threads.append(MultiThread(target=self._thread_process, + name=name, + function=function, + in_queue=in_queue, + out_queue=out_queue)) + logger.debug("Added thread: %s", name) + + def _obtain_batch_item(self, function: Callable[[BatchType], BatchType], + in_queue: Queue, + out_queue: Queue) -> BatchType | None: + """ Obtain the batch item from the in queue for the current process. + + Parameters + ---------- + function: callable + The current plugin function being run + in_queue: :class:`queue.Queue` + The input queue for the function + out_queue: :class:`queue.Queue` + The output queue from the function + + Returns + ------- + :class:`ExtractorBatch` or ``None`` + The batch, if one exists, or ``None`` if queue is exhausted + """ + batch: T.Literal["EOF"] | BatchType | ExtractMedia + if function.__name__ == "_process_input": # Process input items to batches + exhausted, batch = self.get_batch(in_queue) + if exhausted: + if batch.filename: + # Put the final batch + batch = function(batch) + out_queue.put(batch) + return None + else: + batch = self._get_item(in_queue) + if batch == "EOF": + return None + + # ExtractMedia should only ever be the output of _get_item at the entry to a + # plugin's pipeline (ie in _process_input) + assert not isinstance(batch, ExtractMedia) + return batch + + def _thread_process(self, + function: Callable[[BatchType], BatchType], + in_queue: Queue, + out_queue: Queue) -> None: + """ Perform a plugin function in a thread + + Parameters + ---------- + function: callable + The current plugin function being run + in_queue: :class:`queue.Queue` + The input queue for the function + out_queue: :class:`queue.Queue` + The output queue from the function + """ + logger.debug("threading: (function: '%s')", function.__name__) + while True: + batch = self._obtain_batch_item(function, in_queue, out_queue) + if batch is None: + break + if not batch.filename: # Batch not populated. Possible during re-aligns + continue + batch = function(batch) + if function.__name__ == "_process_output": + # Process output items to individual items from batch + for item in self.finalize(batch): + out_queue.put(item) + else: + out_queue.put(batch) + logger.debug("Putting EOF") + out_queue.put("EOF") + + # <<< QUEUE METHODS >>> # + def _get_item(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia | BatchType: + """ Yield one item from a queue """ + item = queue.get() + if isinstance(item, ExtractMedia): + logger.trace("filename: '%s', image shape: %s, " # type:ignore[attr-defined] + "detected_faces: %s, queue: %s, item: %s", + item.filename, item.image_shape, item.detected_faces, queue, item) + self._extract_media[item.filename] = item + else: + logger.trace("item: %s, queue: %s", item, queue) # type:ignore[attr-defined] + return item diff --git a/plugins/extract/_config.py b/plugins/extract/_config.py deleted file mode 100644 index 092b0665d0..0000000000 --- a/plugins/extract/_config.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -""" Default configurations for extract """ - -import logging - -from lib.config import FaceswapConfig - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Config(FaceswapConfig): - """ Config File for Models """ - - def set_defaults(self): - """ Set the default values for config """ - logger.debug("Setting defaults") - - # << GLOBAL OPTIONS >> # -# section = "global" -# self.add_section(title=section, -# info="Options that apply to all models") - - # << MTCNN DETECTOR OPTIONS >> # - section = "detect.mtcnn" - self.add_section(title=section, - info="MTCNN Detector options") - self.add_item( - section=section, title="minsize", datatype=int, default=20, rounding=10, - min_max=(20, 1000), - info="The minimum size of a face (in pixels) to be accepted as a positive match.\n" - "Lower values use significantly more VRAM and will detect more false positives") - self.add_item( - section=section, title="threshold_1", datatype=float, default=0.6, rounding=2, - min_max=(0.1, 0.9), - info="First stage threshold for face detection. This stage obtains face candidates") - self.add_item( - section=section, title="threshold_2", datatype=float, default=0.7, rounding=2, - min_max=(0.1, 0.9), - info="Second stage threshold for face detection. This stage refines face candidates") - self.add_item( - section=section, title="threshold_3", datatype=float, default=0.7, rounding=2, - min_max=(0.1, 0.9), - info="Third stage threshold for face detection. This stage further refines face " - "candidates") - self.add_item( - section=section, title="scalefactor", datatype=float, default=0.709, rounding=3, - min_max=(0.1, 0.9), - info="The scale factor for the image pyramid") diff --git a/plugins/extract/align/.cache/2DFAN-4.pb b/plugins/extract/align/.cache/2DFAN-4.pb deleted file mode 100755 index 6e52218f3d..0000000000 Binary files a/plugins/extract/align/.cache/2DFAN-4.pb and /dev/null differ diff --git a/plugins/extract/align/_base.py b/plugins/extract/align/_base.py deleted file mode 100644 index b2b7c62e27..0000000000 --- a/plugins/extract/align/_base.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -""" Base class for Face Aligner plugins - Plugins should inherit from this class - - See the override methods for which methods are - required. - - The plugin will receive a dict containing: - {"filename": , - "image": , - "detected_faces": } - - For each source item, the plugin must pass a dict to finalize containing: - {"filename": , - "image": , - "detected_faces": , - "landmarks": } - """ - -import logging -import os -import traceback - -from io import StringIO - -from lib.aligner import Extract -from lib.gpu_stats import GPUStats -from lib.faces_detect import DetectedFace - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Aligner(): - """ Landmarks Aligner Object """ - def __init__(self, loglevel): - logger.debug("Initializing %s", self.__class__.__name__) - self.loglevel = loglevel - self.cachepath = os.path.join(os.path.dirname(__file__), ".cache") - self.extract = Extract() - self.init = None - - # The input and output queues for the plugin. - # See lib.queue_manager.QueueManager for getting queues - self.queues = {"in": None, "out": None} - - # Path to model if required - self.model_path = self.set_model_path() - - # Approximate VRAM required for aligner. Used to calculate - # how many parallel processes / batches can be run. - # Be conservative to avoid OOM. - self.vram = None - logger.debug("Initialized %s", self.__class__.__name__) - - # <<< OVERRIDE METHODS >>> # - # These methods must be overriden when creating a plugin - @staticmethod - def set_model_path(): - """ path to data file/models - override for specific detector """ - raise NotImplementedError() - - def initialize(self, *args, **kwargs): - """ Inititalize the aligner - Tasks to be run before any alignments are performed. - Override for specific detector """ - logger_init = kwargs["log_init"] - log_queue = kwargs["log_queue"] - logger_init(self.loglevel, log_queue) - logger.debug("_base initialize %s: (PID: %s, args: %s, kwargs: %s)", - self.__class__.__name__, os.getpid(), args, kwargs) - self.init = kwargs["event"] - self.queues["in"] = kwargs["in_queue"] - self.queues["out"] = kwargs["out_queue"] - - def align(self, *args, **kwargs): - """ Process landmarks - Override for specific detector - Must return a list of dlib rects""" - if not self.init: - self.initialize(*args, **kwargs) - logger.debug("Launching Align: (args: %s kwargs: %s)", args, kwargs) - - # <<< DETECTION WRAPPER >>> # - def run(self, *args, **kwargs): - """ Parent align process. - This should always be called as the entry point so exceptions - are passed back to parent. - Do not override """ - try: - self.align(*args, **kwargs) - except Exception: # pylint: disable=broad-except - logger.error("Caught exception in child process: %s", os.getpid()) - # Display traceback if in initialization stage - if not self.init.is_set(): - logger.exception("Traceback:") - tb_buffer = StringIO() - traceback.print_exc(file=tb_buffer) - exception = {"exception": (os.getpid(), tb_buffer)} - self.queues["out"].put(exception) - exit(1) - - # <<< FINALIZE METHODS>>> # - def finalize(self, output): - """ This should be called as the final task of each plugin - aligns faces and puts to the out queue """ - if output == "EOF": - logger.trace("Item out: %s", output) - self.queues["out"].put("EOF") - return - logger.trace("Item out: %s", {key: val - for key, val in output.items() - if key != "image"}) - self.queues["out"].put((output)) - - # <<< MISC METHODS >>> # - @staticmethod - def get_vram_free(): - """ Return free and total VRAM on card with most VRAM free""" - stats = GPUStats() - vram = stats.get_card_most_free() - logger.verbose("Using device %s with %sMB free of %sMB", - vram["device"], - int(vram["free"]), - int(vram["total"])) - return int(vram["card_id"]), int(vram["free"]), int(vram["total"]) - - def get_item(self): - """ Yield one item from the queue """ - while True: - item = self.queues["in"].get() - if isinstance(item, dict): - logger.trace("Item in: %s", {key: val - for key, val in item.items() - if key != "image"}) - # Pass Detector failures straight out and quit - if item.get("exception", None): - self.queues["out"].put(item) - exit(1) - else: - logger.trace("Item in: %s", item) - yield item - if item == "EOF": - break diff --git a/plugins/extract/align/_base/__init__.py b/plugins/extract/align/_base/__init__.py new file mode 100644 index 0000000000..6e32deea62 --- /dev/null +++ b/plugins/extract/align/_base/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Base class for Aligner plugins ALL aligners should at least inherit from this class. """ + +from .aligner import Aligner, AlignerBatch, BatchType diff --git a/plugins/extract/align/_base/aligner.py b/plugins/extract/align/_base/aligner.py new file mode 100644 index 0000000000..49cb5d19cf --- /dev/null +++ b/plugins/extract/align/_base/aligner.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +""" Base class for Face Aligner plugins + +All Aligner Plugins should inherit from this class. +See the override methods for which methods are required. + +The plugin will receive a :class:`~plugins.extract.extract_media.ExtractMedia` object. + +For each source item, the plugin must pass a dict to finalize containing: + +>>> {"filename": [], +>>> "landmarks": [list of 68 point face landmarks] +>>> "detected_faces": []} +""" +from __future__ import annotations +import logging +import typing as T + +from dataclasses import dataclass, field +from time import sleep + +import cv2 +import numpy as np +from torch.cuda import OutOfMemoryError + +from lib.align import LandmarkType +from lib.utils import FaceswapError +from plugins.extract import ExtractMedia, extract_config as cfg +from plugins.extract._base import BatchType, ExtractorBatch, Extractor +from .processing import AlignedFilter, ReAlign + +if T.TYPE_CHECKING: + from collections.abc import Generator + from queue import Queue + from lib.align import DetectedFace + from lib.align.aligned_face import CenteringType + +logger = logging.getLogger(__name__) +_BATCH_IDX: int = 0 + + +def _get_new_batch_id() -> int: + """ Obtain the next available batch index + + Returns + ------- + int + The next available unique batch id + """ + global _BATCH_IDX # pylint:disable=global-statement + _BATCH_IDX += 1 + return _BATCH_IDX + + +@dataclass +class AlignerBatch(ExtractorBatch): + """ Dataclass for holding items flowing through the aligner. + + Inherits from :class:`~plugins.extract._base.ExtractorBatch` + + Parameters + ---------- + batch_id: int + A unique integer for tracking this batch + landmarks: list + List of 68 point :class:`numpy.ndarray` landmark points returned from the aligner + refeeds: list + List of :class:`numpy.ndarrays` for holding each of the feeds that will be put through the + model for each refeed + second_pass: bool, optional + ``True`` if this batch is passing through the aligner for a second time as re-align has + been selected otherwise ``False``. Default: ``False`` + second_pass_masks: :class:`numpy.ndarray`, optional + The masks used to filter out re-feed values for passing to the re-aligner. + """ + batch_id: int = 0 + detected_faces: list[DetectedFace] = field(default_factory=list) + landmarks: np.ndarray = field(default_factory=lambda: np.array([])) + refeeds: list[np.ndarray] = field(default_factory=list) + second_pass: bool = False + second_pass_masks: np.ndarray = field(default_factory=lambda: np.array([])) + + def __repr__(self): + """ Prettier repr for debug printing """ + retval = super().__repr__() + retval += (f", batch_id={self.batch_id}, " + f"landmarks=[({self.landmarks.shape}, {self.landmarks.dtype})], " + f"refeeds={[(f.shape, f.dtype) for f in self.refeeds]}, " + f"second_pass={self.second_pass}, " + f"second_pass_masks={self.second_pass_masks})") + return retval + + def __post_init__(self): + """ Make sure that we have been given a non-zero ID """ + assert self.batch_id != 0, ("A batch ID must be specified for Aligner Batches") + + +class Aligner(Extractor): # pylint:disable=abstract-method + """ Aligner plugin _base Object + + All Aligner plugins must inherit from this class + + Parameters + ---------- + git_model_id: int + The second digit in the github tag that identifies this model. See + https://github.com/deepfakes-models/faceswap-models for more information + model_filename: str + The name of the model file to be loaded + normalize_method: {`None`, 'clahe', 'hist', 'mean'}, optional + Normalize the images fed to the aligner. Default: ``None`` + re_feed: int, optional + The number of times to re-feed a slightly adjusted bounding box into the aligner. + Default: `0` + re_align: bool, optional + ``True`` to obtain landmarks by passing the initially aligned face back through the + aligner. Default ``False`` + disable_filter: bool, optional + Disable all aligner filters regardless of config option. Default: ``False`` + Other Parameters + ---------------- + configfile: str, optional + Path to a custom configuration ``ini`` file. Default: Use system configfile + + See Also + -------- + plugins.extract.pipeline : The extraction pipeline for calling plugins + plugins.extract.align : Aligner plugins + plugins.extract._base : Parent class for all extraction plugins + plugins.extract.detect._base : Detector parent class for extraction plugins. + plugins.extract.mask._base : Masker parent class for extraction plugins. + """ + + def __init__(self, # pylint:disable=too-many-positional-arguments + git_model_id: int | None = None, + model_filename: str | None = None, + configfile: str | None = None, + instance: int = 0, + normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None, + re_feed: int = 0, + re_align: bool = False, + disable_filter: bool = False, + **kwargs) -> None: + logger.debug("Initializing %s: (normalize_method: %s, re_feed: %s, re_align: %s, " + "disable_filter: %s)", self.__class__.__name__, normalize_method, re_feed, + re_align, disable_filter) + super().__init__(git_model_id, + model_filename, + configfile=configfile, + instance=instance, + **kwargs) + self._info.plugin_type = "align" + self.realign_centering: CenteringType = "face" # overide for plugin specific centering + + # Override for specific landmark type: + self.landmark_type = LandmarkType.LM_2D_68 + + self._eof_seen = False + self._normalize_method: T.Literal["clahe", "hist", "mean"] | None = None + self._re_feed = re_feed + self._filter = AlignedFilter(feature_filter=cfg.aligner_features(), + min_scale=cfg.aligner_min_scale(), + max_scale=cfg.aligner_max_scale(), + distance=cfg.aligner_distance(), + roll=cfg.aligner_roll(), + save_output=cfg.save_filtered(), + disable=disable_filter) + self._re_align = ReAlign(re_align, + cfg.realign_refeeds(), + cfg.filter_realign()) + self._needs_refeed_masks: bool = self._re_feed > 0 and ( + cfg.filter_refeed() or (self._re_align.do_refeeds and self._re_align.do_filter)) + self.set_normalize_method(normalize_method) + + logger.debug("Initialized %s", self.__class__.__name__) + + def set_normalize_method(self, method: T.Literal["none", "clahe", "hist", "mean"] | None + ) -> None: + """ Set the normalization method for feeding faces into the aligner. + + Parameters + ---------- + method: {"none", "clahe", "hist", "mean"} + The normalization method to apply to faces prior to feeding into the model + """ + method = None if method is None or method.lower() == "none" else method + self._normalize_method = T.cast(T.Literal["clahe", "hist", "mean"] | None, method) + + def initialize(self, *args, **kwargs) -> None: + """ Add a call to add model input size to the re-aligner """ + self._re_align.set_input_size_and_centering(self.input_size, self.realign_centering) + super().initialize(*args, **kwargs) + + def _handle_realigns(self, queue: Queue) -> tuple[bool, AlignerBatch] | None: + """ Handle any items waiting for a second pass through the aligner. + + If EOF has been recieved and items are still being processed through the first pass + then wait for a short time and try again to collect them. + + On EOF return exhausted flag with an empty batch + + Parameters + ---------- + queue : queue.Queue() + The ``queue`` that the plugin will be fed from. + + Returns + ------- + ``None`` or tuple + If items are processed then returns (`bool`, :class:`AlignerBatch`) containing the + exhausted flag and the batch to be processed. If no items are processed returns + ``None`` + """ + if not self._re_align.active: + return None + + exhausted = False + if self._re_align.items_queued: + batch = self._re_align.get_batch() + logger.trace("Re-align batch: %s", batch) # type: ignore[attr-defined] + return exhausted, batch + + if self._eof_seen and self._re_align.items_tracked: + # EOF seen and items still being processed on first pass + logger.debug("Tracked re-align items waiting to be flushed, retrying...") + sleep(0.25) + return self.get_batch(queue) + + if self._eof_seen: + exhausted = True + logger.debug("All items processed. Returning empty batch") + self._filter.output_counts() + self._eof_seen = False # Reset for plugin re-use + return exhausted, AlignerBatch(batch_id=-1) + + return None + + def get_batch(self, queue: Queue) -> tuple[bool, AlignerBatch]: + """ Get items for inputting into the aligner from the queue in batches + + Items are returned from the ``queue`` in batches of + :attr:`~plugins.extract._base.Extractor.batchsize` + + Items are received as :class:`~plugins.extract.extract_media.ExtractMedia` objects and + converted to ``dict`` for internal processing. + + To ensure consistent batch sizes for aligner the items are split into separate items for + each :class:`~lib.align.DetectedFace` object. + + Remember to put ``'EOF'`` to the out queue after processing + the final batch + + Outputs items in the following format. All lists are of length + :attr:`~plugins.extract._base.Extractor.batchsize`: + + >>> {'filename': [], + >>> 'image': [], + >>> 'detected_faces': [[ np.ndarray: + """ Overide for specific plugin processing to convert a batch of face images from UINT8 + (0-255) into the correct format for the plugin's inference + + Parameters + ---------- + faces: :class:`numpy.ndarray` + The batch of faces in UINT8 format + + Returns + ------- + class: `numpy.ndarray` + The batch of faces in the format to feed through the plugin + """ + raise NotImplementedError() + + # <<< FINALIZE METHODS >>> # + def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]: + """ Finalize the output from Aligner + + This should be called as the final task of each `plugin`. + + Pairs the detected faces back up with their original frame before yielding each frame. + + Parameters + ---------- + batch : :class:`AlignerBatch` + The final batch item from the `plugin` process. + + Yields + ------ + :class:`~plugins.extract.extract_media.ExtractMedia` + The :attr:`DetectedFaces` list will be populated for this class with the bounding boxes + and landmarks for the detected faces found in the frame. + """ + assert isinstance(batch, AlignerBatch) + if not batch.second_pass and self._re_align.active: + # Add the batch for second pass re-alignment and return + self._re_align.add_batch(batch) + return + for face, landmarks in zip(batch.detected_faces, batch.landmarks): + if not isinstance(landmarks, np.ndarray): + landmarks = np.array(landmarks) + face.add_landmarks_xy(landmarks) + + logger.trace("Item out: %s", batch) # type: ignore[attr-defined] + + for frame, filename, face in zip(batch.image, batch.filename, batch.detected_faces): + self._tracker.output_faces.append(face) + if len(self._tracker.output_faces) != self._tracker.faces_per_filename[filename]: + continue + + self._tracker.output_faces, folders = self._filter(self._tracker.output_faces, + min(frame.shape[:2])) + + output = self._extract_media.pop(filename) + output.add_detected_faces(self._tracker.output_faces) + output.add_sub_folders(folders) + self._tracker.output_faces = [] + + logger.trace("Final Output: (filename: '%s', image " # type: ignore[attr-defined] + "shape: %s, detected_faces: %s, item: %s)", output.filename, + output.image_shape, output.detected_faces, output) + yield output + self._re_align.untrack_batch(batch.batch_id) + + def on_completion(self) -> None: + """ Output the filter counts when process has completed """ + self._filter.output_counts() + + # <<< PROTECTED METHODS >>> # + # << PROCESS_INPUT WRAPPER >> + def _get_adjusted_boxes(self, original_boxes: np.ndarray) -> np.ndarray: + """ Obtain an array of adjusted bounding boxes based on the number of re-feed iterations + that have been selected and the minimum dimension of the original bounding box. + + Parameters + ---------- + original_boxes: :class:`numpy.ndarray` + The original ('x', 'y', 'w', 'h') detected face boxes corresponding to the incoming + detected face objects + + Returns + ------- + :class:`numpy.ndarray` + The original boxes (in position 0) and the randomly adjusted bounding boxes + """ + if self._re_feed == 0: + return original_boxes[None, ...] + beta = 0.05 + max_shift = np.min(original_boxes[..., 2:], axis=1) * beta + rands = np.random.rand(self._re_feed, *original_boxes.shape) * 2 - 1 + new_boxes = np.rint(original_boxes + (rands * max_shift[None, :, None])).astype("int32") + retval = np.concatenate((original_boxes[None, ...], new_boxes)) + logger.trace(retval) # type: ignore[attr-defined] + return retval + + def _process_input_first_pass(self, batch: AlignerBatch) -> None: + """ Standard pre-processing for aligners for first pass (if re-align selected) or the + only pass. + + Process the input to the aligner model multiple times based on the user selected + `re-feed` command line option. This adjusts the bounding box for the face to be fed + into the model by a random amount within 0.05 pixels of the detected face's shortest axis. + + References + ---------- + https://studios.disneyresearch.com/2020/06/29/high-resolution-neural-face-swapping-for-visual-effects/ + + Parameters + ---------- + batch: :class:`AlignerBatch` + Contains the batch that is currently being passed through the plugin process + """ + original_boxes = np.array([(face.left, face.top, face.width, face.height) + for face in batch.detected_faces]) + adjusted_boxes = self._get_adjusted_boxes(original_boxes) + + # Put in random re-feed data to the bounding boxes + for bounding_boxes in adjusted_boxes: + for face, box in zip(batch.detected_faces, bounding_boxes): + face.left, face.top, face.width, face.height = box + + self.process_input(batch) + batch.feed = self.faces_to_feed(self._normalize_faces(batch.feed)) + # Move the populated feed into the batch refeed list. It will be overwritten at next + # iteration + batch.refeeds.append(batch.feed) + + # Place the original bounding box back to detected face objects + for face, box in zip(batch.detected_faces, original_boxes): + face.left, face.top, face.width, face.height = box.tolist() + + def _get_realign_masks(self, batch: AlignerBatch) -> np.ndarray: + """ Obtain the masks required for processing re-aligns + + Parameters + ---------- + batch: :class:`AlignerBatch` + Contains the batch that is currently being passed through the plugin process + + Returns + ------- + :class:`numpy.ndarray` + The filter masks required for masking the re-aligns + """ + if self._re_align.do_refeeds: + retval = batch.second_pass_masks # Masks already calculated during re-feed + elif self._re_align.do_filter: + retval = self._filter.filtered_mask(batch)[None, ...] + else: + retval = np.zeros((batch.landmarks.shape[0], ), dtype="bool")[None, ...] + return retval + + def _process_input_second_pass(self, batch: AlignerBatch) -> None: + """ Process the input for 2nd-pass re-alignment + + Parameters + ---------- + batch: :class:`AlignerBatch` + Contains the batch that is currently being passed through the plugin process + """ + batch.second_pass_masks = self._get_realign_masks(batch) + + if not self._re_align.do_refeeds: + # Expand the dimensions for re-aligns for consistent handling of code + batch.landmarks = batch.landmarks[None, ...] + + refeeds = self._re_align.process_batch(batch) + batch.refeeds = [self.faces_to_feed(self._normalize_faces(faces)) for faces in refeeds] + + def _process_input(self, batch: BatchType) -> AlignerBatch: + """ Perform pre-processing depending on whether this is the first/only pass through the + aligner or the 2nd pass when re-align has been selected + + Parameters + ---------- + batch: :class:`AlignerBatch` + Contains the batch that is currently being passed through the plugin process + + Returns + ------- + :class:`AlignerBatch` + The batch with input processed + """ + assert isinstance(batch, AlignerBatch) + if batch.second_pass: + self._process_input_second_pass(batch) + else: + self._process_input_first_pass(batch) + return batch + + # <<< PREDICT WRAPPER >>> # + def _predict(self, batch: BatchType) -> AlignerBatch: + """ Just return the aligner's predict function + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch to find alignments for + + Returns + ------- + :class:`AlignerBatch` + The batch item with the :attr:`prediction` populated + + Raises + ------ + FaceswapError + If GPU resources are exhausted + """ + assert isinstance(batch, AlignerBatch) + try: + preds = [self.predict(feed) for feed in batch.refeeds] + try: + batch.prediction = np.array(preds) + logger.trace("Aligner out: %s", # type:ignore[attr-defined] + batch.prediction.shape) + except ValueError as err: + # If refeed batches are different sizes, Numpy will error, so we need to explicitly + # set the dtype to 'object' rather than let it infer + # numpy error: + # ValueError: setting an array element with a sequence. The requested array has an + # inhomogeneous shape after 1 dimensions. The detected shape was (9,) + + # inhomogeneous part + if "inhomogeneous" in str(err): + logger.trace( # type:ignore[attr-defined] + "Mismatched array sizes, setting dtype to object: %s", + [p.shape for p in preds]) + batch.prediction = np.array(preds, dtype="object") + else: + raise + + except OutOfMemoryError as err: + msg = ("You do not have enough GPU memory available to run detection at the " + "selected batch size. You can try a number of things:" + "\n1) Close any other application that is using your GPU (web browsers are " + "particularly bad for this)." + "\n2) Lower the batchsize (the amount of images fed into the model) by " + "editing the plugin settings (GUI: Settings > Configure extract settings, " + "CLI: Edit the file faceswap/config/extract.ini)." + "\n3) Enable 'Single Process' mode.") + raise FaceswapError(msg) from err + + return batch + + def _process_refeeds(self, batch: AlignerBatch) -> list[AlignerBatch]: + """ Process the output for each selected re-feed + + Parameters + ---------- + batch: :class:`AlignerBatch` + The batch object passing through the aligner + + Returns + ------- + list + List of :class:`AlignerBatch` objects. Each object in the list contains the + results for each selected re-feed + """ + retval: list[AlignerBatch] = [] + if batch.second_pass: + # Re-insert empty sub-patches for re-population in ReAlign for filtered out batches + selected_idx = 0 + for mask in batch.second_pass_masks: + all_filtered = np.all(mask) + if not all_filtered: + feed = batch.refeeds[selected_idx] + pred = batch.prediction[selected_idx] + data = batch.data[selected_idx] if batch.data else {} + selected_idx += 1 + else: # All resuts have been filtered out + feed = pred = np.array([]) + data = {} + + subbatch = AlignerBatch(batch_id=batch.batch_id, + image=batch.image, + detected_faces=batch.detected_faces, + filename=batch.filename, + feed=feed, + prediction=pred, + data=[data], + second_pass=batch.second_pass) + + if not all_filtered: + self.process_output(subbatch) + + retval.append(subbatch) + else: + b_data = batch.data if batch.data else [{}] + for feed, pred, dat in zip(batch.refeeds, batch.prediction, b_data): + subbatch = AlignerBatch(batch_id=batch.batch_id, + image=batch.image, + detected_faces=batch.detected_faces, + filename=batch.filename, + feed=feed, + prediction=pred, + data=[dat], + second_pass=batch.second_pass) + self.process_output(subbatch) + retval.append(subbatch) + return retval + + def _get_refeed_filter_masks(self, + subbatches: list[AlignerBatch], + original_masks: np.ndarray | None = None) -> np.ndarray: + """ Obtain the boolean mask array for masking out failed re-feed results if filter refeed + has been selected + + Parameters + ---------- + subbatches: list + List of sub-batch results for each re-feed performed + original_masks: :class:`numpy.ndarray`, Optional + If passing in the second pass landmarks, these should be the original filter masks so + that we don't calculate the mask again for already filtered faces. Default: ``None`` + + Returns + ------- + :class:`numpy.ndarray` + boolean values for every detected face indicating whether the interim landmarks have + passed the filter test + """ + retval = np.zeros((len(subbatches), subbatches[0].landmarks.shape[0]), dtype="bool") + + if not self._needs_refeed_masks: + return retval + + retval = retval if original_masks is None else original_masks + for subbatch, masks in zip(subbatches, retval): + masks[:] = self._filter.filtered_mask(subbatch, np.flatnonzero(masks)) + return retval + + def _get_mean_landmarks(self, landmarks: np.ndarray, masks: np.ndarray) -> np.ndarray: + """ Obtain the averaged landmarks from the re-fed alignments. If config option + 'filter_refeed' is enabled, then average those results which have not been filtered out + otherwise average all results + + Parameters + ---------- + landmarks: :class:`numpy.ndarray` + The batch of re-fed alignments + masks: :class:`numpy.ndarray` + List of boolean values indicating whether each re-fed alignments passed or failed + the filter test + + Returns + ------- + :class:`numpy.ndarray` + The final averaged landmarks + """ + if any(np.all(masked) for masked in masks.T): + # hacky fix for faces which entirely failed the filter + # We just unmask one value as it is junk anyway and will be discarded on output + for idx, masked in enumerate(masks.T): + if np.all(masked): + masks[0, idx] = False + + masks = np.broadcast_to(np.reshape(masks, (*landmarks.shape[:2], 1, 1)), + landmarks.shape) + return np.ma.array(landmarks, mask=masks).mean(axis=0).data.astype("float32") + + def _process_output_first_pass(self, subbatches: list[AlignerBatch]) -> tuple[np.ndarray, + np.ndarray]: + """ Process the output from the aligner if this is the first or only pass. + + Parameters + ---------- + subbatches: list + List of sub-batch results for each re-feed performed + + Returns + ------- + landmarks: :class:`numpy.ndarray` + If re-align is not selected or if re-align has been selected but only on the final + output (ie: realign_reefeeds is ``False``) then the averaged batch of landmarks for all + re-feeds is returned. + If re-align_refeeds has been selected, then this will output each batch of re-feed + landmarks. + masks: :class:`numpy.ndarray` + Boolean mask corresponding to the re-fed landmarks output indicating any values which + should be filtered out prior to further processing + """ + masks = self._get_refeed_filter_masks(subbatches) + all_landmarks = np.array([sub.landmarks for sub in subbatches]) + + # re-align not selected or not filtering the re-feeds + if not self._re_align.do_refeeds: + retval = self._get_mean_landmarks(all_landmarks, masks) + return retval, masks + + # Re-align selected with filter re-feeds + return all_landmarks, masks + + def _process_output_second_pass(self, + subbatches: list[AlignerBatch], + masks: np.ndarray) -> np.ndarray: + """ Process the output from the aligner if this is the first or only pass. + + Parameters + ---------- + subbatches: list + List of sub-batch results for each re-aligned re-feed performed + masks: :class:`numpy.ndarray` + The original re-feed filter masks from the first pass + """ + self._re_align.process_output(subbatches, masks) + masks = self._get_refeed_filter_masks(subbatches, original_masks=masks) + all_landmarks = np.array([sub.landmarks for sub in subbatches]) + return self._get_mean_landmarks(all_landmarks, masks) + + def _process_output(self, batch: BatchType) -> AlignerBatch: + """ Process the output from the aligner model multiple times based on the user selected + `re-feed amount` configuration option, then average the results for final prediction. + + If the config option 'filter_refeed' is enabled, then mask out any returned alignments + that fail a filter test + + Parameters + ---------- + batch : :class:`AlignerBatch` + Contains the batch that is currently being passed through the plugin process + + Returns + ------- + :class:`AlignerBatch` + The batch item with :attr:`landmarks` populated + """ + assert isinstance(batch, AlignerBatch) + subbatches = self._process_refeeds(batch) + if batch.second_pass: + batch.landmarks = self._process_output_second_pass(subbatches, batch.second_pass_masks) + else: + landmarks, masks = self._process_output_first_pass(subbatches) + batch.landmarks = landmarks + batch.second_pass_masks = masks + return batch + + # <<< FACE NORMALIZATION METHODS >>> # + def _normalize_faces(self, faces: np.ndarray) -> np.ndarray: + """ Normalizes the face for feeding into model + The normalization method is dictated by the normalization command line argument + + Parameters + ---------- + faces: :class:`numpy.ndarray` + The batch of faces to normalize + + Returns + ------- + :class:`numpy.ndarray` + The normalized faces + """ + if self._normalize_method is None: + return faces + logger.trace("Normalizing faces") # type: ignore[attr-defined] + meth = getattr(self, f"_normalize_{self._normalize_method.lower()}") + faces = np.array([meth(face) for face in faces]) + logger.trace("Normalized faces") # type: ignore[attr-defined] + return faces + + @classmethod + def _normalize_mean(cls, face: np.ndarray) -> np.ndarray: + """ Normalize Face to the Mean + + Parameters + ---------- + face: :class:`numpy.ndarray` + The face to normalize + + Returns + ------- + :class:`numpy.ndarray` + The normalized face + """ + face = face / 255.0 + for chan in range(3): + layer = face[:, :, chan] + layer = (layer - layer.min()) / (layer.max() - layer.min()) + face[:, :, chan] = layer + return face * 255.0 + + @classmethod + def _normalize_hist(cls, face: np.ndarray) -> np.ndarray: + """ Equalize the RGB histogram channels + + Parameters + ---------- + face: :class:`numpy.ndarray` + The face to normalize + + Returns + ------- + :class:`numpy.ndarray` + The normalized face + """ + for chan in range(3): + face[:, :, chan] = cv2.equalizeHist(face[:, :, chan]) + return face + + @classmethod + def _normalize_clahe(cls, face: np.ndarray) -> np.ndarray: + """ Perform Contrast Limited Adaptive Histogram Equalization + + Parameters + ---------- + face: :class:`numpy.ndarray` + The face to normalize + + Returns + ------- + :class:`numpy.ndarray` + The normalized face + """ + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4, 4)) + for chan in range(3): + face[:, :, chan] = clahe.apply(face[:, :, chan]) + return face diff --git a/plugins/extract/align/_base/processing.py b/plugins/extract/align/_base/processing.py new file mode 100644 index 0000000000..efdeec9468 --- /dev/null +++ b/plugins/extract/align/_base/processing.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python3 +""" Processing methods for aligner plugins """ +from __future__ import annotations +import logging +import typing as T + +from threading import Lock + +import numpy as np + +from lib.align import AlignedFace + +if T.TYPE_CHECKING: + from lib.align import DetectedFace + from .aligner import AlignerBatch + from lib.align.aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +class AlignedFilter(): + """ Applies filters on the output of the aligner + + Parameters + ---------- + feature_filter: bool + ``True`` to enable filter to check relative position of eyes/eyebrows and mouth. ``False`` + to disable. + min_scale: float + Filters out faces that have been aligned at below this value as a multiplier of the + minimum frame dimension. Set to ``0`` for off. + max_scale: float + Filters out faces that have been aligned at above this value as a multiplier of the + minimum frame dimension. Set to ``0`` for off. + distance: float + Filters out faces that are further than this distance from an "average" face. Set to + ``0`` for off. + roll: float + Filters out faces with a roll value outside of 0 +/- the value given here. Set to ``0`` + for off. + save_output: bool + ``True`` if the filtered faces should be kept as they are being saved. ``False`` if they + should be deleted + disable: bool, Optional + ``True`` to disable the filter regardless of config options. Default: ``False`` + """ + def __init__(self, + feature_filter: bool, + min_scale: float, + max_scale: float, + distance: float, + roll: float, + save_output: bool, + disable: bool = False) -> None: + logger.debug("Initializing %s: (feature_filter: %s, min_scale: %s, max_scale: %s, " + "distance: %s, roll, %s, save_output: %s, disable: %s)", + self.__class__.__name__, feature_filter, min_scale, max_scale, distance, roll, + save_output, disable) + self._features = feature_filter + self._min_scale = min_scale + self._max_scale = max_scale + self._distance = distance / 100. + self._roll = roll + self._save_output = save_output + self._active = not disable and (feature_filter or + max_scale > 0.0 or + min_scale > 0.0 or + distance > 0.0 or + roll > 0.0) + self._counts: dict[str, int] = {"features": 0, + "min_scale": 0, + "max_scale": 0, + "distance": 0, + "roll": 0} + logger.debug("Initialized %s: ", self.__class__.__name__) + + def _scale_test(self, + face: AlignedFace, + minimum_dimension: int) -> T.Literal["min", "max"] | None: + """ Test if a face is below or above the min/max size thresholds. Returns as soon as a test + fails. + + Parameters + ---------- + face: :class:`~lib.aligned.AlignedFace` + The aligned face to test the original size of. + + minimum_dimension: int + The minimum (height, width) of the original frame + + Returns + ------- + "min", "max" or ``None`` + Returns min or max if the face failed the minimum or maximum test respectively. + ``None`` if all tests passed + """ + + if self._min_scale <= 0.0 and self._max_scale <= 0.0: + return None + + roi = face.original_roi.astype("int64") + size = ((roi[1][0] - roi[0][0]) ** 2 + (roi[1][1] - roi[0][1]) ** 2) ** 0.5 + + if self._min_scale > 0.0 and size < minimum_dimension * self._min_scale: + return "min" + + if self._max_scale > 0.0 and size > minimum_dimension * self._max_scale: + return "max" + + return None + + def _handle_filtered(self, + key: str, + face: DetectedFace, + faces: list[DetectedFace], + sub_folders: list[str | None], + sub_folder_index: int) -> None: + """ Add the filtered item to the filter counts. + + If config option `save_filtered` has been enabled then add the face to the output faces + list and update the sub_folder list with the correct name for this face. + + Parameters + ---------- + key: str + The key to use for the filter counts dictionary and the sub_folder name + face: :class:`~lib.align.detected_face.DetectedFace` + The detected face object to be filtered out + faces: list + The list of faces that will be returned from the filter + sub_folders: list + List of sub folder names corresponding to the list of detected face objects + sub_folder_index: int + The index within the sub-folder list that the filtered face belongs to + """ + self._counts[key] += 1 + if not self._save_output: + return + + faces.append(face) + sub_folders[sub_folder_index] = f"_align_filt_{key}" + + def __call__(self, faces: list[DetectedFace], minimum_dimension: int + ) -> tuple[list[DetectedFace], list[str | None]]: + """ Apply the filter to the incoming batch + + Parameters + ---------- + faces: list + List of detected face objects to filter out on size + minimum_dimension: int + The minimum (height, width) of the original frame + + Returns + ------- + detected_faces: list + The filtered list of detected face objects, if saving filtered faces has not been + selected or the full list of detected faces + sub_folders: list + List of ``Nones`` if saving filtered faces has not been selected or list of ``Nones`` + and sub folder names corresponding the filtered face location + """ + sub_folders: list[str | None] = [None for _ in range(len(faces))] + if not self._active: + return faces, sub_folders + + retval: list[DetectedFace] = [] + for idx, face in enumerate(faces): + aligned = AlignedFace(landmarks=face.landmarks_xy, centering="face") + + if self._features and aligned.relative_eye_mouth_position < 0.0: + self._handle_filtered("features", face, retval, sub_folders, idx) + continue + + min_max = self._scale_test(aligned, minimum_dimension) + if min_max in ("min", "max"): + self._handle_filtered(f"{min_max}_scale", face, retval, sub_folders, idx) + continue + + if 0.0 < self._distance < aligned.average_distance: + self._handle_filtered("distance", face, retval, sub_folders, idx) + continue + + if self._roll != 0.0 and not 0.0 < abs(aligned.pose.roll) < self._roll: + self._handle_filtered("roll", face, retval, sub_folders, idx) + continue + + retval.append(face) + return retval, sub_folders + + def filtered_mask(self, + batch: AlignerBatch, + skip: np.ndarray | list[int] | None = None) -> np.ndarray: + """ Obtain a list of boolean values for the given batch indicating whether they pass the + filter test. + + Parameters + ---------- + batch: :class:`AlignerBatch` + The batch of face to obtain masks for + skip: list or :class:`numpy.ndarray`, optional + List or 1D numpy array of indices indicating faces that have already been filter + masked and so should not be filtered again. Values in these index positions will be + returned as ``True`` + + Returns + ------- + :class:`numpy.ndarray` + Boolean mask array corresponding to any of the input DetectedFace objects that passed a + test. ``False`` the face passed the test. ``True`` it failed + """ + skip = [] if skip is None else skip + retval = np.ones((len(batch.detected_faces), ), dtype="bool") + for idx, (landmarks, image) in enumerate(zip(batch.landmarks, batch.image)): + if idx in skip: + continue + face = AlignedFace(landmarks) + if self._features and face.relative_eye_mouth_position < 0.0: + continue + if self._scale_test(face, min(image.shape[:2])) is not None: + continue + if 0.0 < self._distance < face.average_distance: + continue + if self._roll != 0.0 and not 0.0 < abs(face.pose.roll) < self._roll: + continue + retval[idx] = False + return retval + + def output_counts(self): + """ Output the counts of filtered items """ + if not self._active: + return + counts = [f"{key} ({getattr(self, f'_{key}'):.2f}): {count}" + for key, count in self._counts.items() + if count > 0] + if counts: + logger.info("Aligner filtered: (%s)", ", ".join(counts)) + + +class ReAlign(): + """ Holds data and methods for 2nd pass re-aligns + + Parameters + ---------- + active: bool + ``True`` if re-alignment has been requested otherwise ``False`` + do_refeeds: bool + ``True`` if re-feeds should be re-aligned, ``False`` if just the final output of the + re-feeds should be aligned + do_filter: bool + ``True`` if aligner filtered out faces should not be re-aligned. ``False`` if all faces + should be re-aligned + """ + def __init__(self, active: bool, do_refeeds: bool, do_filter: bool) -> None: + logger.debug("Initializing %s: (active: %s, do_refeeds: %s, do_filter: %s)", + self.__class__.__name__, active, do_refeeds, do_filter) + self._active = active + self._do_refeeds = do_refeeds + self._do_filter = do_filter + self._centering: CenteringType = "face" + self._size = 0 + self._tracked_lock = Lock() + self._tracked_batchs: dict[int, + dict[T.Literal["filtered_landmarks"], list[np.ndarray]]] = {} + # TODO. Probably does not need to be a list, just alignerbatch + self._queue_lock = Lock() + self._queued: list[AlignerBatch] = [] + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def active(self) -> bool: + """bool: ``True`` if re_aligns have been selected otherwise ``False``""" + return self._active + + @property + def do_refeeds(self) -> bool: + """bool: ``True`` if re-aligning is active and re-aligning re-feeds has been selected + otherwise ``False``""" + return self._active and self._do_refeeds + + @property + def do_filter(self) -> bool: + """bool: ``True`` if re-aligning is active and faces which failed the aligner filter test + should not be re-aligned otherwise ``False``""" + return self._active and self._do_filter + + @property + def items_queued(self) -> bool: + """bool: ``True`` if re-align is active and items are queued for a 2nd pass otherwise + ``False`` """ + with self._queue_lock: + return self._active and bool(self._queued) + + @property + def items_tracked(self) -> bool: + """bool: ``True`` if items exist in the tracker so still need to be processed """ + with self._tracked_lock: + return bool(self._tracked_batchs) + + def set_input_size_and_centering(self, input_size: int, centering: CenteringType) -> None: + """ Set the input size of the loaded plugin once the model has been loaded + + Parameters + ---------- + input_size: int + The input size, in pixels, of the aligner plugin + centering: ["face", "head" or "legacy"] + The centering to align the image at for re-aligning + """ + logger.debug("input_size: %s, centering: %s", input_size, centering) + self._size = input_size + self._centering = centering + + def track_batch(self, batch_id: int) -> None: + """ Add newly seen batch id from the aligner to the batch tracker, so that we can keep + track of whether there are still batches to be processed when the aligner hits 'EOF' + + Parameters + ---------- + batch_id: int + The batch id to add to batch tracking + """ + if not self._active: + return + logger.trace("Tracking batch id: %s", batch_id) # type: ignore[attr-defined] + with self._tracked_lock: + self._tracked_batchs[batch_id] = {} + + def untrack_batch(self, batch_id: int) -> None: + """ Remove the tracked batch from the tracker once the batch has been fully processed + + Parameters + ---------- + batch_id: int + The batch id to remove from batch tracking + """ + if not self._active: + return + logger.trace("Removing batch id from tracking: %s", batch_id) # type: ignore[attr-defined] + with self._tracked_lock: + del self._tracked_batchs[batch_id] + + def add_batch(self, batch: AlignerBatch) -> None: + """ Add first pass alignments to the queue for picking up for re-alignment, update their + :attr:`second_pass` attribute to ``True`` and clear attributes not required. + + Parameters + ---------- + batch: :class:`AlignerBatch` + aligner batch to perform re-alignment on + """ + with self._queue_lock: + logger.trace("Queueing for second pass: %s", batch) # type: ignore[attr-defined] + batch.second_pass = True + batch.feed = np.array([]) + batch.prediction = np.array([]) + batch.refeeds = [] + batch.data = [] + self._queued.append(batch) + + def get_batch(self) -> AlignerBatch: + """ Retrieve the next batch currently queued for re-alignment + + Returns + ------- + :class:`AlignerBatch` + The next :class:`AlignerBatch` for re-alignment + """ + with self._queue_lock: + retval = self._queued.pop(0) + logger.trace("Retrieving for second pass: %s", # type: ignore[attr-defined] + retval.filename) + return retval + + def process_batch(self, batch: AlignerBatch) -> list[np.ndarray]: + """ Pre process a batch object for re-aligning through the aligner. + + Parameters + ---------- + batch: :class:`AlignerBatch` + aligner batch to perform pre-processing on + + Returns + ------- + list + List of UINT8 aligned faces batch for each selected refeed + """ + logger.trace("Processing batch: %s, landmarks: %s", # type: ignore[attr-defined] + batch.filename, [b.shape for b in batch.landmarks]) + retval: list[np.ndarray] = [] + filtered_landmarks: list[np.ndarray] = [] + for landmarks, masks in zip(batch.landmarks, batch.second_pass_masks): + if not np.all(masks): # At least one face has not already been filtered + aligned_faces = [AlignedFace(lms, + image=image, + size=self._size, + centering=self._centering) + for image, lms, msk in zip(batch.image, landmarks, masks) + if not msk] + faces = np.array([aligned.face for aligned in aligned_faces + if aligned.face is not None]) + retval.append(faces) + batch.data.append({"aligned_faces": aligned_faces}) + + if np.any(masks): + # Track the original landmarks for re-insertion on the other side + filtered_landmarks.append(landmarks[masks]) + + with self._tracked_lock: + self._tracked_batchs[batch.batch_id] = {"filtered_landmarks": filtered_landmarks} + batch.landmarks = np.array([]) # Clear the old landmarks + return retval + + def _transform_to_frame(self, batch: AlignerBatch) -> np.ndarray: + """ Transform the predicted landmarks from the aligned face image back into frame + co-ordinates + + Parameters + ---------- + batch: :class:`AlignerBatch` + An aligner batch containing the aligned faces in the data field and the face + co-ordinate landmarks in the landmarks field + + Returns + ------- + :class:`numpy.ndarray` + The landmarks transformed to frame space + """ + faces: list[AlignedFace] = batch.data[0]["aligned_faces"] + retval = np.array([aligned.transform_points(landmarks, invert=True) + for landmarks, aligned in zip(batch.landmarks, faces)]) + logger.trace("Transformed points: original max: %s, " # type: ignore[attr-defined] + "new max: %s", batch.landmarks.max(), retval.max()) + return retval + + def _re_insert_filtered(self, batch: AlignerBatch, masks: np.ndarray) -> np.ndarray: + """ Re-insert landmarks that were filtered out from the re-align process back into the + landmark results + + Parameters + ---------- + batch: :class:`AlignerBatch` + An aligner batch containing the aligned faces in the data field and the landmarks in + frame space in the landmarks field + masks: np.ndarray + The original filter masks for this batch + + Returns + ------- + :class:`numpy.ndarray` + The full batch of landmarks with filtered out values re-inserted + """ + if not np.any(masks): + logger.trace("No landmarks to re-insert: %s", masks) # type: ignore[attr-defined] + return batch.landmarks + + with self._tracked_lock: + filtered = self._tracked_batchs[batch.batch_id]["filtered_landmarks"].pop(0) + + if np.all(masks): + retval = filtered + else: + retval = np.empty((masks.shape[0], *filtered.shape[1:]), dtype=filtered.dtype) + retval[~masks] = batch.landmarks + retval[masks] = filtered + + logger.trace("Filtered re-inserted: old shape: %s, " # type: ignore[attr-defined] + "new shape: %s)", batch.landmarks.shape, retval.shape) + + return retval + + def process_output(self, subbatches: list[AlignerBatch], batch_masks: np.ndarray) -> None: + """ Process the output from the re-align pass. + + - Transform landmarks from aligned face space to face space + - Re-insert faces that were filtered out from the re-align process back into the + landmarks list + + Parameters + ---------- + subbatches: list + List of sub-batch results for each re-aligned re-feed performed + batch_masks: :class:`numpy.ndarray` + The original re-feed filter masks from the first pass + """ + for batch, masks in zip(subbatches, batch_masks): + if not np.all(masks): + batch.landmarks = self._transform_to_frame(batch) + batch.landmarks = self._re_insert_filtered(batch, masks) diff --git a/plugins/extract/align/cv2_dnn.py b/plugins/extract/align/cv2_dnn.py new file mode 100644 index 0000000000..a695f7f11d --- /dev/null +++ b/plugins/extract/align/cv2_dnn.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" CV2 DNN landmarks extractor for faceswap.py +Adapted from: https://github.com/yinguobing/cnn-facial-landmark +MIT License + +Copyright (c) 2017 Yin Guobing + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from __future__ import annotations +import logging +import typing as T + +import cv2 +import numpy as np + +from lib.utils import get_module_objects +from ._base import Aligner, AlignerBatch, BatchType + +if T.TYPE_CHECKING: + from lib.align.detected_face import DetectedFace + +logger = logging.getLogger(__name__) + + +class Align(Aligner): + """ Perform transformation to align and get landmarks """ + def __init__(self, **kwargs) -> None: + git_model_id = 1 + model_filename = "cnn-facial-landmark_v1.pb" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + + self.model: cv2.dnn.Net + self.model_path: str + self.name = "cv2-DNN Aligner" + self.input_size = 128 + self.color_format = "RGB" + self.vram = 0 # Doesn't use GPU + self.vram_per_batch = 0 + self.batchsize = 1 + self.realign_centering = "legacy" + + def init_model(self) -> None: + """ Initialize CV2 DNN Detector Model""" + self.model = cv2.dnn.readNetFromTensorflow(self.model_path) + self.model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) + + def faces_to_feed(self, faces: np.ndarray) -> np.ndarray: + """ Convert a batch of face images from UINT8 (0-255) to fp32 (0.0-255.0) + + Parameters + ---------- + faces: :class:`numpy.ndarray` + The batch of faces in UINT8 format + + Returns + ------- + class: `numpy.ndarray` + The batch of faces as fp32 + """ + return faces.astype("float32").transpose((0, 3, 1, 2)) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch to process input for + + Returns + ------- + :class:`AlignerBatch` + The batch item with the :attr:`feed` populated and any required :attr:`data` added + """ + assert isinstance(batch, AlignerBatch) + lfaces, roi, offsets = self.align_image(batch) + batch.feed = np.array(lfaces)[..., :3] + batch.data.append({"roi": roi, "offsets": offsets}) + + def _get_box_and_offset(self, face: DetectedFace) -> tuple[list[int], int]: + """Obtain the bounding box and offset from a detected face. + + + Parameters + ---------- + face: :class:`~lib.align.DetectedFace` + The detected face object to obtain the bounding box and offset from + + Returns + ------- + box: list + The [left, top, right, bottom] bounding box + offset: int + The offset of the box (difference between half width vs height) + """ + + box = T.cast(list[int], [face.left, + face.top, + face.right, + face.bottom]) + diff_height_width = T.cast(int, face.height) - T.cast(int, face.width) + offset = int(abs(diff_height_width / 2)) + return box, offset + + def align_image(self, batch: AlignerBatch) -> tuple[list[np.ndarray], + list[list[int]], + list[tuple[int, int]]]: + """ Align the incoming image for prediction + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch to align the input for + + Returns + ------- + faces: list + List of feed faces for the aligner + rois: list + List of roi's for the faces + offsets: list + List of offsets for the faces + """ + logger.trace("Aligning image around center") # type:ignore[attr-defined] + sizes = (self.input_size, self.input_size) + rois = [] + faces = [] + offsets = [] + for det_face, image in zip(batch.detected_faces, batch.image): + box, offset_y = self._get_box_and_offset(det_face) + box_moved = self.move_box(box, (0, offset_y)) + # Make box square. + roi = self.get_square_box(box_moved) + + # Pad the image and adjust roi if face is outside of boundaries + image, offset = self.pad_image(roi, image) + face = image[roi[1] + abs(offset[1]): roi[3] + abs(offset[1]), + roi[0] + abs(offset[0]): roi[2] + abs(offset[0])] + interpolation = cv2.INTER_CUBIC if face.shape[0] < self.input_size else cv2.INTER_AREA + face = cv2.resize(face, dsize=sizes, interpolation=interpolation) + faces.append(face) + rois.append(roi) + offsets.append(offset) + return faces, rois, offsets + + @classmethod + def move_box(cls, + box: list[int], + offset: tuple[int, int]) -> list[int]: + """Move the box to direction specified by vector offset + + Parameters + ---------- + box: list + The (`left`, `top`, `right`, `bottom`) box positions + offset: tuple + (x, y) offset to move the box + + Returns + ------- + list + The original box shifted by the offset + """ + left = box[0] + offset[0] + top = box[1] + offset[1] + right = box[2] + offset[0] + bottom = box[3] + offset[1] + return [left, top, right, bottom] + + @staticmethod + def get_square_box(box: list[int]) -> list[int]: + """Get a square box out of the given box, by expanding it. + + Parameters + ---------- + box: list + The (`left`, `top`, `right`, `bottom`) box positions + + Returns + ------- + list + The original box but made square + """ + left = box[0] + top = box[1] + right = box[2] + bottom = box[3] + + box_width = right - left + box_height = bottom - top + + # Check if box is already a square. If not, make it a square. + diff = box_height - box_width + delta = int(abs(diff) / 2) + + if diff == 0: # Already a square. + return box + if diff > 0: # Height > width, a slim box. + left -= delta + right += delta + if diff % 2 == 1: + right += 1 + else: # Width > height, a short box. + top -= delta + bottom += delta + if diff % 2 == 1: + bottom += 1 + + # Make sure box is always square. + assert ((right - left) == (bottom - top)), 'Box is not square.' + + return [left, top, right, bottom] + + @classmethod + def pad_image(cls, box: list[int], image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]: + """Pad image if face-box falls outside of boundaries + + Parameters + ---------- + box: list + The (`left`, `top`, `right`, `bottom`) roi box positions + image: :class:`numpy.ndarray` + The image to be padded + + Returns + ------- + :class:`numpy.ndarray` + The padded image + """ + height, width = image.shape[:2] + pad_l = 1 - box[0] if box[0] < 0 else 0 + pad_t = 1 - box[1] if box[1] < 0 else 0 + pad_r = box[2] - width if box[2] > width else 0 + pad_b = box[3] - height if box[3] > height else 0 + logger.trace("Padding: (l: %s, t: %s, r: %s, b: %s)", # type:ignore[attr-defined] + pad_l, pad_t, pad_r, pad_b) + padded_image = cv2.copyMakeBorder(image.copy(), + pad_t, + pad_b, + pad_l, + pad_r, + cv2.BORDER_CONSTANT, + value=(0, 0, 0)) + offsets = (pad_l - pad_r, pad_t - pad_b) + logger.trace("image_shape: %s, Padded shape: %s, box: %s, " # type:ignore[attr-defined] + "offsets: %s", + image.shape, padded_image.shape, box, offsets) + return padded_image, offsets + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Predict the 68 point landmarks + + Parameters + ---------- + feed: :class:`numpy.ndarray` + The batch to feed into the aligner + + Returns + ------- + :class:`numpy.ndarray` + The predictions from the aligner + """ + assert isinstance(self.model, cv2.dnn.Net) + self.model.setInput(feed) + retval = self.model.forward() + return retval + + def process_output(self, batch: BatchType) -> None: + """ Process the output from the model + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch from the model with :attr:`predictions` populated + """ + assert isinstance(batch, AlignerBatch) + self.get_pts_from_predict(batch) + + def get_pts_from_predict(self, batch: AlignerBatch): + """ Get points from predictor and populates the :attr:`landmarks` property + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch from the model with :attr:`predictions` populated + """ + landmarks = [] + if batch.second_pass: + batch.landmarks = batch.prediction.reshape(self.batchsize, -1, 2) * self.input_size + else: + for prediction, roi, offset in zip(batch.prediction, + batch.data[0]["roi"], + batch.data[0]["offsets"]): + points = np.reshape(prediction, (-1, 2)) + points *= (roi[2] - roi[0]) + points[:, 0] += (roi[0] - offset[0]) + points[:, 1] += (roi[1] - offset[1]) + landmarks.append(points) + batch.landmarks = np.array(landmarks) + logger.trace("Predicted Landmarks: %s", batch.landmarks) # type:ignore[attr-defined] + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/align/dlib.py b/plugins/extract/align/dlib.py deleted file mode 100644 index 5e8d367530..0000000000 --- a/plugins/extract/align/dlib.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -""" DLib landmarks extractor for faceswap.py """ -import face_recognition_models -import dlib - -from ._base import Aligner, logger - - -class Align(Aligner): - """ Perform transformation to align and get landmarks """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.vram = 0 # Doesn't use GPU - self.model = None - - def set_model_path(self): - """ Model path handled by face_recognition_models """ - model_path = face_recognition_models.pose_predictor_model_location() - logger.debug("Loading model: '%s'", model_path) - return model_path - - def initialize(self, *args, **kwargs): - """ Initialization tasks to run prior to alignments """ - super().initialize(*args, **kwargs) - logger.info("Initializing Dlib Pose Predictor...") - logger.debug("dlib initialize: (args: %s kwargs: %s)", args, kwargs) - self.model = dlib.shape_predictor(self.model_path) # pylint: disable=c-extension-no-member - self.init.set() - logger.info("Initialized Dlib Pose Predictor.") - - def align(self, *args, **kwargs): - """ Perform alignments on detected faces """ - super().align(*args, **kwargs) - for item in self.get_item(): - if item == "EOF": - self.finalize(item) - break - image = item["image"][:, :, ::-1].copy() - - logger.trace("Algning faces") - item["landmarks"] = self.process_landmarks(image, item["detected_faces"]) - logger.trace("Algned faces: %s", item["landmarks"]) - - self.finalize(item) - logger.debug("Completed Align") - - def process_landmarks(self, image, detected_faces): - """ Align image and process landmarks """ - logger.trace("Processing Landmarks") - retval = list() - for detected_face in detected_faces: - pts = self.model(image, detected_face).parts() - landmarks = [(point.x, point.y) for point in pts] - retval.append(landmarks) - logger.trace("Processed Landmarks: %s", retval) - return retval diff --git a/plugins/extract/align/external.py b/plugins/extract/align/external.py new file mode 100644 index 0000000000..ca5630dc37 --- /dev/null +++ b/plugins/extract/align/external.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +""" Import 68 point landmarks or ROI boxes from a json file """ +from __future__ import annotations +import logging +import typing as T +import os +import re + +import numpy as np + +from lib.align import EXTRACT_RATIOS, LandmarkType +from lib.utils import get_module_objects, FaceswapError, IMAGE_EXTENSIONS + +from ._base import BatchType, Aligner, AlignerBatch +from . import external_defaults as cfg + +if T.TYPE_CHECKING: + from lib.align.constants import CenteringType + +logger = logging.getLogger(__name__) +OriginType = T.Literal["top-left", "bottom-left", "top-right", "bottom-right"] +# pylint:disable=duplicate-code + + +class Align(Aligner): + """ Import face detection bounding boxes from an external json file """ + def __init__(self, **kwargs) -> None: + kwargs["normalize_method"] = None # Disable normalization + kwargs["re_feed"] = 0 # Disable re-feed + kwargs["re_align"] = False # Disablle re-align + kwargs["disable_filter"] = True # Disable aligner filters + super().__init__(git_model_id=None, model_filename=None, **kwargs) + + self.name = "External" + self.batchsize = 16 + self.origin: OriginType = T.cast(OriginType, cfg.origin()) + """ Literal["top-left", "bottom-left", "top-right", "bottom-right"] : The origin (0, 0) + location of the co-ordinates system used""" + self.file_name = cfg.file_name() + """ str : The file name to import landmark data from """ + + self._re_frame_no: re.Pattern = re.compile(r"\d+$") + self._is_video: bool = False + self._imported: dict[str | int, tuple[int, np.ndarray]] = {} + """dict[str | int, tuple[int, np.ndarray]]: filename as key, value of [number of faces + remaining for the frame, all landmarks in the frame] """ + + self._missing: list[str] = [] + self._roll: dict[T.Literal["bottom-left", "top-right", "bottom-right"], int] = { + "bottom-left": 3, "top-right": 1, "bottom-right": 2} + """dict[Literal["bottom-left", "top-right", "bottom-right"], int]: Amount to roll the + points by for different origins when 4 Point ROI landmarks are provided """ + + centering = T.cast("CenteringType", cfg.four_point_centering) + self._adjustment: float = 1. if centering == "none" else 1. - EXTRACT_RATIOS[centering] + """float: The amount to adjust 4 point ROI landmarks to standardize the points for a + 'head' sized extracted face """ + + def init_model(self) -> None: + """ No initialization to perform """ + logger.debug("No aligner model to initialize") + + def _check_for_video(self, filename: str) -> None: + """ Check a sample filename from the import file for a file extension to set + :attr:`_is_video` + + Parameters + ---------- + filename: str + A sample file name from the imported data + """ + logger.debug("Checking for video from '%s'", filename) + ext = os.path.splitext(filename)[-1] + if ext.lower() not in IMAGE_EXTENSIONS: + self._is_video = True + logger.debug("Set is_video to %s from extension '%s'", self._is_video, ext) + + def _get_key(self, key: str) -> str | int: + """ Obtain the key for the item in the lookup table. If the input are images, the key will + be the image filename. If the input is a video, the key will be the frame number + + Parameters + ---------- + key: str + The initial key value from import data or an import image/frame + + Returns + ------- + str | int + The filename is the input data is images, otherwise the frame number of a video + """ + if not self._is_video: + return key + original_name = os.path.splitext(key)[0] + matches = self._re_frame_no.findall(original_name) + if not matches or len(matches) > 1: + raise FaceswapError(f"Invalid import name: '{key}'. For video files, the key should " + "end with the frame number.") + retval = int(matches[0]) + logger.trace("Obtained frame number %s from key '%s'", # type:ignore[attr-defined] + retval, key) + return retval + + def _import_face(self, face: dict[str, list[int] | list[list[float]]]) -> np.ndarray: + """ Import the landmarks from a single face + + Parameters + ---------- + face: dict[str, list[int] | list[list[float]]] + An import dictionary item for a face + + Returns + ------- + :class:`numpy.ndarray` + The landmark data imported from the json file + + Raises + ------ + FaceSwapError + If the landmarks_2d key does not exist or the landmarks are in an incorrect format + """ + landmarks = face.get("landmarks_2d") + if landmarks is None: + raise FaceswapError("The provided import file is the required key 'landmarks_2d") + if len(landmarks) not in (4, 68): + raise FaceswapError("Imported 'landmarks_2d' should be either 68 facial feature " + "landmarks or 4 ROI corner locations") + retval = np.array(landmarks, dtype="float32") + if retval.shape[-1] != 2: + raise FaceswapError("Imported 'landmarks_2d' should be formatted as a list of (x, y) " + "co-ordinates") + if retval.shape[0] == 4: # Adjust ROI landmarks based on centering selected + center = np.mean(retval, axis=0) + retval = (retval - center) * self._adjustment + center + + return retval + + def import_data(self, data: dict[str, list[dict[str, list[int] | list[list[float]]]]]) -> None: + """ Import the aligner data from the json import file and set to :attr:`_imported` + + Parameters + ---------- + data: dict[str, list[dict[str, list[int] | list[list[float]]]]] + The data to be imported + """ + logger.debug("Data length: %s", len(data)) + self._check_for_video(list(data)[0]) + for key, faces in data.items(): + try: + lms = np.array([self._import_face(face) for face in faces], dtype="float32") + if not np.any(lms): + logger.trace("Skipping frame '%s' with no faces") # type:ignore[attr-defined] + continue + + store_key = self._get_key(key) + self._imported[store_key] = (lms.shape[0], lms) + except FaceswapError as err: + logger.error(str(err)) + msg = f"The imported frame key that failed was '{key}'" + raise FaceswapError(msg) from err + lm_shape = set(v[1].shape[1:] for v in self._imported.values() if v[0] > 0) + if len(lm_shape) > 1: + raise FaceswapError("All external data should have the same number of landmarks. " + f"Found landmarks of shape: {lm_shape}") + if (4, 2) in lm_shape: + self.landmark_type = LandmarkType.LM_2D_4 + + def process_input(self, batch: BatchType) -> None: + """ Put the filenames and original frame dimensions into `batch.feed` so they can be + collected for mapping in `.predict` + + Parameters + ---------- + batch: :class:`~plugins.extract.detect._base.AlignerBatch` + The batch to be processed by the plugin + """ + batch.feed = np.array([(self._get_key(os.path.basename(f)), i.shape[:2]) + for f, i in zip(batch.filename, batch.image)], dtype="object") + + def faces_to_feed(self, faces: np.ndarray) -> np.ndarray: + """ No action required for import plugin + + Parameters + ---------- + faces: :class:`numpy.ndarray` + The batch of faces in UINT8 format + + Returns + ------- + class: `numpy.ndarray` + the original batch of faces + """ + return faces + + def _adjust_for_origin(self, landmarks: np.ndarray, frame_dims: tuple[int, int]) -> np.ndarray: + """ Adjust the landmarks to be top-left orientated based on the selected import origin + + Parameters + ---------- + landmarks: :class:`np.ndarray` + The imported facial landmarks box at original (0, 0) origin + frame_dims: tuple[int, int] + The (rows, columns) dimensions of the original frame + + Returns + ------- + :class:`numpy.ndarray` + The adjusted landmarks box for a top-left origin + """ + if not np.any(landmarks) or self.origin == "top-left": + return landmarks + + if LandmarkType.from_shape(landmarks.shape) == LandmarkType.LM_2D_4: + landmarks = np.roll(landmarks, self._roll[self.origin], axis=0) + + if self.origin.startswith("bottom"): + landmarks[:, 1] = frame_dims[0] - landmarks[:, 1] + if self.origin.endswith("right"): + landmarks[:, 0] = frame_dims[1] - landmarks[:, 0] + + return landmarks + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Pair the input filenames to the import file + + Parameters + ---------- + feed: :class:`numpy.ndarray` + The filenames in the batch to return imported alignments for + + Returns + ------- + :class:`numpy.ndarray` + The predictions for the given filenames + """ + preds = [] + for key, frame_dims in feed: + if key not in self._imported: + self._missing.append(key) + continue + + remaining, all_lms = self._imported[key] + preds.append(self._adjust_for_origin(all_lms[all_lms.shape[0] - remaining], + frame_dims)) + + if remaining == 1: + del self._imported[key] + else: + self._imported[key] = (remaining - 1, all_lms) + + return np.array(preds, dtype="float32") + + def process_output(self, batch: BatchType) -> None: + """ Process the imported data to the landmarks attribute + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch from the model with :attr:`predictions` populated + """ + assert isinstance(batch, AlignerBatch) + batch.landmarks = batch.prediction + logger.trace("Imported landmarks: %s", batch.landmarks) # type:ignore[attr-defined] + + def on_completion(self) -> None: + """ Output information if: + - Imported items were not matched in input data + - Input data was not matched in imported items + """ + super().on_completion() + + if self._missing: + logger.warning("[ALIGN] %s input frames could not be matched in the import file " + "'%s'. Run in verbose mode for a list of frames.", + len(self._missing), cfg.file_name) + logger.verbose( # type:ignore[attr-defined] + "[ALIGN] Input frames not in import file: %s", self._missing) + + if self._imported: + logger.warning("[ALIGN] %s items in the import file '%s' could not be matched to any " + "input frames. Run in verbose mode for a list of items.", + len(self._imported), cfg.file_name) + logger.verbose( # type:ignore[attr-defined] + "[ALIGN] import file items not in input frames: %s", list(self._imported)) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/align/external_defaults.py b/plugins/extract/align/external_defaults.py new file mode 100644 index 0000000000..c027bd483a --- /dev/null +++ b/plugins/extract/align/external_defaults.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" The default options for the external faceswap Import Alignments plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "Import Aligner options.\n" + "Imports either 68 point 2D landmarks or an aligned bounding box from an external .json file." + ) + + +file_name = ConfigItem( + datatype=str, + default="import.json", + group="settings", + info="The import file should be stored in the same folder as the video (if extracting " + "from a video file) or inside the folder of images (if importing from a folder of " + "images)") + +origin = ConfigItem( + datatype=str, + default="top-left", + group="input", + info="The origin (0, 0) location of the co-ordinates system used. " + "\n\t top-left: The origin (0, 0) of the canvas is at the top left " + "corner." + "\n\t bottom-left: The origin (0, 0) of the canvas is at the bottom " + "left corner." + "\n\t top-right: The origin (0, 0) of the canvas is at the top right " + "corner." + "\n\t bottom-right: The origin (0, 0) of the canvas is at the bottom " + "right corner.", + choices=["top-left", "bottom-left", "top-right", "bottom-right"], + gui_radio=True) + +four_point_centering = ConfigItem( + datatype=str, + default="head", + group="input", + info="4 point ROI landmarks only. The approximate centering for the location of the " + "corner points to be imported. Default faceswap extracts are generated at 'head' " + "centering, but it is possible to pass in ROI points at a tighter centering. " + "Refer to https://github.com/deepfakes/faceswap/pull/1095 for a visual guide" + "\n\t head: The ROI points represent a loose crop enclosing the whole head." + "\n\t face: The ROI points represent a medium crop enclosing the face." + "\n\t legacy: The ROI points represent a tight crop enclosing the central face " + "area." + "\n\t none: Only required if importing 4 point ROI landmarks back into faceswap " + "having generated them from the 'alignments' tool 'export' job.", + choices=["head", "face", "legacy", "none"], + gui_radio=True) diff --git a/plugins/extract/align/fan.py b/plugins/extract/align/fan.py index 8da9269d35..1a38397aa3 100644 --- a/plugins/extract/align/fan.py +++ b/plugins/extract/align/fan.py @@ -3,258 +3,283 @@ Code adapted and modified from: https://github.com/1adrianb/face-alignment """ -import os +from __future__ import annotations +import logging +import typing as T + import cv2 import numpy as np -from ._base import Aligner, logger +from keras.saving import load_model + +from lib.utils import get_module_objects +from ._base import Aligner, AlignerBatch, BatchType +from . import fan_defaults as cfg + +if T.TYPE_CHECKING: + from lib.align import DetectedFace + from keras import Model + +logger = logging.getLogger(__name__) class Align(Aligner): """ Perform transformation to align and get landmarks """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.vram = 2240 - self.reference_scale = 195.0 - self.model = None - self.test = None - - def set_model_path(self): - """ Load the mtcnn models """ - model_path = os.path.join(self.cachepath, "2DFAN-4.pb") - if not os.path.exists(model_path): - raise Exception("Error: Unable to find {}, reinstall " - "the lib!".format(model_path)) - logger.debug("Loading model: '%s'", model_path) - return model_path - - def initialize(self, *args, **kwargs): - """ Initialization tasks to run prior to alignments """ - super().initialize(*args, **kwargs) - logger.info("Initializing Face Alignment Network...") - logger.debug("fan initialize: (args: %s kwargs: %s)", args, kwargs) - - _, _, vram_total = self.get_vram_free() - - if vram_total <= self.vram: - tf_ratio = 1.0 - else: - tf_ratio = self.vram / vram_total - logger.verbose("Reserving %sMB for face alignments", self.vram) - - self.model = FAN(self.model_path, ratio=tf_ratio) - - self.init.set() - logger.info("Initialized Face Alignment Network.") - - def align(self, *args, **kwargs): - """ Perform alignments on detected faces """ - super().align(*args, **kwargs) - for item in self.get_item(): - if item == "EOF": - self.finalize(item) - break - image = item["image"][:, :, ::-1].copy() - - logger.trace("Aligning faces") - try: - item["landmarks"] = self.process_landmarks(image, item["detected_faces"]) - logger.trace("Aligned faces: %s", item["landmarks"]) - except ValueError as err: - logger.warning("Image '%s' could not be processed. This may be due to corrupted " - "data: %s", item["filename"], str(err)) - item["detected_faces"] = list() - item["landmarks"] = list() - self.finalize(item) - logger.debug("Completed Align") - - def process_landmarks(self, image, detected_faces): - """ Align image and process landmarks """ - logger.trace("Processing landmarks") - retval = list() - for detected_face in detected_faces: - center, scale = self.get_center_scale(detected_face) - aligned_image = self.align_image(image, center, scale) - landmarks = self.predict_landmarks(aligned_image, center, scale) - retval.append(landmarks) - logger.trace("Processed landmarks: %s", retval) + def __init__(self, **kwargs) -> None: + git_model_id = 13 + model_filename = "face-alignment-network_2d4_keras_v3.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.model: Model + self.name = "FAN" + self.input_size = 256 + self.color_format = "RGB" + self.vram = 896 # 810 in testing + self.vram_per_batch = 768 # ~720 in testing + self.realign_centering = "head" + self.batchsize: int = cfg.batch_size() + self.reference_scale = 200. / 195. + + def init_model(self) -> None: + """ Initialize FAN model """ + assert isinstance(self.name, str) + assert isinstance(self.model_path, str) + logging.disable(logging.WARNING) # Disable compile warning from Keras + self.model = load_model(self.model_path, compile=False) + logging.disable(logging.NOTSET) + self.model.make_predict_function() + # Feed a placeholder so Aligner is primed for Manual tool + placeholder_shape = (self.batchsize, self.input_size, self.input_size, 3) + placeholder = np.zeros(placeholder_shape, dtype="float32") + self.model.predict(placeholder, verbose=False, batch_size=self.batchsize) + + def faces_to_feed(self, faces: np.ndarray) -> np.ndarray: + """ Convert a batch of face images from UINT8 (0-255) to fp32 (0.0-1.0) + + Parameters + ---------- + faces: :class:`numpy.ndarray` + The batch of faces in UINT8 format + + Returns + ------- + class: `numpy.ndarray` + The batch of faces as fp32 in 0.0 to 1.0 range + """ + return faces.astype("float32") / 255. + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch to process input for + """ + assert isinstance(batch, AlignerBatch) + logger.trace("Aligning faces around center") # type:ignore[attr-defined] + center_scale = self.get_center_scale(batch.detected_faces) + batch.feed = np.array(self.crop(batch, center_scale))[..., :3] + batch.data.append({"center_scale": center_scale}) + logger.trace("Aligned image around center") # type:ignore[attr-defined] + + def get_center_scale(self, detected_faces: list[DetectedFace]) -> np.ndarray: + """ Get the center and set scale of bounding box + + Parameters + ---------- + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects for the batch + + Returns + ------- + :class:`numpy.ndarray` + The center and scale of the bounding box + """ + logger.trace("Calculating center and scale") # type:ignore[attr-defined] + center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32') + for index, face in enumerate(detected_faces): + x_ctr = (T.cast(int, face.left) + face.right) / 2.0 + y_ctr = (T.cast(int, face.top) + face.bottom) / 2.0 - T.cast(int, face.height) * 0.12 + scale = (T.cast(int, face.width) + T.cast(int, face.height)) * self.reference_scale + center_scale[index, :, 0] = np.full(68, x_ctr, dtype='float32') + center_scale[index, :, 1] = np.full(68, y_ctr, dtype='float32') + center_scale[index, :, 2] = np.full(68, scale, dtype='float32') + logger.trace("Calculated center and scale: %s", center_scale) # type:ignore[attr-defined] + return center_scale + + def _crop_image(self, + image: np.ndarray, + top_left: np.ndarray, + bottom_right: np.ndarray) -> np.ndarray: + """ Crop a single image + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to crop + top_left: :class:`numpy.ndarray` + The top left (x, y) point to crop from + bottom_right: :class:`numpy.ndarray` + The bottom right (x, y) point to crop to + + Returns + ------- + :class:`numpy.ndarray` + The cropped image + """ + bottom_right_width, bottom_right_height = bottom_right[0].astype('int32') + top_left_width, top_left_height = top_left[0].astype('int32') + new_dim = (bottom_right_height - top_left_height, + bottom_right_width - top_left_width, + 3 if image.ndim > 2 else 1) + new_img = np.zeros(new_dim, dtype=np.uint8) + + new_x = slice(max(0, -top_left_width), + min(bottom_right_width, image.shape[1]) - top_left_width) + new_y = slice(max(0, -top_left_height), + min(bottom_right_height, image.shape[0]) - top_left_height) + old_x = slice(max(0, top_left_width), min(bottom_right_width, image.shape[1])) + old_y = slice(max(0, top_left_height), min(bottom_right_height, image.shape[0])) + new_img[new_y, new_x] = image[old_y, old_x] + + interp = cv2.INTER_CUBIC if new_dim[0] < self.input_size else cv2.INTER_AREA + return cv2.resize(new_img, + dsize=(self.input_size, self.input_size), + interpolation=interp) + + def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> list[np.ndarray]: + """ Crop image around the center point + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch to crop the image for + center_scale: :class:`numpy.ndarray` + The center and scale for the bounding box + + Returns + ------- + list + List of cropped images for the batch + """ + logger.trace("Cropping images") # type:ignore[attr-defined] + batch_shape = center_scale.shape[:2] + resolutions = np.full(batch_shape, self.input_size, dtype='float32') + matrix_ones = np.ones(batch_shape + (3,), dtype='float32') + matrix_size = np.full(batch_shape + (3,), self.input_size, dtype='float32') + matrix_size[..., 2] = 1.0 + upper_left = self.transform(matrix_ones, center_scale, resolutions) + bot_right = self.transform(matrix_size, center_scale, resolutions) + + # TODO second pass .. convert to matrix + new_images = [self._crop_image(image, top_left, bottom_right) + for image, top_left, bottom_right in zip(batch.image, upper_left, bot_right)] + logger.trace("Cropped images") # type:ignore[attr-defined] + return new_images + + @classmethod + def transform(cls, + points: np.ndarray, + center_scales: np.ndarray, + resolutions: np.ndarray) -> np.ndarray: + """ Transform Image + + Parameters + ---------- + points: :class:`numpy.ndarray` + The points to transform + center_scales: :class:`numpy.ndarray` + The calculated centers and scales for the batch + resolutions: :class:`numpy.ndarray` + The resolutions + """ + logger.trace("Transforming Points") # type:ignore[attr-defined] + num_images, num_landmarks = points.shape[:2] + transform_matrix = np.eye(3, dtype='float32') + transform_matrix = np.repeat(transform_matrix[None, :], num_landmarks, axis=0) + transform_matrix = np.repeat(transform_matrix[None, :, :], num_images, axis=0) + scales = center_scales[:, :, 2] / resolutions + translations = center_scales[..., 2:3] * -0.5 + center_scales[..., :2] + transform_matrix[:, :, 0, 0] = scales # x scale + transform_matrix[:, :, 1, 1] = scales # y scale + transform_matrix[:, :, 0, 2] = translations[:, :, 0] # x translation + transform_matrix[:, :, 1, 2] = translations[:, :, 1] # y translation + new_points = np.einsum('abij, abj -> abi', transform_matrix, points, optimize='greedy') + retval = new_points[:, :, :2].astype('float32') + logger.trace("Transformed Points: %s", retval) # type:ignore[attr-defined] return retval - def get_center_scale(self, detected_face): - """ Get the center and set scale of bounding box """ - logger.trace("Calculating center and scale") - center = np.array([(detected_face.left() - + detected_face.right()) / 2.0, - (detected_face.top() - + detected_face.bottom()) / 2.0]) - - center[1] -= (detected_face.bottom() - - detected_face.top()) * 0.12 - - scale = (detected_face.right() - - detected_face.left() - + detected_face.bottom() - - detected_face.top()) / self.reference_scale - - logger.trace("Calculated center and scale: %s, %s", center, scale) - return center, scale - - def align_image(self, image, center, scale): - """ Crop and align image around center """ - logger.trace("Aligning image around center") - image = self.crop( - image, - center, - scale).transpose((2, 0, 1)).astype(np.float32) / 255.0 - logger.trace("Aligned image around center") - return np.expand_dims(image, 0) - - def predict_landmarks(self, image, center, scale): - """ Predict the 68 point landmarks """ - logger.trace("Predicting Landmarks") - prediction = self.model.predict(image)[-1] - pts_img = self.get_pts_from_predict(prediction, center, scale) - retval = [(int(pt[0]), int(pt[1])) for pt in pts_img] - logger.trace("Predicted Landmarks: %s", retval) + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Predict the 68 point landmarks + + Parameters + ---------- + batch: :class:`numpy.ndarray` + The batch to feed into the aligner + + Returns + ------- + :class:`numpy.ndarray` + The predictions from the aligner + """ + logger.trace("Predicting Landmarks") # type:ignore[attr-defined] + retval = self.model.predict(feed, + verbose=False, + batch_size=self.batchsize)[-1].transpose(0, 3, 1, 2) return retval - @staticmethod - def transform(point, center, scale, resolution): - """ Transform Image """ - logger.trace("Transforming Points") - pnt = np.array([point[0], point[1], 1.0]) - hscl = 200.0 * scale - eye = np.eye(3) - eye[0, 0] = resolution / hscl - eye[1, 1] = resolution / hscl - eye[0, 2] = resolution * (-center[0] / hscl + 0.5) - eye[1, 2] = resolution * (-center[1] / hscl + 0.5) - eye = np.linalg.inv(eye) - retval = np.matmul(eye, pnt)[0:2] - logger.trace("Transformed Points: %s", retval) - return retval - - def crop(self, image, center, scale, resolution=256.0): # pylint: disable=too-many-locals - """ Crop image around the center point """ - logger.trace("Cropping image") - v_ul = self.transform([1, 1], center, scale, resolution).astype(np.int) - v_br = self.transform([resolution, resolution], - center, - scale, - resolution).astype(np.int) - if image.ndim > 2: - new_dim = np.array([v_br[1] - v_ul[1], - v_br[0] - v_ul[0], - image.shape[2]], - dtype=np.int32) - self.test = new_dim - new_img = np.zeros(new_dim, dtype=np.uint8) + def process_output(self, batch: BatchType) -> None: + """ Process the output from the model + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch from the model with :attr:`predictions` populated + """ + assert isinstance(batch, AlignerBatch) + self.get_pts_from_predict(batch) + + def get_pts_from_predict(self, batch: AlignerBatch) -> None: + """ Get points from predictor and populate the :attr:`landmarks` property of the + :class:`AlignerBatch` + + Parameters + ---------- + batch: :class:`AlignerBatch` + The current batch from the model with :attr:`predictions` populated + """ + logger.trace("Obtain points from prediction") # type:ignore[attr-defined] + num_images, num_landmarks = batch.prediction.shape[:2] + image_slice = np.repeat(np.arange(num_images)[:, None], num_landmarks, axis=1) + landmark_slice = np.repeat(np.arange(num_landmarks)[None, :], num_images, axis=0) + resolution = np.full((num_images, num_landmarks), 64, dtype='int32') + subpixel_landmarks = np.ones((num_images, num_landmarks, 3), dtype='float32') + + indices = np.array(np.unravel_index(batch.prediction.reshape(num_images, + num_landmarks, + -1).argmax(-1), + (batch.prediction.shape[2], # height + batch.prediction.shape[3]))) # width + min_clipped = np.minimum(indices + 1, batch.prediction.shape[2] - 1) + max_clipped = np.maximum(indices - 1, 0) + offsets = [(image_slice, landmark_slice, indices[0], min_clipped[1]), + (image_slice, landmark_slice, indices[0], max_clipped[1]), + (image_slice, landmark_slice, min_clipped[0], indices[1]), + (image_slice, landmark_slice, max_clipped[0], indices[1])] + x_subpixel_shift = batch.prediction[offsets[0]] - batch.prediction[offsets[1]] + y_subpixel_shift = batch.prediction[offsets[2]] - batch.prediction[offsets[3]] + # TODO improve rudimentary sub-pixel logic to centroid of 3x3 window algorithm + subpixel_landmarks[:, :, 0] = indices[1] + np.sign(x_subpixel_shift) * 0.25 + 0.5 + subpixel_landmarks[:, :, 1] = indices[0] + np.sign(y_subpixel_shift) * 0.25 + 0.5 + + if batch.second_pass: # Transformation handled by plugin parent for re-aligned faces + batch.landmarks = subpixel_landmarks[..., :2] * 4. else: - new_dim = np.array([v_br[1] - v_ul[1], - v_br[0] - v_ul[0]], - dtype=np.int) - self.test = new_dim - new_img = np.zeros(new_dim, dtype=np.uint8) - height = image.shape[0] - width = image.shape[1] - new_x = np.array([max(1, -v_ul[0] + 1), min(v_br[0], width) - v_ul[0]], - dtype=np.int32) - new_y = np.array([max(1, -v_ul[1] + 1), - min(v_br[1], height) - v_ul[1]], - dtype=np.int32) - old_x = np.array([max(1, v_ul[0] + 1), min(v_br[0], width)], - dtype=np.int32) - old_y = np.array([max(1, v_ul[1] + 1), min(v_br[1], height)], - dtype=np.int32) - new_img[new_y[0] - 1:new_y[1], - new_x[0] - 1:new_x[1]] = image[old_y[0] - 1:old_y[1], - old_x[0] - 1:old_x[1], :] - # pylint: disable=no-member - new_img = cv2.resize(new_img, - dsize=(int(resolution), int(resolution)), - interpolation=cv2.INTER_LINEAR) - logger.trace("Cropped image") - return new_img - - def get_pts_from_predict(self, var_a, center, scale): - """ Get points from predictor """ - logger.trace("Obtain points from prediction") - var_b = var_a.reshape((var_a.shape[0], - var_a.shape[1] * var_a.shape[2])) - var_c = var_b.argmax(1).reshape((var_a.shape[0], - 1)).repeat(2, - axis=1).astype(np.float) - var_c[:, 0] %= var_a.shape[2] - var_c[:, 1] = np.apply_along_axis( - lambda x: np.floor(x / var_a.shape[2]), - 0, - var_c[:, 1]) - - for i in range(var_a.shape[0]): - pt_x, pt_y = int(var_c[i, 0]), int(var_c[i, 1]) - if pt_x > 0 and pt_x < 63 and pt_y > 0 and pt_y < 63: - diff = np.array([var_a[i, pt_y, pt_x+1] - - var_a[i, pt_y, pt_x-1], - var_a[i, pt_y+1, pt_x] - - var_a[i, pt_y-1, pt_x]]) - - var_c[i] += np.sign(diff)*0.25 - - var_c += 0.5 - retval = [self.transform(var_c[i], center, scale, var_a.shape[2]) - for i in range(var_a.shape[0])] - logger.trace("Obtained points from prediction: %s", retval) - - return retval + batch.landmarks = self.transform(subpixel_landmarks, + batch.data[0]["center_scale"], + resolution) + logger.trace("Obtained points from prediction: %s", # type:ignore[attr-defined] + batch.landmarks) -class FAN(): - """The FAN Model. - Converted from pyTorch via ONNX from: - https://github.com/1adrianb/face-alignment """ - - def __init__(self, model_path, ratio=1.0): - # Must import tensorflow inside the spawned process - # for Windows machines - import tensorflow as tf - self.tf = tf # pylint: disable=invalid-name - - self.model_path = model_path - self.graph = self.load_graph() - self.input = self.graph.get_tensor_by_name("fa/input_1:0") - self.output = self.graph.get_tensor_by_name("fa/transpose_647:0") - self.session = self.set_session(ratio) - - def load_graph(self): - """ Load the tensorflow Model and weights """ - # pylint: disable=not-context-manager - logger.verbose("Initializing Face Alignment Network model...") - - with self.tf.gfile.GFile(self.model_path, "rb") as gfile: - graph_def = self.tf.GraphDef() - graph_def.ParseFromString(gfile.read()) - fa_graph = self.tf.Graph() - with fa_graph.as_default(): - self.tf.import_graph_def(graph_def, name="fa") - return fa_graph - - def set_session(self, vram_ratio): - """ Set the TF Session and initialize """ - # pylint: disable=not-context-manager, no-member - placeholder = np.zeros((1, 3, 256, 256)) - with self.graph.as_default(): - config = self.tf.ConfigProto() - config.gpu_options.per_process_gpu_memory_fraction = vram_ratio - session = self.tf.Session(config=config) - with session.as_default(): - if any("gpu" in str(device).lower() for device in session.list_devices()): - logger.debug("Using GPU") - else: - logger.warning("Using CPU") - session.run(self.output, feed_dict={self.input: placeholder}) - return session - - def predict(self, feed_item): - """ Predict landmarks in session """ - return self.session.run(self.output, - feed_dict={self.input: feed_item}) +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/align/fan_defaults.py b/plugins/extract/align/fan_defaults.py new file mode 100644 index 0000000000..31072d64c8 --- /dev/null +++ b/plugins/extract/align/fan_defaults.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap FAN Alignments plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "FAN Aligner options.\n" + "Fast on GPU, slow on CPU. Best aligner." + ) + + +batch_size = ConfigItem( + datatype=int, + default=12, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered." + "\n\tAMD users: A batchsize of 8 requires about 4 GB vram.", + rounding=1, + min_max=(1, 64)) diff --git a/plugins/extract/detect/.cache/det1.npy b/plugins/extract/detect/.cache/det1.npy deleted file mode 100755 index 7c05a2c562..0000000000 Binary files a/plugins/extract/detect/.cache/det1.npy and /dev/null differ diff --git a/plugins/extract/detect/.cache/det2.npy b/plugins/extract/detect/.cache/det2.npy deleted file mode 100755 index 85d5bf09c9..0000000000 Binary files a/plugins/extract/detect/.cache/det2.npy and /dev/null differ diff --git a/plugins/extract/detect/.cache/det3.npy b/plugins/extract/detect/.cache/det3.npy deleted file mode 100755 index 90d5ba9754..0000000000 Binary files a/plugins/extract/detect/.cache/det3.npy and /dev/null differ diff --git a/plugins/extract/detect/_base.py b/plugins/extract/detect/_base.py index d65b202aae..61aafbd1f9 100644 --- a/plugins/extract/detect/_base.py +++ b/plugins/extract/detect/_base.py @@ -1,341 +1,675 @@ #!/usr/bin/env python3 """ Base class for Face Detector plugins - Plugins should inherit from this class - See the override methods for which methods are - required. +All Detector Plugins should inherit from this class. +See the override methods for which methods are required. - For each source frame, the plugin must pass a dict to finalize containing: - {"filename": , - "image": , - "detected_faces": } - """ +The plugin will receive a :class:`~plugins.extract.extract_media.ExtractMedia` object. + +For each source frame, the plugin must pass a dict to finalize containing: + +>>> {'filename': , +>>> 'detected_faces': >> face = self._to_detected_face(, , , ) +""" +from __future__ import annotations import logging -import os -import traceback -from io import StringIO -from math import sqrt +import typing as T + +from dataclasses import dataclass, field import cv2 -import dlib +import numpy as np +from torch.cuda import OutOfMemoryError -from lib.gpu_stats import GPUStats -from lib.utils import rotate_landmarks -from plugins.extract._config import Config +from lib.align import DetectedFace +from lib.utils import FaceswapError -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +from plugins.extract._base import BatchType, Extractor, ExtractorBatch +from plugins.extract import ExtractMedia +if T.TYPE_CHECKING: + from collections.abc import Generator + from queue import Queue -def get_config(plugin_name): - """ Return the config for the requested model """ - return Config(plugin_name).config_dict +logger = logging.getLogger(__name__) -class Detector(): - """ Detector object """ - def __init__(self, loglevel, rotation=None, min_size=0): - logger.debug("Initializing %s: (rotation: %s, min_size: %s)", - self.__class__.__name__, rotation, min_size) - self.config = get_config(".".join(self.__module__.split(".")[-2:])) - self.loglevel = loglevel - self.cachepath = os.path.join(os.path.dirname(__file__), ".cache") - self.rotation = self.get_rotation_angles(rotation) - self.min_size = min_size - self.parent_is_pool = False - self.init = None - - # The input and output queues for the plugin. - # See lib.queue_manager.QueueManager for getting queues - self.queues = {"in": None, "out": None} - - # Path to model if required - self.model_path = self.set_model_path() - - # Target image size for passing images through the detector - # Set to tuple of dimensions (x, y) or int of pixel count - self.target = None - - # Approximate VRAM used for the set target. Used to calculate - # how many parallel processes / batches can be run. - # Be conservative to avoid OOM. - self.vram = None - - # For detectors that support batching, this should be set to - # the calculated batch size that the amount of available VRAM - # will support. It is also used for holding the number of threads/ - # processes for parallel processing plugins - self.batch_size = 1 - logger.debug("Initialized _base %s", self.__class__.__name__) +@dataclass +class DetectorBatch(ExtractorBatch): + """ Dataclass for holding items flowing through the aligner. - # <<< OVERRIDE METHODS >>> # - # These methods must be overriden when creating a plugin - @staticmethod - def set_model_path(): - """ path to data file/models - override for specific detector """ - raise NotImplementedError() - - def initialize(self, *args, **kwargs): - """ Inititalize the detector - Tasks to be run before any detection is performed. - Override for specific detector """ - logger_init = kwargs["log_init"] - log_queue = kwargs["log_queue"] - logger_init(self.loglevel, log_queue) - logger.debug("initialize %s (PID: %s, args: %s, kwargs: %s)", - self.__class__.__name__, os.getpid(), args, kwargs) - self.init = kwargs.get("event", False) - self.queues["in"] = kwargs["in_queue"] - self.queues["out"] = kwargs["out_queue"] - - def detect_faces(self, *args, **kwargs): - """ Detect faces in rgb image - Override for specific detector - Must return a list of dlib rects""" - try: - if not self.init: - self.initialize(*args, **kwargs) - except ValueError as err: - logger.error(err) - exit(1) - logger.debug("Detecting Faces (args: %s, kwargs: %s)", args, kwargs) - - # <<< DETECTION WRAPPER >>> # - def run(self, *args, **kwargs): - """ Parent detect process. - This should always be called as the entry point so exceptions - are passed back to parent. - Do not override """ - try: - self.detect_faces(*args, **kwargs) - except Exception as err: # pylint: disable=broad-except - logger.error("Caught exception in child process: %s: %s", os.getpid(), str(err)) - # Display traceback if in initialization stage - if not self.init.is_set(): - logger.exception("Traceback:") - tb_buffer = StringIO() - traceback.print_exc(file=tb_buffer) - logger.trace(tb_buffer.getvalue()) - exception = {"exception": (os.getpid(), tb_buffer)} - self.queues["out"].put(exception) - exit(1) + Inherits from :class:`~plugins.extract._base.ExtractorBatch` - # <<< FINALIZE METHODS>>> # - def finalize(self, output): - """ This should be called as the final task of each plugin - Performs fianl processing and puts to the out queue """ - if isinstance(output, dict): - logger.trace("Item out: %s", {key: val - for key, val in output.items() - if key != "image"}) - if self.min_size > 0 and output.get("detected_faces", None): - output["detected_faces"] = self.filter_small_faces(output["detected_faces"]) - else: - logger.trace("Item out: %s", output) - self.queues["out"].put(output) - - def filter_small_faces(self, detected_faces): - """ Filter out any faces smaller than the min size threshold """ - retval = list() - for face in detected_faces: - face_size = ((face.right() - face.left()) ** 2 + - (face.bottom() - face.top()) ** 2) ** 0.5 - if face_size < self.min_size: - logger.debug("Removing detected face: (face_size: %s, min_size: %s", - face_size, self.min_size) - continue - retval.append(face) + Parameters + ---------- + rotation_matrix: :class:`numpy.ndarray` + The rotation matrix for any requested rotations + scale: float + The scaling factor to take the input image back to original size + pad: tuple + The amount of padding to apply to the image to feed the network + initial_feed: :class:`numpy.ndarray` + Used to hold the initial :attr:`feed` when rotate images is enabled + """ + detected_faces: list[list["DetectedFace"]] = field(default_factory=list) + rotation_matrix: list[np.ndarray] = field(default_factory=list) + scale: list[float] = field(default_factory=list) + pad: list[tuple[int, int]] = field(default_factory=list) + initial_feed: np.ndarray = field(default_factory=lambda: np.array([])) + + def __repr__(self): + """ Prettier repr for debug printing """ + retval = super().__repr__() + retval += (f", rotation_matrix={self.rotation_matrix}, " + f"scale={self.scale}, " + f"pad={self.pad}, " + f"initial_feed=({self.initial_feed.shape}, {self.initial_feed.dtype})") return retval - # <<< DETECTION IMAGE COMPILATION METHODS >>> # - def compile_detection_image(self, image, is_square, scale_up): - """ Compile the detection image """ - scale = self.set_scale(image, is_square=is_square, scale_up=scale_up) - return [self.set_detect_image(image, scale), scale] - def set_scale(self, image, is_square=False, scale_up=False): - """ Set the scale factor for incoming image """ - height, width = image.shape[:2] - if is_square: - if isinstance(self.target, int): - dims = (self.target ** 0.5, self.target ** 0.5) - self.target = dims - source = max(height, width) - target = max(self.target) - else: - if isinstance(self.target, tuple): - self.target = self.target[0] * self.target[1] - source = width * height - target = self.target +class Detector(Extractor): # pylint:disable=abstract-method + """ Detector Object + + Parent class for all Detector plugins + + Parameters + ---------- + git_model_id: int + The second digit in the github tag that identifies this model. See + https://github.com/deepfakes-models/faceswap-models for more information + model_filename: str + The name of the model file to be loaded + rotation: str, optional + Pass in a single number to use increments of that size up to 360, or pass in a ``list`` of + ``ints`` to enumerate exactly what angles to check. Can also pass in ``'on'`` to increment + at 90 degree intervals. Default: ``None`` + min_size: int, optional + Filters out faces detected below this size. Length, in pixels across the diagonal of the + bounding box. Set to ``0`` for off. Default: ``0`` + + Other Parameters + ---------------- + configfile: str, optional + Path to a custom configuration ``ini`` file. Default: Use system configfile + + See Also + -------- + plugins.extract.pipeline : The extraction pipeline for calling plugins + plugins.extract.detect : Detector plugins + plugins.extract._base : Parent class for all extraction plugins + plugins.extract.align._base : Aligner parent class for extraction plugins. + plugins.extract.mask._base : Masker parent class for extraction plugins. + """ + + def __init__(self, + git_model_id: int | None = None, + model_filename: str | list[str] | None = None, + configfile: str | None = None, + instance: int = 0, + rotation: str | None = None, + min_size: int = 0, + **kwargs) -> None: + logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__, + rotation, min_size) + super().__init__(git_model_id, + model_filename, + configfile=configfile, + instance=instance, + **kwargs) + self.rotation = self._get_rotation_angles(rotation) + self.min_size = min_size + + self._info.plugin_type = "detect" - if scale_up or target < source: - scale = sqrt(target / source) + logger.debug("Initialized _base %s", self.__class__.__name__) + + # <<< QUEUE METHODS >>> # + def get_batch(self, queue: Queue) -> tuple[bool, DetectorBatch]: + """ Get items for inputting to the detector plugin in batches + + Items are received as :class:`~plugins.extract.extract_media.ExtractMedia` objects and + converted to ``dict`` for internal processing. + + Items are returned from the ``queue`` in batches of + :attr:`~plugins.extract._base.Extractor.batchsize` + + Remember to put ``'EOF'`` to the out queue after processing + the final batch + + Outputs items in the following format. All lists are of length + :attr:`~plugins.extract._base.Extractor.batchsize`: + + >>> {'filename': [], + >>> 'image': , + >>> 'scale': [], + >>> 'pad': [], + >>> 'detected_faces': [[>> # + def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]: + """ Finalize the output from Detector + + This should be called as the final task of each ``plugin``. + + Parameters + ---------- + batch : :class:`~plugins.extract._base.ExtractorBatch` + The batch object for the current batch + + Yields + ------ + :class:`~plugins.extract.extract_media.ExtractMedia` + The :attr:`DetectedFaces` list will be populated for this class with the bounding boxes + for the detected faces found in the frame. + """ + assert isinstance(batch, DetectorBatch) + logger.trace("Item out: %s", # type:ignore[attr-defined] + {k: len(v) if isinstance(v, (list, np.ndarray)) else v + for k, v in batch.__dict__.items()}) + + batch_faces = [[self._to_detected_face(face[0], face[1], face[2], face[3]) + for face in faces] + for faces in batch.prediction] + # Rotations + if any(m.any() for m in batch.rotation_matrix) and any(batch_faces): + batch_faces = [[self._rotate_face(face, rotmat) if rotmat.any() else face + for face in faces] + for faces, rotmat in zip(batch_faces, batch.rotation_matrix)] + + # Remove zero sized faces + batch_faces = self._remove_zero_sized_faces(batch_faces) + + # Scale back out to original frame + batch.detected_faces = [[self._to_detected_face((face.left - pad[0]) / scale, + (face.top - pad[1]) / scale, + (face.right - pad[0]) / scale, + (face.bottom - pad[1]) / scale) + for face in faces + if face.left is not None and face.top is not None] + for scale, pad, faces in zip(batch.scale, + batch.pad, + batch_faces)] + + if self.min_size > 0 and batch.detected_faces: + batch.detected_faces = self._filter_small_faces(batch.detected_faces) + + for idx, filename in enumerate(batch.filename): + output = self._extract_media.pop(filename) + output.add_detected_faces(batch.detected_faces[idx]) + + logger.trace("final output: (filename: '%s', " # type:ignore[attr-defined] + "image shape: %s, detected_faces: %s, item: %s", + output.filename, output.image_shape, output.detected_faces, output) + yield output + @staticmethod + def _to_detected_face(left: float, top: float, right: float, bottom: float) -> DetectedFace: + """ Convert a bounding box to a detected face object + + Parameters + ---------- + left: float + The left point of the detection bounding box + top: float + The top point of the detection bounding box + right: float + The right point of the detection bounding box + bottom: float + The bottom point of the detection bounding box + + Returns + ------- + class:`~lib.align.DetectedFace` + The detected face object for the given bounding box + """ + return DetectedFace(left=int(round(left)), + width=int(round(right - left)), + top=int(round(top)), + height=int(round(bottom - top))) + + # <<< PROTECTED ACCESS METHODS >>> # + # <<< PREDICT WRAPPER >>> # + def _predict(self, batch: BatchType) -> DetectorBatch: + """ Wrap models predict function in rotations """ + assert isinstance(batch, DetectorBatch) + batch.rotation_matrix = [np.array([]) for _ in range(len(batch.feed))] + found_faces: list[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))] + for angle in self.rotation: + # Rotate the batch and insert placeholders for already found faces + self._rotate_batch(batch, angle) + try: + pred = self.predict(batch.feed) + if angle == 0: + batch.prediction = pred + else: + try: + batch.prediction = np.array([b if b.any() else p + for b, p in zip(batch.prediction, pred)]) + except ValueError as err: + # If batches are different sizes after rotation Numpy will error, so we + # need to explicitly set the dtype to 'object' rather than let it infer + # numpy error: + # ValueError: setting an array element with a sequence. The requested array + # has an inhomogeneous shape after 1 dimensions. The detected shape was + # (8,) + inhomogeneous part + if "inhomogeneous" in str(err): + batch.prediction = np.array([b if b.any() else p + for b, p in zip(batch.prediction, pred)], + dtype="object") + logger.trace( # type:ignore[attr-defined] + "Mismatched array sizes, setting dtype to object: %s", + [p.shape for p in batch.prediction]) + else: + raise + + logger.trace("angle: %s, filenames: %s, " # type:ignore[attr-defined] + "prediction: %s", + angle, batch.filename, pred) + except OutOfMemoryError as err: + msg = ("You do not have enough GPU memory available to run detection at the " + "selected batch size. You can try a number of things:" + "\n1) Close any other application that is using your GPU (web browsers are " + "particularly bad for this)." + "\n2) Lower the batchsize (the amount of images fed into the model) by " + "editing the plugin settings (GUI: Settings > Configure extract settings, " + "CLI: Edit the file faceswap/config/extract.ini)." + "\n3) Enable 'Single Process' mode.") + raise FaceswapError(msg) from err + + if angle != 0 and any(face.any() for face in batch.prediction): + logger.verbose("found face(s) by rotating image %s " # type:ignore[attr-defined] + "degrees", + angle) + + found_faces = T.cast(list[np.ndarray], ([face if not found.any() else found + for face, found in zip(batch.prediction, + found_faces)])) + if all(face.any() for face in found_faces): + logger.trace("Faces found for all images") # type:ignore[attr-defined] + break + + batch.prediction = np.array(found_faces, dtype="object") + logger.trace("detect_prediction output: (filenames: %s, " # type:ignore[attr-defined] + "prediction: %s, rotmat: %s)", + batch.filename, batch.prediction, batch.rotation_matrix) + return batch + + # <<< DETECTION IMAGE COMPILATION METHODS >>> # + def _compile_detection_image(self, item: ExtractMedia + ) -> tuple[np.ndarray, float, tuple[int, int]]: + """ Compile the detection image for feeding into the model + + Parameters + ---------- + item: :class:`~plugins.extract.extract_media.ExtractMedia` + The input item from the pipeline + + Returns + ------- + image: :class:`numpy.ndarray` + The original image formatted for detection + scale: float + The scaling factor for the image + pad: int + The amount of padding applied to the image + """ + image = item.get_image_copy(self.color_format) + scale = self._set_scale(item.image_size) + pad = self._set_padding(item.image_size, scale) + + image = self._scale_image(image, item.image_size, scale) + image = self._pad_image(image) + logger.trace("compiled: (images shape: %s, " # type:ignore[attr-defined] + "scale: %s, pad: %s)", + image.shape, scale, pad) + return image, scale, pad + + def _set_scale(self, image_size: tuple[int, int]) -> float: + """ Set the scale factor for incoming image + + Parameters + ---------- + image_size: tuple + The (height, width) of the original image + + Returns + ------- + float + The scaling factor from original image size to model input size + """ + scale = self.input_size / max(image_size) + logger.trace("Detector scale: %s", scale) # type:ignore[attr-defined] return scale + def _set_padding(self, image_size: tuple[int, int], scale: float) -> tuple[int, int]: + """ Set the image padding for non-square images + + Parameters + ---------- + image_size: tuple + The (height, width) of the original image + scale: float + The scaling factor from original image size to model input size + + Returns + ------- + tuple + The amount of padding to apply to the x and y axes + """ + pad_left = int(self.input_size - int(image_size[1] * scale)) // 2 + pad_top = int(self.input_size - int(image_size[0] * scale)) // 2 + return pad_left, pad_top + @staticmethod - def set_detect_image(input_image, scale): - """ Convert the image to RGB and scale """ - # pylint: disable=no-member - image = input_image[:, :, ::-1].copy() - if scale == 1.0: - return image + def _scale_image(image: np.ndarray, image_size: tuple[int, int], scale: float) -> np.ndarray: + """ Scale the image and optional pad to given size + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to be scalued + image_size: tuple + The image (height, width) + scale: float + The scaling factor to apply to the image + + Returns + ------- + :class:`numpy.ndarray` + The scaled image + """ + interpln = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + if scale != 1.0: + dims = (int(image_size[1] * scale), int(image_size[0] * scale)) + logger.trace("Resizing detection image from %s to %s. " # type:ignore[attr-defined] + "Scale=%s", + "x".join(str(i) for i in reversed(image_size)), + "x".join(str(i) for i in dims), scale) + image = cv2.resize(image, dims, interpolation=interpln) + logger.trace("Resized image shape: %s", image.shape) # type:ignore[attr-defined] + return image - height, width = image.shape[:2] - interpln = cv2.INTER_LINEAR if scale > 1.0 else cv2.INTER_AREA - dims = (int(width * scale), int(height * scale)) + def _pad_image(self, image: np.ndarray) -> np.ndarray: + """ Pad a resized image to input size - if scale < 1.0: - logger.verbose("Resizing image from %sx%s to %s.", - width, height, "x".join(str(i) for i in dims)) + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to have padding applied - image = cv2.resize(image, dims, interpolation=interpln) + Returns + ------- + :class:`numpy.ndarray` + The image with padding applied + """ + height, width = image.shape[:2] + if width < self.input_size or height < self.input_size: + pad_l = (self.input_size - width) // 2 + pad_r = (self.input_size - width) - pad_l + pad_t = (self.input_size - height) // 2 + pad_b = (self.input_size - height) - pad_t + image = cv2.copyMakeBorder(image, + pad_t, + pad_b, + pad_l, + pad_r, + cv2.BORDER_CONSTANT) + logger.trace("Padded image shape: %s", image.shape) # type:ignore[attr-defined] return image + # <<< FINALIZE METHODS >>> # + def _remove_zero_sized_faces(self, batch_faces: list[list[DetectedFace]] + ) -> list[list[DetectedFace]]: + """ Remove items from batch_faces where detected face is of zero size or face falls + entirely outside of image + + Parameters + ---------- + batch_faces: list + List of detected face objects + + Returns + ------- + list + List of detected face objects with filtered out faces removed + """ + logger.trace("Input sizes: %s", [len(face) for face in batch_faces]) # type: ignore + retval = [[face + for face in faces + if face.right > 0 and face.left is not None and face.left < self.input_size + and face.bottom > 0 and face.top is not None and face.top < self.input_size] + for faces in batch_faces] + logger.trace("Output sizes: %s", [len(face) for face in retval]) # type: ignore + return retval + + def _filter_small_faces(self, detected_faces: list[list[DetectedFace]] + ) -> list[list[DetectedFace]]: + """ Filter out any faces smaller than the min size threshold + + Parameters + ---------- + detected_faces: list + List of detected face objects + + Returns + ------- + list + List of detected face objects with filtered out faces removed + """ + retval = [] + for faces in detected_faces: + this_image = [] + for face in faces: + assert face.width is not None and face.height is not None + face_size = (face.width ** 2 + face.height ** 2) ** 0.5 + if face_size < self.min_size: + logger.debug("Removing detected face: (face_size: %s, min_size: %s", + face_size, self.min_size) + continue + this_image.append(face) + retval.append(this_image) + return retval + # <<< IMAGE ROTATION METHODS >>> # @staticmethod - def get_rotation_angles(rotation): - """ Set the rotation angles. Includes backwards compatibility for the - 'on' and 'off' options: - - 'on' - increment 90 degrees - - 'off' - disable - - 0 is prepended to the list, as whatever happens, we want to - scan the image in it's upright state """ + def _get_rotation_angles(rotation: str | None) -> list[int]: + """ Set the rotation angles. + + Parameters + ---------- + str + List of requested rotation angles + + Returns + ------- + list + The complete list of rotation angles to apply + """ rotation_angles = [0] - if not rotation or rotation.lower() == "off": + if not rotation: logger.debug("Not setting rotation angles") return rotation_angles - if rotation.lower() == "on": - rotation_angles.extend(range(90, 360, 90)) - else: - passed_angles = [int(angle) - for angle in rotation.split(",")] - if len(passed_angles) == 1: - rotation_step_size = passed_angles[0] - rotation_angles.extend(range(rotation_step_size, - 360, - rotation_step_size)) - elif len(passed_angles) > 1: - rotation_angles.extend(passed_angles) + passed_angles = [int(angle) + for angle in rotation.split(",") + if int(angle) != 0] + if len(passed_angles) == 1: + rotation_step_size = passed_angles[0] + rotation_angles.extend(range(rotation_step_size, + 360, + rotation_step_size)) + elif len(passed_angles) > 1: + rotation_angles.extend(passed_angles) logger.debug("Rotation Angles: %s", rotation_angles) return rotation_angles - def rotate_image(self, image, angle): - """ Rotate the image by given angle and return - Image with rotation matrix """ - if angle == 0: - return image, None - return self.rotate_image_by_angle(image, angle) + def _rotate_batch(self, batch: DetectorBatch, angle: int) -> None: + """ Rotate images in a batch by given angle - @staticmethod - def rotate_rect(d_rect, rotation_matrix): - """ Rotate a dlib rect based on the rotation_matrix""" - logger.trace("Rotating d_rectangle") - d_rect = rotate_landmarks(d_rect, rotation_matrix) - return d_rect + if any faces have already been detected for a batch, store the existing rotation + matrix and replace the feed image with a placeholder + + Parameters + ---------- + batch: :class:`DetectorBatch` + The batch to apply rotation to + angle: int + The amount of degrees to rotate the image by + """ + if angle == 0: + # Set the initial batch so we always rotate from zero + batch.initial_feed = batch.feed.copy() + return + + feeds: list[np.ndarray] = [] + rotmats: list[np.ndarray] = [] + for img, faces, rotmat in zip(batch.initial_feed, + batch.prediction, + batch.rotation_matrix): + if faces.any(): + image = np.zeros_like(img) + matrix = rotmat + else: + image, matrix = self._rotate_image_by_angle(img, angle) + feeds.append(image) + rotmats.append(matrix) + batch.feed = np.array(feeds, dtype="float32") + batch.rotation_matrix = rotmats @staticmethod - def rotate_image_by_angle(image, angle, - rotated_width=None, rotated_height=None): + def _rotate_face(face: DetectedFace, rotation_matrix: np.ndarray) -> DetectedFace: + """ Rotates the detection bounding box around the given rotation matrix. + + Parameters + ---------- + face: :class:`DetectedFace` + A :class:`DetectedFace` containing the `x`, `w`, `y`, `h` detection bounding box + points. + rotation_matrix: numpy.ndarray + The rotation matrix to rotate the given object by. + + Returns + ------- + :class:`DetectedFace` + The same class with the detection bounding box points rotated by the given matrix. + """ + logger.trace("Rotating face: (face: %s, rotation_matrix: %s)", # type: ignore + face, rotation_matrix) + bounding_box = [[face.left, face.top], + [face.right, face.top], + [face.right, face.bottom], + [face.left, face.bottom]] + rotation_matrix = cv2.invertAffineTransform(rotation_matrix) + + points = np.array(bounding_box, "int32") + points = np.expand_dims(points, axis=0) + transformed = cv2.transform(points, rotation_matrix).astype("int32") + rotated = transformed.squeeze() + + # Bounding box should follow x, y planes, so get min/max for non-90 degree rotations + pt_x = min(pnt[0] for pnt in rotated) + pt_y = min(pnt[1] for pnt in rotated) + pt_x1 = max(pnt[0] for pnt in rotated) + pt_y1 = max(pnt[1] for pnt in rotated) + width = pt_x1 - pt_x + height = pt_y1 - pt_y + + face.left = int(pt_x) + face.top = int(pt_y) + face.width = int(width) + face.height = int(height) + return face + + def _rotate_image_by_angle(self, + image: np.ndarray, + angle: int) -> tuple[np.ndarray, np.ndarray]: """ Rotate an image by a given angle. - From: https://stackoverflow.com/questions/22041699 """ - logger.trace("Rotating image: (angle: %s, rotated_width: %s, rotated_height: %s)", - angle, rotated_width, rotated_height) + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to be rotated + angle: int + The angle, in degrees, to rotate the image by + + Returns + ------- + image: :class:`numpy.ndarray` + The rotated image + rotation_matrix: :class:`numpy.ndarray` + The rotation matrix used to rotate the image + + Reference + --------- + https://stackoverflow.com/questions/22041699 + """ + + logger.trace("Rotating image: (image: %s, angle: %s)", # type:ignore[attr-defined] + image.shape, angle) + channels_first = image.shape[0] <= 4 + if channels_first: + image = np.moveaxis(image, 0, 2) + height, width = image.shape[:2] image_center = (width/2, height/2) - rotation_matrix = cv2.getRotationMatrix2D( # pylint: disable=no-member - image_center, -1.*angle, 1.) - if rotated_width is None or rotated_height is None: - abs_cos = abs(rotation_matrix[0, 0]) - abs_sin = abs(rotation_matrix[0, 1]) - if rotated_width is None: - rotated_width = int(height*abs_sin + width*abs_cos) - if rotated_height is None: - rotated_height = int(height*abs_cos + width*abs_sin) - rotation_matrix[0, 2] += rotated_width/2 - image_center[0] - rotation_matrix[1, 2] += rotated_height/2 - image_center[1] - logger.trace("Rotated image: (rotation_matrix: %s", rotation_matrix) - return (cv2.warpAffine(image, # pylint: disable=no-member - rotation_matrix, - (rotated_width, rotated_height)), - rotation_matrix) - - # << QUEUE METHODS >> # - def get_item(self): - """ Yield one item from the queue """ - item = self.queues["in"].get() - if isinstance(item, dict): - logger.trace("Item in: %s", item["filename"]) - else: - logger.trace("Item in: %s", item) - if item == "EOF": - logger.debug("In Queue Exhausted") - # Re-put EOF into queue for other threads - self.queues["in"].put(item) - return item - - def get_batch(self): - """ Get items from the queue in batches of - self.batch_size - - First item in output tuple indicates whether the - queue is exhausted. - Second item is the batch - - Remember to put "EOF" to the out queue after processing - the final batch """ - exhausted = False - batch = list() - for _ in range(self.batch_size): - item = self.get_item() - if item == "EOF": - exhausted = True - break - batch.append(item) - logger.trace("Returning batch size: %s", len(batch)) - return (exhausted, batch) - - # <<< DLIB RECTANGLE METHODS >>> # - @staticmethod - def is_mmod_rectangle(d_rectangle): - """ Return whether the passed in object is - a dlib.mmod_rectangle """ - return isinstance( - d_rectangle, - dlib.mmod_rectangle) # pylint: disable=c-extension-no-member - - def convert_to_dlib_rectangle(self, d_rect): - """ Convert detected mmod_rects to dlib_rectangle """ - if self.is_mmod_rectangle(d_rect): - return d_rect.rect - return d_rect - - # <<< MISC METHODS >>> # - @staticmethod - def get_vram_free(): - """ Return total free VRAM on largest card """ - stats = GPUStats() - vram = stats.get_card_most_free() - logger.verbose("Using device %s with %sMB free of %sMB", - vram["device"], - int(vram["free"]), - int(vram["total"])) - return int(vram["free"]) - - @staticmethod - def set_predetected(width, height): - """ Set a dlib rectangle for predetected faces """ - # Predetected_face is used for sort tool. - # Landmarks should not be extracted again from predetected faces, - # because face data is lost, resulting in a large variance - # against extract from original image - logger.debug("Setting predetected face") - return [dlib.rectangle(0, 0, width, height)] # pylint: disable=c-extension-no-member + rotation_matrix = cv2.getRotationMatrix2D(image_center, -1.*angle, 1.) + rotation_matrix[0, 2] += self.input_size / 2 - image_center[0] + rotation_matrix[1, 2] += self.input_size / 2 - image_center[1] + logger.trace("Rotated image: (rotation_matrix: %s", # type:ignore[attr-defined] + rotation_matrix) + image = cv2.warpAffine(image, rotation_matrix, (self.input_size, self.input_size)) + if channels_first: + image = np.moveaxis(image, 2, 0) + + return image, rotation_matrix diff --git a/plugins/extract/detect/cv2_dnn.py b/plugins/extract/detect/cv2_dnn.py new file mode 100644 index 0000000000..7e948eaef6 --- /dev/null +++ b/plugins/extract/detect/cv2_dnn.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +""" OpenCV DNN Face detection plugin """ +import logging + +import numpy as np + +from lib.utils import get_module_objects +from ._base import BatchType, cv2, Detector, DetectorBatch +from . import cv2_dnn_defaults as cfg + + +logger = logging.getLogger(__name__) + + +class Detect(Detector): + """ CV2 DNN detector for face recognition """ + def __init__(self, **kwargs) -> None: + git_model_id = 4 + model_filename = ["resnet_ssd_v1.caffemodel", "resnet_ssd_v1.prototxt"] + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.name = "cv2-DNN Detector" + self.input_size = 300 + self.vram = 0 # CPU Only. Doesn't use VRAM + self.vram_per_batch = 0 + self.batchsize = 1 + self.confidence = cfg.confidence() / 100 + + def init_model(self) -> None: + """ Initialize CV2 DNN Detector Model""" + assert isinstance(self.model_path, list) + self.model = cv2.dnn.readNetFromCaffe(self.model_path[1], + self.model_path[0]) + self.model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detection image(s) for prediction """ + assert isinstance(batch, DetectorBatch) + batch.feed = cv2.dnn.blobFromImages(batch.image, + scalefactor=1.0, + size=(self.input_size, self.input_size), + mean=[104, 117, 123], + swapRB=False, + crop=False) + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + assert isinstance(self.model, cv2.dnn.Net) + self.model.setInput(feed) + predictions = self.model.forward() + return self.finalize_predictions(predictions) + + def finalize_predictions(self, predictions: np.ndarray) -> np.ndarray: + """ Filter faces based on confidence level """ + faces = [] + for i in range(predictions.shape[2]): + confidence = predictions[0, 0, i, 2] + if confidence >= self.confidence: + logger.trace("Accepting due to confidence %s >= %s", # type:ignore[attr-defined] + confidence, self.confidence) + faces.append([(predictions[0, 0, i, 3] * self.input_size), + (predictions[0, 0, i, 4] * self.input_size), + (predictions[0, 0, i, 5] * self.input_size), + (predictions[0, 0, i, 6] * self.input_size)]) + logger.trace("faces: %s", faces) # type:ignore[attr-defined] + return np.array(faces)[None, ...] + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/detect/cv2_dnn_defaults.py b/plugins/extract/detect/cv2_dnn_defaults.py new file mode 100755 index 0000000000..127fae9370 --- /dev/null +++ b/plugins/extract/detect/cv2_dnn_defaults.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Cv2_Dnn Detect plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = ( + "CV2 DNN Detector options.\n" + "A CPU only extractor, is the least reliable, but uses least resources and runs fast on CPU. " + "Use this if not using a GPU and time is important" +) + + +confidence = ConfigItem( + datatype=int, + default=50, + group="settings", + info="The confidence level at which the detector has succesfully found a face.\nHigher " + "levels will be more discriminating, lower levels will have more false positives.", + rounding=5, + min_max=(25, 100)) diff --git a/plugins/extract/detect/dlib_cnn.py b/plugins/extract/detect/dlib_cnn.py deleted file mode 100644 index 40de072dbc..0000000000 --- a/plugins/extract/detect/dlib_cnn.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -""" DLIB CNN Face detection plugin """ - -import numpy as np -import face_recognition_models - -from ._base import Detector, dlib, logger - - -class Detect(Detector): - """ Dlib detector for face recognition """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.target = (1792, 1792) # Uses approx 1805MB of VRAM - self.vram = 1600 # Lower as batch size of 2 gives wiggle room - self.detector = None - - @staticmethod - def compiled_for_cuda(): - """ Return a message on DLIB Cuda Compilation status """ - cuda = dlib.DLIB_USE_CUDA # pylint: disable=c-extension-no-member - msg = "DLib is " - if not cuda: - msg += "NOT " - msg += "compiled to use CUDA" - logger.verbose(msg) - return cuda - - def set_model_path(self): - """ Model path handled by face_recognition_models """ - model_path = face_recognition_models.cnn_face_detector_model_location() - logger.debug("Loading model: '%s'", model_path) - return model_path - - def initialize(self, *args, **kwargs): - """ Calculate batch size """ - super().initialize(*args, **kwargs) - logger.verbose("Initializing Dlib-CNN Detector...") - self.detector = dlib.cnn_face_detection_model_v1( # pylint: disable=c-extension-no-member - self.model_path) - is_cuda = self.compiled_for_cuda() - if is_cuda: - logger.debug("Using GPU") - vram_free = self.get_vram_free() - else: - logger.verbose("Using CPU") - vram_free = 2048 - - # Batch size of 2 actually uses about 338MB less than a single image?? - # From there batches increase at ~680MB per item in the batch - - self.batch_size = int(((vram_free - self.vram) / 680) + 2) - - if self.batch_size < 1: - raise ValueError("Insufficient VRAM available to continue " - "({}MB)".format(int(vram_free))) - - logger.verbose("Processing in batches of %s", self.batch_size) - - self.init.set() - logger.info("Initialized Dlib-CNN Detector...") - - def detect_faces(self, *args, **kwargs): - """ Detect faces in rgb image """ - super().detect_faces(*args, **kwargs) - while True: - exhausted, batch = self.get_batch() - if not batch: - break - filenames = list() - images = list() - for item in batch: - filenames.append(item["filename"]) - images.append(item["image"]) - [detect_images, scales] = self.compile_detection_images(images) - batch_detected = self.detect_batch(detect_images) - processed = self.process_output(batch_detected, - indexes=None, - rotation_matrix=None, - output=None, - scales=scales) - if not all(faces for faces in processed) and self.rotation != [0]: - processed = self.process_rotations(detect_images, processed, scales) - for idx, faces in enumerate(processed): - filename = filenames[idx] - for b_idx, item in enumerate(batch): - if item["filename"] == filename: - output = item - del_idx = b_idx - break - output["detected_faces"] = faces - self.finalize(output) - del batch[del_idx] - if exhausted: - break - self.queues["out"].put("EOF") - del self.detector # Free up VRAM - logger.debug("Detecting Faces complete") - - def compile_detection_images(self, images): - """ Compile the detection images into batches """ - logger.trace("Compiling Detection Images: %s", len(images)) - detect_images = list() - scales = list() - for image in images: - scale = self.set_scale(image, is_square=True, scale_up=True) - detect_images.append(self.set_detect_image(image, scale)) - scales.append(scale) - logger.trace("Compiled Detection Images") - return [detect_images, scales] - - def detect_batch(self, detect_images, disable_message=False): - """ Pass the batch through detector for consistently sized images - or each image separately for inconsitently sized images """ - logger.trace("Detecting Batch") - can_batch = self.check_batch_dims(detect_images) - if can_batch: - logger.trace("Valid for batching") - batch_detected = self.detector(detect_images, 0) - else: - if not disable_message: - logger.verbose("Batch has inconsistently sized images. Processing one " - "image at a time") - batch_detected = dlib.mmod_rectangless( # pylint: disable=c-extension-no-member - [self.detector(detect_image, 0) for detect_image in detect_images]) - logger.trace("Detected Batch: %s", [item for item in batch_detected]) - return batch_detected - - @staticmethod - def check_batch_dims(images): - """ Check all images are the same size for batching """ - dims = set(frame.shape[:2] for frame in images) - logger.trace("Batch Dimensions: %s", dims) - return len(dims) == 1 - - def process_output(self, batch_detected, - indexes=None, rotation_matrix=None, output=None, scales=None): - """ Process the output images """ - logger.trace("Processing Output: (batch_detected: %s, indexes: %s, rotation_matrix: %s, " - "output: %s, scales: %s", - batch_detected, indexes, rotation_matrix, output, scales) - output = output if output else list() - for idx, faces in enumerate(batch_detected): - detected_faces = list() - scale = scales[idx] - - if isinstance(rotation_matrix, np.ndarray): - faces = [self.rotate_rect(face.rect, rotation_matrix) - for face in faces] - - for face in faces: - face = self.convert_to_dlib_rectangle(face) - face = dlib.rectangle( # pylint: disable=c-extension-no-member - int(face.left() / scale), - int(face.top() / scale), - int(face.right() / scale), - int(face.bottom() / scale)) - detected_faces.append(face) - if indexes: - target = indexes[idx] - output[target] = detected_faces - else: - output.append(detected_faces) - logger.trace("Processed Output: %s", output) - return output - - def process_rotations(self, detect_images, processed, scales): - """ Rotate frames missing faces until face is found """ - logger.trace("Processing Rotations") - for angle in self.rotation: - if all(faces for faces in processed): - break - if angle == 0: - continue - reprocess, indexes, rotmat = self.compile_reprocess( - processed, - detect_images, - angle) - - batch_detected = self.detect_batch(reprocess, disable_message=True) - if any(item for item in batch_detected): - logger.verbose("found face(s) by rotating image %s degrees", angle) - processed = self.process_output(batch_detected, - indexes=indexes, - rotation_matrix=rotmat, - output=processed, - scales=scales) - logger.trace("Processed Rotations") - return processed - - def compile_reprocess(self, processed, detect_images, angle): - """ Rotate images which did not find a face for reprocessing """ - logger.trace("Compile images for reprocessing") - indexes = list() - to_detect = list() - for idx, faces in enumerate(processed): - if faces: - continue - image = detect_images[idx] - rot_image, rot_matrix = self.rotate_image_by_angle(image, angle) - to_detect.append(rot_image) - indexes.append(idx) - logger.trace("Compiled images for reprocessing") - return to_detect, indexes, rot_matrix diff --git a/plugins/extract/detect/dlib_hog.py b/plugins/extract/detect/dlib_hog.py deleted file mode 100644 index 2f91d0d1e5..0000000000 --- a/plugins/extract/detect/dlib_hog.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 -""" DLIB CNN Face detection plugin """ -from time import sleep - -import numpy as np - -from ._base import Detector, dlib, logger - - -class Detect(Detector): - """ Dlib detector for face recognition """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.parent_is_pool = True - self.target = (2048, 2048) # Doesn't use VRAM - self.vram = 0 - self.detector = dlib.get_frontal_face_detector() # pylint: disable=c-extension-no-member - self.iterator = None - - def set_model_path(self): - """ No model for dlib Hog """ - pass - - def initialize(self, *args, **kwargs): - """ Calculate batch size """ - super().initialize(*args, **kwargs) - logger.info("Initializing Dlib-HOG Detector...") - logger.verbose("Using CPU for detection") - self.init = True - logger.info("Initialized Dlib-HOG Detector...") - - def detect_faces(self, *args, **kwargs): - """ Detect faces in rgb image """ - super().detect_faces(*args, **kwargs) - while True: - item = self.get_item() - if item == "EOF": - break - logger.trace("Detecting faces: %s", item["filename"]) - [detect_image, scale] = self.compile_detection_image(item["image"], True, True) - - for angle in self.rotation: - current_image, rotmat = self.rotate_image(detect_image, angle) - - logger.trace("Detecting faces") - faces = self.detector(current_image, 0) - logger.trace("Detected faces: %s", [face for face in faces]) - - if angle != 0 and faces.any(): - logger.verbose("found face(s) by rotating image %s degrees", angle) - - if faces: - break - - detected_faces = self.process_output(faces, rotmat, scale) - item["detected_faces"] = detected_faces - self.finalize(item) - - if item == "EOF": - sleep(3) # Wait for all processes to finish before EOF (hacky!) - self.queues["out"].put("EOF") - logger.debug("Detecting Faces Complete") - - def process_output(self, faces, rotation_matrix, scale): - """ Compile found faces for output """ - logger.trace("Processing Output: (faces: %s, rotation_matrix: %s)", - faces, rotation_matrix) - if isinstance(rotation_matrix, np.ndarray): - faces = [self.rotate_rect(face, rotation_matrix) - for face in faces] - detected = [dlib.rectangle( # pylint: disable=c-extension-no-member - int(face.left() / scale), int(face.top() / scale), - int(face.right() / scale), int(face.bottom() / scale)) - for face in faces] - logger.trace("Processed Output: %s", detected) - return detected diff --git a/plugins/extract/detect/external.py b/plugins/extract/detect/external.py new file mode 100644 index 0000000000..074a978796 --- /dev/null +++ b/plugins/extract/detect/external.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +""" Import face detection ROI boxes from a json file """ +from __future__ import annotations + +import logging +import os +import re +import typing as T + +import numpy as np + +from lib.align import AlignedFace +from lib.utils import get_module_objects, FaceswapError, IMAGE_EXTENSIONS + +from ._base import Detector +from . import external_defaults as cfg + +if T.TYPE_CHECKING: + from lib.align import DetectedFace + from plugins.extract import ExtractMedia + from ._base import BatchType + +logger = logging.getLogger(__name__) +OriginType = T.Literal["top-left", "bottom-left", "top-right", "bottom-right"] +# pylint:disable=duplicate-code + + +class Detect(Detector): + """ Import face detection bounding boxes from an external json file """ + def __init__(self, **kwargs) -> None: + kwargs["rotation"] = None # Disable rotation + kwargs["min_size"] = 0 # Disable min_size + super().__init__(git_model_id=None, model_filename=None, **kwargs) + + self.name = "External" + self.batchsize = 16 + + self.origin: OriginType = T.cast(OriginType, cfg.origin()) + """ Literal["top-left", "bottom-left", "top-right", "bottom-right"] : The origin (0, 0) + location of the co-ordinates system used""" + self.file_name = cfg.file_name() + """ str : The file name to import ROI data from """ + + self._re_frame_no: re.Pattern = re.compile(r"\d+$") + self._missing: list[str] = [] + self._log_once = True + self._is_video = False + self._imported: dict[str | int, np.ndarray] = {} + """dict[str | int, np.ndarray]: The imported data from external .json file""" + + def init_model(self) -> None: + """ No initialization to perform """ + logger.debug("No detector model to initialize") + + def _compile_detection_image(self, item: ExtractMedia + ) -> tuple[np.ndarray, float, tuple[int, int]]: + """ Override _compile_detection_image method, to obtain the source frame dimensions + + Parameters + ---------- + item: :class:`~plugins.extract.extract_media.ExtractMedia` + The input item from the pipeline + + Returns + ------- + image: :class:`numpy.ndarray` + dummy empty array + scale: float + The scaling factor for the image (1.0) + pad: int + The amount of padding applied to the image (0, 0) + """ + return np.array(item.image_shape[:2], dtype="int64"), 1.0, (0, 0) + + def _check_for_video(self, filename: str) -> None: + """ Check a sample filename from the import file for a file extension to set + :attr:`_is_video` + + Parameters + ---------- + filename: str + A sample file name from the imported data + """ + logger.debug("Checking for video from '%s'", filename) + ext = os.path.splitext(filename)[-1] + if ext.lower() not in IMAGE_EXTENSIONS: + self._is_video = True + logger.debug("Set is_video to %s from extension '%s'", self._is_video, ext) + + def _get_key(self, key: str) -> str | int: + """ Obtain the key for the item in the lookup table. If the input are images, the key will + be the image filename. If the input is a video, the key will be the frame number + + Parameters + ---------- + key: str + The initial key value from import data or an import image/frame + + Returns + ------- + str | int + The filename is the input data is images, otherwise the frame number of a video + """ + if not self._is_video: + return key + original_name = os.path.splitext(key)[0] + matches = self._re_frame_no.findall(original_name) + if not matches or len(matches) > 1: + raise FaceswapError(f"Invalid import name: '{key}'. For video files, the key should " + "end with the frame number.") + retval = int(matches[0]) + logger.trace("Obtained frame number %s from key '%s'", # type:ignore[attr-defined] + retval, key) + return retval + + @classmethod + def _bbox_from_detected(cls, bounding_box: list[int]) -> np.ndarray: + """ Import the detected face roi from a `detected` item in the import file + + Parameters + ---------- + bounding_box: list[int] + a bounding box contained within the import file + + Returns + ------- + :class:`numpy.ndarray` + The "left", "top", "right", "bottom" bounding box for the face + + Raises + ------ + FaceSwapError + If the number of bounding box co-ordinates is incorrect + """ + if len(bounding_box) != 4: + raise FaceswapError("Imported 'detected' bounding boxes should be a list of 4 numbers " + "representing the 'left', 'top', 'right', `bottom` of a face.") + return np.rint(bounding_box) + + def _validate_landmarks(self, landmarks: list[list[float]]) -> np.ndarray: + """ Validate that the there are 4 or 68 landmarks and are a complete list of (x, y) + co-ordinates + + Parameters + ---------- + landmarks: list[float] + The 4 point ROI or 68 point 2D landmarks that are being imported + + Returns + ------- + :class:`numpy.ndarray` + The original landmarks as a numpy array + + Raises + ------ + FaceSwapError + If the landmarks being imported are not correct + """ + if len(landmarks) not in (4, 68): + raise FaceswapError("Imported 'landmarks_2d' should be either 68 facial feature " + "landmarks or 4 ROI corner locations") + retval = np.array(landmarks, dtype="float32") + if retval.shape[-1] != 2: + raise FaceswapError("Imported 'landmarks_2d' should be formatted as a list of (x, y) " + "co-ordinates") + return retval + + def _bbox_from_landmarks2d(self, landmarks: list[list[float]]) -> np.ndarray: + """ Import the detected face roi by estimating from imported landmarks + + Parameters + ---------- + landmarks: list[float] + The 4 point ROI or 68 point 2D landmarks that are being imported + + Returns + ------- + :class:`numpy.ndarray` + The "left", "top", "right", "bottom" bounding box for the face + """ + n_landmarks = self._validate_landmarks(landmarks) + face = AlignedFace(n_landmarks, centering="legacy", coverage_ratio=0.75) + return np.concatenate([np.min(face.original_roi, axis=0), + np.max(face.original_roi, axis=0)]) + + def _import_frame_face(self, + face: dict[str, list[int] | list[list[float]]], + align_origin: OriginType | None) -> np.ndarray: + """ Import a detected face ROI from the import file + + Parameters + ---------- + face: dict[str, list[int] | list[list[float]]] + The data that exists within the import file for the frame + align_origin: Literal["top-left", "bottom-left", "top-right", "bottom-right"] | None + The origin of the imported aligner data. Used if the detected ROI is being estimated + from imported aligner data + + Returns + ------- + :class:`numpy.ndarray` + The "left", "top", "right", "bottom" bounding box for the face + + Raises + ------ + FaceSwapError + If the required keys for the bounding boxes are not present for the face + """ + if "detected" in face: + return self._bbox_from_detected(T.cast(list[int], face["detected"])) + if "landmarks_2d" in face: + if self._log_once and align_origin is None: + logger.warning("You are importing Detection data, but have only provided " + "Alignment data. This is most likely incorrect and will lead " + "to poor results") + self._log_once = False + + if self._log_once and align_origin is not None and align_origin != self.origin: + logger.info("Updating Detect origin from Aligner config to '%s'", align_origin) + self.origin = align_origin + self._log_once = False + + return self._bbox_from_landmarks2d(T.cast(list[list[float]], face["landmarks_2d"])) + + raise FaceswapError("The provided import file is missing both of the required keys " + "'detected' and 'landmarks_2d") + + def import_data(self, + data: dict[str, list[dict[str, list[int] | list[list[float]]]]], + align_origin: T.Literal["top-left", + "bottom-left", + "top-right", + "bottom-right"] | None) -> None: + """ Import the detection data from the json import file and set to :attr:`_imported` + + Parameters + ---------- + data: dict[str, list[dict[str, list[int] | list[list[float]]]]] + The data to be imported + align_origin: Literal["top-left", "bottom-left", "top-right", "bottom-right"] | None + The origin of the imported aligner data. Used if the detected ROI is being estimated + from imported aligner data + """ + logger.debug("Data length: %s, align_origin: %s", len(data), align_origin) + self._check_for_video(list(data)[0]) + for key, faces in data.items(): + try: + store_key = self._get_key(key) + self._imported[store_key] = np.array([self._import_frame_face(face, align_origin) + for face in faces], dtype="int32") + except FaceswapError as err: + logger.error(str(err)) + msg = f"The imported frame key that failed was '{key}'" + raise FaceswapError(msg) from err + + def process_input(self, batch: BatchType) -> None: + """ Put the lookup key into `batch.feed` so they can be collected for mapping in `.predict` + + Parameters + ---------- + batch: :class:`~plugins.extract.detect._base.DetectorBatch` + The batch to be processed by the plugin + """ + batch.feed = np.array([(self._get_key(os.path.basename(f)), i) + for f, i in zip(batch.filename, batch.image)], dtype="object") + + def _adjust_for_origin(self, box: np.ndarray, frame_dims: tuple[int, int]) -> np.ndarray: + """ Adjust the bounding box to be top-left orientated based on the selected import origin + + Parameters + ---------- + box: :class:`np.ndarray` + The imported bounding box at original (0, 0) origin + frame_dims: tuple[int, int] + The (rows, columns) dimensions of the original frame + + Returns + ------- + :class:`numpy.ndarray` + The adjusted bounding box for a top-left origin + """ + if not np.any(box) or self.origin == "top-left": + return box + if self.origin.startswith("bottom"): + box[:, [1, 3]] = frame_dims[0] - box[:, [1, 3]] + if self.origin.endswith("right"): + box[:, [0, 2]] = frame_dims[1] - box[:, [0, 2]] + + return box + + def predict(self, feed: np.ndarray) -> list[np.ndarray]: # type:ignore[override] + """ Pair the input filenames to the import file + + Parameters + ---------- + feed: :class:`numpy.ndarray` + The filenames with original frame dimensions to obtain the imported bounding boxes for + + Returns + ------- + list[]:class:`numpy.ndarray`] + The bounding boxes for the given filenames + """ + self._missing.extend(f[0] for f in feed if f[0] not in self._imported) + return [self._adjust_for_origin(self._imported.pop(f[0], np.array([], dtype="int32")), + f[1]) + for f in feed] + + def process_output(self, batch: BatchType) -> None: + """ No output processing required for import plugin + + Parameters + ---------- + batch: :class:`~plugins.extract.detect._base.DetectorBatch` + The batch to be processed by the plugin + """ + logger.trace("No output processing for import plugin") # type:ignore[attr-defined] + + def _remove_zero_sized_faces(self, batch_faces: list[list[DetectedFace]] + ) -> list[list[DetectedFace]]: + """ Override _remove_zero_sized_faces to just return the faces that have been imported + + Parameters + ---------- + batch_faces: list[list[DetectedFace] + List of detected face objects + + Returns + ------- + list[list[DetectedFace] + Original list of detected face objects + """ + return batch_faces + + def on_completion(self) -> None: + """ Output information if: + - Imported items were not matched in input data + - Input data was not matched in imported items + """ + super().on_completion() + + if self._missing: + logger.warning("[DETECT] %s input frames could not be matched in the import file " + "'%s'. Run in verbose mode for a list of frames.", + len(self._missing), cfg.file_name()) + logger.verbose( # type:ignore[attr-defined] + "[DETECT] Input frames not in import file: %s", self._missing) + + if self._imported: + logger.warning("[DETECT] %s items in the import file '%s' could not be matched to any " + "input frames. Run in verbose mode for a list of items.", + len(self._imported), cfg.file_name()) + logger.verbose( # type:ignore[attr-defined] + "[DETECT] import file items not in input frames: %s", list(self._imported)) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/detect/external_defaults.py b/plugins/extract/detect/external_defaults.py new file mode 100644 index 0000000000..dd112566ac --- /dev/null +++ b/plugins/extract/detect/external_defaults.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Import Alignments plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "Import Detector options.\n" + "Imports a detected face bounding box from an external .json file.\n" + ) + + +file_name = ConfigItem( + datatype=str, + default="import.json", + group="settings", + info="The import file should be stored in the same folder as the video (if extracting " + "from a video file) or inside the folder of images (if importing from a folder of " + "images)") + +origin = ConfigItem( + datatype=str, + default="top-left", + group="output", + info="The origin (0, 0) location of the co-ordinates system used. " + "\n\t top-left: The origin (0, 0) of the canvas is at the top left " + "corner." + "\n\t bottom-left: The origin (0, 0) of the canvas is at the bottom " + "left corner." + "\n\t top-right: The origin (0, 0) of the canvas is at the top right " + "corner." + "\n\t bottom-right: The origin (0, 0) of the canvas is at the bottom " + "right corner.", + choices=["top-left", "bottom-left", "top-right", "bottom-right"], + gui_radio=True) diff --git a/plugins/extract/detect/manual.py b/plugins/extract/detect/manual.py deleted file mode 100644 index 5b890f8761..0000000000 --- a/plugins/extract/detect/manual.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -""" Manual face detection plugin """ - -from ._base import Detector, dlib, logger - - -class Detect(Detector): - """ Manual Detector """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def set_model_path(self): - """ No model required for Manual Detector """ - return None - - def initialize(self, *args, **kwargs): - """ Create the mtcnn detector """ - super().initialize(*args, **kwargs) - logger.info("Initializing Manual Detector...") - self.init.set() - logger.info("Initialized Manual Detector.") - - def detect_faces(self, *args, **kwargs): - """ Return the given bounding box in a dlib rectangle """ - super().detect_faces(*args, **kwargs) - while True: - item = self.get_item() - if item == "EOF": - break - face = item["face"] - - bounding_box = [dlib.rectangle( # pylint: disable=c-extension-no-member - int(face[0]), int(face[1]), int(face[2]), int(face[3]))] - item["detected_faces"] = bounding_box - self.finalize(item) - - self.queues["out"].put("EOF") diff --git a/plugins/extract/detect/mtcnn.py b/plugins/extract/detect/mtcnn.py index d6aa42c38d..16533e382a 100644 --- a/plugins/extract/detect/mtcnn.py +++ b/plugins/extract/detect/mtcnn.py @@ -1,49 +1,54 @@ #!/usr/bin/env python3 """ MTCNN Face detection plugin """ - -from __future__ import absolute_import, division, print_function - -import os - -from six import string_types, iteritems +from __future__ import annotations +import logging +import typing as T import cv2 import numpy as np -from lib.multithreading import MultiThread -from ._base import Detector, dlib, logger +from keras.models import Model +from keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D, Permute, PReLU +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from ._base import BatchType, Detector +from . import mtcnn_defaults as cfg -# Must import tensorflow inside the spawned process -# for Windows machines -tf = None # pylint: disable = invalid-name - -def import_tensorflow(): - """ Import tensorflow from inside spawned process """ - global tf # pylint: disable = invalid-name,global-statement - import tensorflow as tflow - tf = tflow +logger = logging.getLogger(__name__) class Detect(Detector): - """ MTCNN detector for face recognition """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.kwargs = self.validate_kwargs() - self.name = "mtcnn" - self.target = 2073600 # Uses approx 1.30 GB of VRAM - self.vram = 1408 - - def validate_kwargs(self): + """ MTCNN detector for face recognition. """ + def __init__(self, **kwargs) -> None: + git_model_id = 2 + model_filename = ["mtcnn_det_v2.1.h5", "mtcnn_det_v2.2.h5", "mtcnn_det_v2.3.h5"] + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.name = "MTCNN" + self.model: MTCNN + self.input_size = 640 + self.vram = 128 if not cfg.cpu() else 0 # 66 in testing + self.vram_per_batch = 64 if not cfg.cpu() else 0 # ~50 in testing + self.batchsize = cfg.batch_size() + self._kwargs = self._validate_kwargs() + self.color_format = "RGB" + + def _validate_kwargs(self) -> dict[T.Literal["minsize", "threshold", "factor", "input_size"], + int | float | list[float]]: """ Validate that config options are correct. If not reset to default """ valid = True - threshold = [self.config["threshold_1"], - self.config["threshold_2"], - self.config["threshold_3"]] - kwargs = {"minsize": self.config["minsize"], - "threshold": threshold, - "factor": self.config["scalefactor"]} + threshold = [cfg.threshold_1(), cfg.threshold_2(), cfg.threshold_3()] + kwargs: dict[T.Literal["minsize", "threshold", "factor", "input_size"], + int | float | list[float]] = {"minsize": cfg.minsize(), + "threshold": threshold, + "factor": cfg.scalefactor(), + "input_size": self.input_size} + + assert isinstance(kwargs["input_size"], int) + assert isinstance(kwargs["minsize"], int) + assert isinstance(kwargs["threshold"], list) + assert isinstance(kwargs["factor"], float) if kwargs["minsize"] < 10: valid = False @@ -53,175 +58,83 @@ def validate_kwargs(self): valid = False if not valid: - kwargs = {"minsize": 20, # minimum size of face - "threshold": [0.6, 0.7, 0.7], # three steps threshold - "factor": 0.709} # scale factor + kwargs = {} logger.warning("Invalid MTCNN options in config. Running with defaults") + logger.debug("Using mtcnn kwargs: %s", kwargs) return kwargs - def set_model_path(self): - """ Load the mtcnn models """ - for model in ("det1.npy", "det2.npy", "det3.npy"): - model_path = os.path.join(self.cachepath, model) - if not os.path.exists(model_path): - raise Exception("Error: Unable to find {}, reinstall " - "the lib!".format(model_path)) - logger.debug("Loading model: '%s'", model_path) - return self.cachepath - - def initialize(self, *args, **kwargs): - """ Create the mtcnn detector """ - super().initialize(*args, **kwargs) - logger.info("Initializing MTCNN Detector...") - is_gpu = False - - # Must import tensorflow inside the spawned process - # for Windows machines - import_tensorflow() - vram_free = self.get_vram_free() - mtcnn_graph = tf.Graph() - - # Windows machines sometimes misreport available vram, and overuse - # causing OOM. Allow growth fixes that - config = tf.ConfigProto() - config.gpu_options.allow_growth = True # pylint: disable=no-member - - with mtcnn_graph.as_default(): # pylint: disable=not-context-manager - sess = tf.Session(config=config) - with sess.as_default(): # pylint: disable=not-context-manager - pnet, rnet, onet = create_mtcnn(sess, self.model_path) - - if any("gpu" in str(device).lower() - for device in sess.list_devices()): - logger.debug("Using GPU") - is_gpu = True - mtcnn_graph.finalize() - - if not is_gpu: - alloc = 2048 - logger.warning("Using CPU") - else: - alloc = vram_free - logger.debug("Allocated for Tensorflow: %sMB", alloc) - - self.batch_size = int(alloc / self.vram) - - if self.batch_size < 1: - raise ValueError("Insufficient VRAM available to continue " - "({}MB)".format(int(alloc))) - - logger.verbose("Processing in %s threads", self.batch_size) - - self.kwargs["pnet"] = pnet - self.kwargs["rnet"] = rnet - self.kwargs["onet"] = onet - - self.init.set() - logger.info("Initialized MTCNN Detector.") - - def detect_faces(self, *args, **kwargs): - """ Detect faces in Multiple Threads """ - super().detect_faces(*args, **kwargs) - workers = MultiThread(target=self.detect_thread, thread_count=self.batch_size) - workers.start() - workers.join() - sentinel = self.queues["in"].get() - self.queues["out"].put(sentinel) - logger.debug("Detecting Faces complete") - - def detect_thread(self): - """ Detect faces in rgb image """ - logger.debug("Launching Detect") - while True: - item = self.get_item() - if item == "EOF": - break - logger.trace("Detecting faces: '%s'", item["filename"]) - [detect_image, scale] = self.compile_detection_image(item["image"], False, False) - - for angle in self.rotation: - current_image, rotmat = self.rotate_image(detect_image, angle) - faces, points = detect_face(current_image, **self.kwargs) - if angle != 0 and faces.any(): - logger.verbose("found face(s) by rotating image %s degrees", angle) - if faces.any(): - break - - detected_faces = self.process_output(faces, points, rotmat, scale) - item["detected_faces"] = detected_faces - self.finalize(item) - - logger.debug("Thread Completed Detect") - - def process_output(self, faces, points, rotation_matrix, scale): - """ Compile found faces for output """ - logger.trace("Processing Output: (faces: %s, points: %s, rotation_matrix: %s)", - faces, points, rotation_matrix) - faces = self.recalculate_bounding_box(faces, points) - faces = [dlib.rectangle( # pylint: disable=c-extension-no-member - int(face[0]), int(face[1]), int(face[2]), int(face[3])) - for face in faces] - if isinstance(rotation_matrix, np.ndarray): - faces = [self.rotate_rect(face, rotation_matrix) - for face in faces] - detected = [dlib.rectangle( # pylint: disable=c-extension-no-member - int(face.left() / scale), - int(face.top() / scale), - int(face.right() / scale), - int(face.bottom() / scale)) - for face in faces] - logger.trace("Processed Output: %s", detected) - return detected + def init_model(self) -> None: + """ Initialize MTCNN Model. """ + assert isinstance(self.model_path, list) + placeholder_shape = (self.batchsize, self.input_size, self.input_size, 3) + placeholder = np.zeros(placeholder_shape, dtype="float32") + + assert isinstance(self._kwargs["input_size"], int) + assert isinstance(self._kwargs["minsize"], int) + assert isinstance(self._kwargs["threshold"], list) + assert isinstance(self._kwargs["factor"], float) + + with self.get_device_context(cfg.cpu()): + self.model = MTCNN(self.model_path, + self.batchsize, + input_size=self._kwargs["input_size"], + minsize=self._kwargs["minsize"], + threshold=self._kwargs["threshold"], + factor=self._kwargs["factor"]) + self.model.detect_faces(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detection image(s) for prediction + + Parameters + ---------- + batch: :class:`~plugins.extract.detect._base.DetectorBatch` + Contains the batch that is currently being passed through the plugin process + """ + batch.feed = (np.array(batch.image, dtype="float32") - 127.5) / 127.5 - @staticmethod - def recalculate_bounding_box(faces, landmarks): - """ Recalculate the bounding box for Face Alignment. - - Face Alignment was built to expect a DLIB bounding - box and calculates center and scale based on that. - Resize the bounding box around features to present - a better box to Face Alignment. Helps its chances - on edge cases and helps remove 'jitter' """ - logger.trace("Recalculating Bounding Boxes: (faces: %s, landmarks: %s)", - faces, landmarks) - retval = list() - no_faces = len(faces) - if no_faces == 0: - return retval - face_landmarks = np.hsplit(landmarks, no_faces) - for idx in range(no_faces): - pts = np.reshape(face_landmarks[idx], (5, 2), order="F") - nose = pts[2] - - minmax = (np.amin(pts, axis=0), np.amax(pts, axis=0)) - padding = [(minmax[1][0] - minmax[0][0]) / 2, - (minmax[1][1] - minmax[0][1]) / 2] - - center = (minmax[1][0] - padding[0], minmax[1][1] - padding[1]) - offset = (center[0] - nose[0], nose[1] - center[1]) - center = (center[0] + offset[0], center[1] + offset[1]) - - padding[0] += padding[0] - padding[1] += padding[1] - - bounding = [center[0] - padding[0], center[1] - padding[1], - center[0] + padding[0], center[1] + padding[1]] - retval.append(bounding) - logger.trace("Recalculated Bounding Boxes: %s", retval) - return retval + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions + Parameters + ---------- + batch: :class:`~plugins.extract.detect._base.DetectorBatch` + Contains the batch to pass through the MTCNN model + + Returns + ------- + dict + The batch with the predictions added to the dictionary + """ + assert isinstance(self.model, MTCNN) + with self.get_device_context(cfg.cpu()): + prediction, points = self.model.detect_faces(feed) + logger.trace("prediction: %s, mtcnn_points: %s", # type:ignore[attr-defined] + prediction, points) + return prediction + + def process_output(self, batch: BatchType) -> None: + """ MTCNN performs no post processing so the original batch is returned + + Parameters + ---------- + batch: :class:`~plugins.extract.detect._base.DetectorBatch` + Contains the batch to apply postprocessing to + """ + return -# MTCNN Detector for face alignment -# Code adapted from: https://github.com/davidsandberg/facenet -# Tensorflow implementation of the face detection / alignment algorithm +# MTCNN Detector +# Code adapted from: https://github.com/xiangrufan/keras-mtcnn +# +# Keras implementation of the face detection / alignment algorithm # found at # https://github.com/kpzhang93/MTCNN_face_detection_alignment - +# # MIT License # -# Copyright (c) 2016 David Sandberg +# Copyright (c) 2016 Kaipeng Zhang # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -230,8 +143,8 @@ def recalculate_bounding_box(faces, landmarks): # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, @@ -242,575 +155,609 @@ def recalculate_bounding_box(faces, landmarks): # SOFTWARE. -def layer(operator): - """Decorator for composable network layers.""" - - def layer_decorated(self, *args, **kwargs): - # Automatically set a name if not provided. - name = kwargs.setdefault('name', self.get_unique_name(operator.__name__)) - # Figure out the layer inputs. - if len(self.terminals) == 0: # pylint: disable=len-as-condition - raise RuntimeError('No input variables found for layer %s.' % name) - elif len(self.terminals) == 1: - layer_input = self.terminals[0] - else: - layer_input = list(self.terminals) - # Perform the operation and get the output. - layer_output = operator(self, layer_input, *args, **kwargs) - # Add to layer LUT. - self.layers[name] = layer_output - # This output is now the input for the next layer. - self.feed(layer_output) - # Return self for chained calls. - return self - - return layer_decorated - - -class Network(): - """ Tensorflow Network """ - def __init__(self, inputs, trainable=True): - # The input nodes for this network - self.inputs = inputs - # The current list of terminal nodes - self.terminals = [] - # Mapping from layer names to layers - self.layers = dict(inputs) - # If true, the resulting variables are set as trainable - self.trainable = trainable - - self.setup() - - def setup(self): - """Construct the network. """ - raise NotImplementedError('Must be implemented by the subclass.') +class PNet(): + """ Keras P-Net model for MTCNN + + Parameters + ---------- + weights_path: str + The path to the keras model file + batch_size: int + The batch size to feed the model + input_size: int + The input size of the model + minsize: int, optional + The minimum size of a face to accept as a detection. Default: `20` + threshold: list, optional + Threshold for P-Net + """ + def __init__(self, + weights_path: str, + batch_size: int, + input_size: int, + min_size: int, + factor: float, + threshold: float) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + + self._input_size = input_size + self._threshold = threshold + + self._pnet_scales = self._calculate_scales(min_size, factor) + self._pnet_sizes = [(int(input_size * scale), int(input_size * scale)) + for scale in self._pnet_scales] + self._pnet_input: list[np.ndarray] | None = None + logger.debug("Initialized: %s", self.__class__.__name__) @staticmethod - def load(model_path, session, ignore_missing=False): - """Load network weights. - model_path: The path to the numpy-serialized network weights - session: The current TensorFlow session - ignore_missing: If true, serialized weights for missing layers are - ignored. + def _load_model(weights_path: str) -> Model: + """ Keras P-Network Definition for MTCNN + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The p-net model """ - # pylint: disable=no-member - data_dict = np.load(model_path, encoding='latin1').item() - - for op_name in data_dict: - with tf.variable_scope(op_name, reuse=True): - for param_name, data in iteritems(data_dict[op_name]): - try: - var = tf.get_variable(param_name) - session.run(var.assign(data)) - except ValueError: - if not ignore_missing: - raise - - def feed(self, *args): - """Set the input(s) for the next operation by replacing the terminal nodes. - The arguments can be either layer names or the actual layers. + input_ = Input(shape=(None, None, 3)) + var_x = Conv2D(10, (3, 3), strides=1, padding='valid', name='conv1')(input_) + var_x = PReLU(shared_axes=[1, 2], name='PReLU1')(var_x) + var_x = MaxPooling2D(pool_size=2)(var_x) + var_x = Conv2D(16, (3, 3), strides=1, padding='valid', name='conv2')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='PReLU2')(var_x) + var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv3')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='PReLU3')(var_x) + classifier = Conv2D(2, (1, 1), activation='softmax', name='conv4-1')(var_x) + bbox_regress = Conv2D(4, (1, 1), name='conv4-2')(var_x) + + retval = Model(input_, [classifier, bbox_regress]) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def _calculate_scales(self, + minsize: int, + factor: float) -> list[float]: + """ Calculate multi-scale + + Parameters + ---------- + minsize: int + Minimum size for a face to be accepted + factor: float + Scaling factor + + Returns + ------- + list + List of scale floats + """ + factor_count = 0 + var_m = 12.0 / minsize + minl = self._input_size * var_m + # create scale pyramid + scales = [] + while minl >= 12: + scales += [var_m * np.power(factor, factor_count)] + minl = minl * factor + factor_count += 1 + logger.trace(scales) # type:ignore[attr-defined] + return scales + + def _detect_face_12net(self, + class_probabilities: np.ndarray, + roi: np.ndarray, + size: int, + scale: float) -> tuple[np.ndarray, np.ndarray]: + """ Detect face position and calibrate bounding box on 12net feature map(matrix version) + + Parameters + ---------- + class_probabilities: :class:`numpy.ndarray` + softmax feature map for face classify + roi: :class:`numpy.ndarray` + feature map for regression + size: int + feature map's largest size + scale: float + current input image scale in multi-scales + + Returns + ------- + list + Calibrated face candidates """ - assert len(args) != 0 # pylint: disable=len-as-condition - self.terminals = [] - for fed_layer in args: - if isinstance(fed_layer, string_types): - try: - fed_layer = self.layers[fed_layer] - except KeyError: - raise KeyError('Unknown layer name fed: %s' % fed_layer) - self.terminals.append(fed_layer) - return self - - def get_output(self): - """Returns the current network output.""" - return self.terminals[-1] - - def get_unique_name(self, prefix): - """Returns an index-suffixed unique name for the given prefix. - This is used for auto-generating layer names based on the type-prefix. + in_side = 2 * size + 11 + stride = 0. if size == 1 else float(in_side - 12) / (size - 1) + (var_x, var_y) = np.nonzero(class_probabilities >= self._threshold) + boundingbox = np.array([var_x, var_y]).T + + boundingbox = np.concatenate((np.fix((stride * (boundingbox) + 0) * scale), + np.fix((stride * (boundingbox) + 11) * scale)), axis=1) + offset = roi[:4, var_x, var_y].T + boundingbox = boundingbox + offset * 12.0 * scale + rectangles = np.concatenate((boundingbox, + np.array([class_probabilities[var_x, var_y]]).T), axis=1) + rectangles = rect2square(rectangles) + + np.clip(rectangles[..., :4], 0., self._input_size, out=rectangles[..., :4]) + pick = np.where(np.logical_and(rectangles[..., 2] > rectangles[..., 0], + rectangles[..., 3] > rectangles[..., 1]))[0] + rects = rectangles[pick, :4].astype("int") + scores = rectangles[pick, 4] + + return nms(rects, scores, 0.3, "iou") + + def __call__(self, images: np.ndarray) -> list[np.ndarray]: + """ first stage - fast proposal network (p-net) to obtain face candidates + + Parameters + ---------- + images: :class:`numpy.ndarray` + The batch of images to detect faces in + + Returns + ------- + List + List of face candidates from P-Net """ - ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1 - return '%s_%d' % (prefix, ident) + batch_size = images.shape[0] + rectangles: list[list[list[int | float]]] = [[] for _ in range(batch_size)] + scores: list[list[np.ndarray]] = [[] for _ in range(batch_size)] + + if self._pnet_input is None: + self._pnet_input = [np.empty((batch_size, rheight, rwidth, 3), dtype="float32") + for rheight, rwidth in self._pnet_sizes] + + for scale, batch, (rheight, rwidth) in zip(self._pnet_scales, + self._pnet_input, + self._pnet_sizes): + _ = [cv2.resize(images[idx], (rwidth, rheight), dst=batch[idx]) + for idx in range(batch_size)] + cls_prob, roi = self._model.predict(batch, verbose=0, batch_size=self._batch_size) + cls_prob = cls_prob[..., 1] + out_side = max(cls_prob.shape[1:3]) + cls_prob = np.swapaxes(cls_prob, 1, 2) + roi = np.swapaxes(roi, 1, 3) + for idx in range(batch_size): + # first index 0 = class score, 1 = one hot representation + rect, score = self._detect_face_12net(cls_prob[idx, ...], + roi[idx, ...], + out_side, + 1 / scale) + rectangles[idx].extend(rect) + scores[idx].extend(score) + + return [nms(np.array(rect), np.array(score), 0.7, "iou")[0] # don't output scores + for rect, score in zip(rectangles, scores)] + + +class RNet(): + """ Keras R-Net model Definition for MTCNN + + Parameters + ---------- + weights_path: str + The path to the keras model file + batch_size: int + The batch size to feed the model + input_size: int + The input size of the model + threshold: list, optional + Threshold for R-Net - def make_var(self, name, shape): - """Creates a new TensorFlow variable.""" - return tf.get_variable(name, shape, trainable=self.trainable) + """ + def __init__(self, + weights_path: str, + batch_size: int, + input_size: int, + threshold: float) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + self._input_size = input_size + self._threshold = threshold + logger.debug("Initialized: %s", self.__class__.__name__) @staticmethod - def validate_padding(padding): - """Verifies that the padding is one of the supported ones.""" - assert padding in ('SAME', 'VALID') - - @layer - def conv(self, # pylint: disable=too-many-arguments - inp, - k_h, - k_w, - c_o, - s_h, - s_w, - name, - relu=True, - padding='SAME', - group=1, - biased=True): - """ Conv Layer """ - # pylint: disable=too-many-locals - - # Verify that the padding is acceptable - self.validate_padding(padding) - # Get the number of channels in the input - c_i = int(inp.get_shape()[-1]) - # Verify that the grouping parameter is valid - assert c_i % group == 0 - assert c_o % group == 0 - # Convolution for a given input and kernel - convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) # noqa - with tf.variable_scope(name) as scope: - kernel = self.make_var('weights', - shape=[k_h, k_w, c_i // group, c_o]) - # This is the common-case. Convolve the input without any - # further complications. - output = convolve(inp, kernel) - # Add the biases - if biased: - biases = self.make_var('biases', [c_o]) - output = tf.nn.bias_add(output, biases) - if relu: - # ReLU non-linearity - output = tf.nn.relu(output, name=scope.name) - return output - - @layer - def prelu(self, inp, name): - """ Prelu Layer """ - with tf.variable_scope(name): - i = int(inp.get_shape()[-1]) - alpha = self.make_var('alpha', shape=(i,)) - output = tf.nn.relu(inp) + tf.multiply(alpha, -tf.nn.relu(-inp)) - return output - - @layer - def max_pool(self, inp, k_h, k_w, # pylint: disable=too-many-arguments - s_h, s_w, name, padding='SAME'): - """ Max Pool Layer """ - self.validate_padding(padding) - return tf.nn.max_pool(inp, - ksize=[1, k_h, k_w, 1], - strides=[1, s_h, s_w, 1], - padding=padding, - name=name) - - @layer - def fc(self, inp, num_out, name, relu=True): # pylint: disable=invalid-name - """ FC Layer """ - with tf.variable_scope(name): - input_shape = inp.get_shape() - if input_shape.ndims == 4: - # The input is spatial. Vectorize it first. - dim = 1 - for this_dim in input_shape[1:].as_list(): - dim *= int(this_dim) - feed_in = tf.reshape(inp, [-1, dim]) - else: - feed_in, dim = (inp, input_shape[-1].value) - weights = self.make_var('weights', shape=[dim, num_out]) - biases = self.make_var('biases', [num_out]) - operator = tf.nn.relu_layer if relu else tf.nn.xw_plus_b - fc = operator(feed_in, weights, biases, name=name) # pylint: disable=invalid-name - return fc - - @layer - def softmax(self, target, axis, name=None): # pylint: disable=no-self-use - """ Multi dimensional softmax, - refer to https://github.com/tensorflow/tensorflow/issues/210 - compute softmax along the dimension of target - the native softmax only supports batch_size x dimension """ - - max_axis = tf.reduce_max(target, axis, keepdims=True) - target_exp = tf.exp(target-max_axis) - normalize = tf.reduce_sum(target_exp, axis, keepdims=True) - softmax = tf.div(target_exp, normalize, name) - return softmax - - -class PNet(Network): - """ Tensorflow PNet """ - def setup(self): - (self.feed('data') # pylint: disable=no-value-for-parameter, no-member - .conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1') - .prelu(name='PReLU1') - .max_pool(2, 2, 2, 2, name='pool1') - .conv(3, 3, 16, 1, 1, padding='VALID', relu=False, name='conv2') - .prelu(name='PReLU2') - .conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv3') - .prelu(name='PReLU3') - .conv(1, 1, 2, 1, 1, relu=False, name='conv4-1') - .softmax(3, name='prob1')) - - (self.feed('PReLU3') # pylint: disable=no-value-for-parameter - .conv(1, 1, 4, 1, 1, relu=False, name='conv4-2')) - - -class RNet(Network): - """ Tensorflow RNet """ - def setup(self): - (self.feed('data') # pylint: disable=no-value-for-parameter, no-member - .conv(3, 3, 28, 1, 1, padding='VALID', relu=False, name='conv1') - .prelu(name='prelu1') - .max_pool(3, 3, 2, 2, name='pool1') - .conv(3, 3, 48, 1, 1, padding='VALID', relu=False, name='conv2') - .prelu(name='prelu2') - .max_pool(3, 3, 2, 2, padding='VALID', name='pool2') - .conv(2, 2, 64, 1, 1, padding='VALID', relu=False, name='conv3') - .prelu(name='prelu3') - .fc(128, relu=False, name='conv4') - .prelu(name='prelu4') - .fc(2, relu=False, name='conv5-1') - .softmax(1, name='prob1')) - - (self.feed('prelu4') # pylint: disable=no-value-for-parameter - .fc(4, relu=False, name='conv5-2')) - - -class ONet(Network): - """ Tensorflow ONet """ - def setup(self): - (self.feed('data') # pylint: disable=no-value-for-parameter, no-member - .conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv1') - .prelu(name='prelu1') - .max_pool(3, 3, 2, 2, name='pool1') - .conv(3, 3, 64, 1, 1, padding='VALID', relu=False, name='conv2') - .prelu(name='prelu2') - .max_pool(3, 3, 2, 2, padding='VALID', name='pool2') - .conv(3, 3, 64, 1, 1, padding='VALID', relu=False, name='conv3') - .prelu(name='prelu3') - .max_pool(2, 2, 2, 2, name='pool3') - .conv(2, 2, 128, 1, 1, padding='VALID', relu=False, name='conv4') - .prelu(name='prelu4') - .fc(256, relu=False, name='conv5') - .prelu(name='prelu5') - .fc(2, relu=False, name='conv6-1') - .softmax(1, name='prob1')) - - (self.feed('prelu5') # pylint: disable=no-value-for-parameter - .fc(4, relu=False, name='conv6-2')) - - (self.feed('prelu5') # pylint: disable=no-value-for-parameter - .fc(10, relu=False, name='conv6-3')) - - -def create_mtcnn(sess, model_path): - """ Create the network """ - if not model_path: - model_path, _ = os.path.split(os.path.realpath(__file__)) - - with tf.variable_scope('pnet'): - data = tf.placeholder(tf.float32, (None, None, None, 3), 'input') - pnet = PNet({'data': data}) - pnet.load(os.path.join(model_path, 'det1.npy'), sess) - with tf.variable_scope('rnet'): - data = tf.placeholder(tf.float32, (None, 24, 24, 3), 'input') - rnet = RNet({'data': data}) - rnet.load(os.path.join(model_path, 'det2.npy'), sess) - with tf.variable_scope('onet'): - data = tf.placeholder(tf.float32, (None, 48, 48, 3), 'input') - onet = ONet({'data': data}) - onet.load(os.path.join(model_path, 'det3.npy'), sess) - - pnet_fun = lambda img: sess.run(('pnet/conv4-2/BiasAdd:0', # noqa - 'pnet/prob1:0'), - feed_dict={'pnet/input:0': img}) - rnet_fun = lambda img: sess.run(('rnet/conv5-2/conv5-2:0', # noqa - 'rnet/prob1:0'), - feed_dict={'rnet/input:0': img}) - onet_fun = lambda img: sess.run(('onet/conv6-2/conv6-2:0', # noqa - 'onet/conv6-3/conv6-3:0', - 'onet/prob1:0'), - feed_dict={'onet/input:0': img}) - return pnet_fun, rnet_fun, onet_fun - - -def detect_face(img, minsize, pnet, rnet, # pylint: disable=too-many-arguments - onet, threshold, factor): - """Detects faces in an image, and returns bounding boxes and points for them. - img: input image - minsize: minimum faces' size - pnet, rnet, onet: caffemodel - threshold: threshold=[th1, th2, th3], th1-3 are three steps's threshold - factor: the factor used to create a scaling pyramid of face sizes to - detect in the image. + def _load_model(weights_path: str) -> Model: + """ Keras R-Network Definition for MTCNN + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The r-net model + """ + input_ = Input(shape=(24, 24, 3)) + var_x = Conv2D(28, (3, 3), strides=1, padding='valid', name='conv1')(input_) + var_x = PReLU(shared_axes=[1, 2], name='prelu1')(var_x) + var_x = MaxPooling2D(pool_size=3, strides=2, padding='same')(var_x) + + var_x = Conv2D(48, (3, 3), strides=1, padding='valid', name='conv2')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='prelu2')(var_x) + var_x = MaxPooling2D(pool_size=3, strides=2)(var_x) + + var_x = Conv2D(64, (2, 2), strides=1, padding='valid', name='conv3')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='prelu3')(var_x) + var_x = Permute((3, 2, 1))(var_x) + var_x = Flatten()(var_x) + var_x = Dense(128, name='conv4')(var_x) + var_x = PReLU(name='prelu4')(var_x) + classifier = Dense(2, activation='softmax', name='conv5-1')(var_x) + bbox_regress = Dense(4, name='conv5-2')(var_x) + + retval = Model(input_, [classifier, bbox_regress]) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def _filter_face_24net(self, + class_probabilities: np.ndarray, + roi: np.ndarray, + rectangles: np.ndarray, + ) -> np.ndarray: + """ Filter face position and calibrate bounding box on 12net's output + + Parameters + ---------- + class_probabilities: class:`np.ndarray` + Softmax feature map for face classify + roi: :class:`numpy.ndarray` + Feature map for regression + rectangles: list + 12net's predict + + Returns + ------- + list + rectangles in the format [[x, y, x1, y1, score]] + """ + prob = class_probabilities[:, 1] + pick = np.nonzero(prob >= self._threshold) + + bbox = rectangles.T[:4, pick] + scores = np.array([prob[pick]]).T.ravel() + deltas = roi.T[:4, pick] + + dims = np.tile([bbox[2] - bbox[0], bbox[3] - bbox[1]], (2, 1, 1)) + bbox = np.transpose(bbox + deltas * dims).reshape(-1, 4) + bbox = np.clip(rect2square(bbox), 0, self._input_size).astype("int") + return nms(bbox, scores, 0.3, "iou")[0] + + def __call__(self, + images: np.ndarray, + rectangle_batch: list[np.ndarray], + ) -> list[np.ndarray]: + """ second stage - refinement of face candidates with r-net + + Parameters + ---------- + images: :class:`numpy.ndarray` + The batch of images to detect faces in + rectangle_batch: + List of :class:`numpy.ndarray` face candidates from P-Net + + Returns + ------- + List + List of :class:`numpy.ndarray` refined face candidates from R-Net + """ + ret: list[np.ndarray] = [] + for idx, (rectangles, image) in enumerate(zip(rectangle_batch, images)): + if not np.any(rectangles): + ret.append(np.array([])) + continue + + feed_batch = np.empty((rectangles.shape[0], 24, 24, 3), dtype="float32") + + _ = [cv2.resize(image[rect[1]: rect[3], rect[0]: rect[2]], + (24, 24), + dst=feed_batch[idx]) + for idx, rect in enumerate(rectangles)] + + cls_prob, roi_prob = self._model.predict(feed_batch, + verbose=0, + batch_size=self._batch_size) + ret.append(self._filter_face_24net(cls_prob, roi_prob, rectangles)) + return ret + + +class ONet(): + """ Keras O-Net model for MTCNN + + Parameters + ---------- + weights_path: str + The path to the keras model file + batch_size: int + The batch size to feed the model + input_size: int + The input size of the model + threshold: list, optional + Threshold for O-Net + """ + def __init__(self, + weights_path: str, + batch_size: int, + input_size: int, + threshold: float) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + self._input_size = input_size + self._threshold = threshold + logger.debug("Initialized: %s", self.__class__.__name__) + + @staticmethod + def _load_model(weights_path: str) -> Model: + """ Keras P-Network Definition for MTCNN + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The p-net model + """ + input_ = Input(shape=(48, 48, 3)) + var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv1')(input_) + var_x = PReLU(shared_axes=[1, 2], name='prelu1')(var_x) + var_x = MaxPooling2D(pool_size=3, strides=2, padding='same')(var_x) + var_x = Conv2D(64, (3, 3), strides=1, padding='valid', name='conv2')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='prelu2')(var_x) + var_x = MaxPooling2D(pool_size=3, strides=2)(var_x) + var_x = Conv2D(64, (3, 3), strides=1, padding='valid', name='conv3')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='prelu3')(var_x) + var_x = MaxPooling2D(pool_size=2)(var_x) + var_x = Conv2D(128, (2, 2), strides=1, padding='valid', name='conv4')(var_x) + var_x = PReLU(shared_axes=[1, 2], name='prelu4')(var_x) + var_x = Permute((3, 2, 1))(var_x) + var_x = Flatten()(var_x) + var_x = Dense(256, name='conv5')(var_x) + var_x = PReLU(name='prelu5')(var_x) + + classifier = Dense(2, activation='softmax', name='conv6-1')(var_x) + bbox_regress = Dense(4, name='conv6-2')(var_x) + landmark_regress = Dense(10, name='conv6-3')(var_x) + retval = Model(input_, [classifier, bbox_regress, landmark_regress]) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def _filter_face_48net(self, class_probabilities: np.ndarray, + roi: np.ndarray, + points: np.ndarray, + rectangles: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ Filter face position and calibrate bounding box on 12net's output + + Parameters + ---------- + class_probabilities: :class:`numpy.ndarray` : class_probabilities[1] is face possibility + Array of face probabilities + roi: :class:`numpy.ndarray` + offset + points: :class:`numpy.ndarray` + 5 point face landmark + rectangles: :class:`numpy.ndarray` + 12net's predict, rectangles[i][0:3] is the position, rectangles[i][4] is score + + Returns + ------- + boxes: :class:`numpy.ndarray` + The [l, t, r, b, score] bounding boxes + points: :class:`numpy.ndarray` + The 5 point landmarks + """ + prob = class_probabilities[:, 1] + pick = np.nonzero(prob >= self._threshold)[0] + scores = np.array([prob[pick]]).T.ravel() + + bbox = rectangles[pick] + dims = np.array([bbox[..., 2] - bbox[..., 0], bbox[..., 3] - bbox[..., 1]]).T + + pts = np.vstack( + np.hsplit(points[pick], 2)).reshape(2, -1, 5).transpose(1, 2, 0).reshape(-1, 10) + pts = np.tile(dims, (1, 5)) * pts + np.tile(bbox[..., :2], (1, 5)) + + bbox = np.clip(np.floor(bbox + roi[pick] * np.tile(dims, (1, 2))), + 0., + self._input_size) + + indices = np.where( + np.logical_and(bbox[..., 2] > bbox[..., 0], bbox[..., 3] > bbox[..., 1]))[0] + picks = np.concatenate([bbox[indices], pts[indices]], axis=-1) + + results, scores = nms(picks, scores, 0.3, "iom") + return np.concatenate([results[..., :4], scores[..., None]], axis=-1), results[..., 4:].T + + def __call__(self, + images: np.ndarray, + rectangle_batch: list[np.ndarray] + ) -> list[tuple[np.ndarray, np.ndarray]]: + """ Third stage - further refinement and facial landmarks positions with o-net + + Parameters + ---------- + images: :class:`numpy.ndarray` + The batch of images to detect faces in + rectangle_batch: + List of :class:`numpy.ndarray` face candidates from R-Net + + Returns + ------- + List + List of refined final candidates, scores and landmark points from O-Net + """ + ret: list[tuple[np.ndarray, np.ndarray]] = [] + for idx, rectangles in enumerate(rectangle_batch): + if not np.any(rectangles): + ret.append((np.empty((0, 5)), np.empty(0))) + continue + image = images[idx] + feed_batch = np.empty((rectangles.shape[0], 48, 48, 3), dtype="float32") + + _ = [cv2.resize(image[rect[1]: rect[3], rect[0]: rect[2]], + (48, 48), + dst=feed_batch[idx]) + for idx, rect in enumerate(rectangles)] + + cls_probs, roi_probs, pts_probs = self._model.predict(feed_batch, + verbose=0, + batch_size=self._batch_size) + ret.append(self._filter_face_48net(cls_probs, roi_probs, pts_probs, rectangles)) + return ret + + +class MTCNN(): + """ MTCNN Detector for face alignment + + Parameters + ---------- + weights_path: list + List of paths to the 3 MTCNN subnet weights + batch_size: int + The batch size to feed the model + input_size: int, optional + The height, width input size to the model. Default: 640 + minsize: int, optional + The minimum size of a face to accept as a detection. Default: `20` + threshold: list, optional + List of floats for the three steps, Default: `[0.6, 0.7, 0.7]` + factor: float, optional + The factor used to create a scaling pyramid of face sizes to detect in the image. + Default: `0.709` """ - # pylint: disable=too-many-locals,too-many-statements,too-many-branches - factor_count = 0 - total_boxes = np.empty((0, 9)) - points = np.empty(0) - height = img.shape[0] - width = img.shape[1] - minl = np.amin([height, width]) - var_m = 12.0 / minsize - minl = minl * var_m - # create scale pyramid - scales = [] - while minl >= 12: - scales += [var_m * np.power(factor, factor_count)] - minl = minl * factor - factor_count += 1 - - # # # # # # # # # # # # # - # first stage - fast proposal network (pnet) to obtain face candidates - # # # # # # # # # # # # # - for scale in scales: - height_scale = int(np.ceil(height * scale)) - width_scale = int(np.ceil(width * scale)) - im_data = imresample(img, (height_scale, width_scale)) - im_data = (im_data - 127.5) * 0.0078125 - img_x = np.expand_dims(im_data, 0) - img_y = np.transpose(img_x, (0, 2, 1, 3)) - out = pnet(img_y) - out0 = np.transpose(out[0], (0, 2, 1, 3)) - out1 = np.transpose(out[1], (0, 2, 1, 3)) - - boxes, _ = generate_bounding_box(out1[0, :, :, 1].copy(), - out0[0, :, :, :].copy(), - scale, threshold[0]) - - # inter-scale nms - pick = nms(boxes.copy(), 0.5, 'Union') - if boxes.size > 0 and pick.size > 0: - boxes = boxes[pick, :] - total_boxes = np.append(total_boxes, boxes, axis=0) - - numbox = total_boxes.shape[0] - if numbox > 0: - pick = nms(total_boxes.copy(), 0.7, 'Union') - total_boxes = total_boxes[pick, :] - regw = total_boxes[:, 2]-total_boxes[:, 0] - regh = total_boxes[:, 3]-total_boxes[:, 1] - qq_1 = total_boxes[:, 0]+total_boxes[:, 5] * regw - qq_2 = total_boxes[:, 1]+total_boxes[:, 6] * regh - qq_3 = total_boxes[:, 2]+total_boxes[:, 7] * regw - qq_4 = total_boxes[:, 3]+total_boxes[:, 8] * regh - total_boxes = np.transpose(np.vstack([qq_1, qq_2, qq_3, qq_4, total_boxes[:, 4]])) - total_boxes = rerec(total_boxes.copy()) - total_boxes[:, 0:4] = np.fix(total_boxes[:, 0:4]).astype(np.int32) - d_y, ed_y, d_x, ed_x, var_y, e_y, var_x, e_x, tmpw, tmph = pad(total_boxes.copy(), - width, height) - - numbox = total_boxes.shape[0] - - # # # # # # # # # # # # # - # second stage - refinement of face candidates with rnet - # # # # # # # # # # # # # - - if numbox > 0: - tempimg = np.zeros((24, 24, 3, numbox)) - for k in range(0, numbox): - tmp = np.zeros((int(tmph[k]), int(tmpw[k]), 3)) - tmp[d_y[k] - 1:ed_y[k], d_x[k] - 1:ed_x[k], :] = img[var_y[k] - 1:e_y[k], - var_x[k]-1:e_x[k], :] - if tmp.shape[0] > 0 and tmp.shape[1] > 0 or tmp.shape[0] == 0 and tmp.shape[1] == 0: - tempimg[:, :, :, k] = imresample(tmp, (24, 24)) - else: - return np.empty() - tempimg = (tempimg-127.5)*0.0078125 - tempimg1 = np.transpose(tempimg, (3, 1, 0, 2)) - out = rnet(tempimg1) - out0 = np.transpose(out[0]) - out1 = np.transpose(out[1]) - score = out1[1, :] - ipass = np.where(score > threshold[1]) - total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), - np.expand_dims(score[ipass].copy(), 1)]) - m_v = out0[:, ipass[0]] - if total_boxes.shape[0] > 0: - pick = nms(total_boxes, 0.7, 'Union') - total_boxes = total_boxes[pick, :] - total_boxes = bbreg(total_boxes.copy(), np.transpose(m_v[:, pick])) - total_boxes = rerec(total_boxes.copy()) - - numbox = total_boxes.shape[0] - - # # # # # # # # # # # # # - # third stage - further refinement and facial landmarks positions with onet - # NB: Facial landmarks code commented out for faceswap - # # # # # # # # # # # # # - - if numbox > 0: - # third stage - total_boxes = np.fix(total_boxes).astype(np.int32) - d_y, ed_y, d_x, ed_x, var_y, e_y, var_x, e_x, tmpw, tmph = pad(total_boxes.copy(), - width, height) - tempimg = np.zeros((48, 48, 3, numbox)) - for k in range(0, numbox): - tmp = np.zeros((int(tmph[k]), int(tmpw[k]), 3)) - tmp[d_y[k] - 1:ed_y[k], d_x[k] - 1:ed_x[k], :] = img[var_y[k] - 1:e_y[k], - var_x[k] - 1:e_x[k], :] - if tmp.shape[0] > 0 and tmp.shape[1] > 0 or tmp.shape[0] == 0 and tmp.shape[1] == 0: - tempimg[:, :, :, k] = imresample(tmp, (48, 48)) - else: - return np.empty() - tempimg = (tempimg-127.5)*0.0078125 - tempimg1 = np.transpose(tempimg, (3, 1, 0, 2)) - out = onet(tempimg1) - out0 = np.transpose(out[0]) - out1 = np.transpose(out[1]) - out2 = np.transpose(out[2]) - score = out2[1, :] - points = out1 - ipass = np.where(score > threshold[2]) - points = points[:, ipass[0]] - total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), - np.expand_dims(score[ipass].copy(), 1)]) - m_v = out0[:, ipass[0]] - - width = total_boxes[:, 2] - total_boxes[:, 0] + 1 - height = total_boxes[:, 3] - total_boxes[:, 1] + 1 - points[0:5, :] = (np.tile(width, (5, 1)) * points[0:5, :] + - np.tile(total_boxes[:, 0], (5, 1)) - 1) - points[5:10, :] = (np.tile(height, (5, 1)) * points[5:10, :] + - np.tile(total_boxes[:, 1], (5, 1)) - 1) - if total_boxes.shape[0] > 0: - total_boxes = bbreg(total_boxes.copy(), np.transpose(m_v)) - pick = nms(total_boxes.copy(), 0.7, 'Min') - total_boxes = total_boxes[pick, :] - points = points[:, pick] - - return total_boxes, points - - -# function [boundingbox] = bbreg(boundingbox,reg) -def bbreg(boundingbox, reg): - """Calibrate bounding boxes""" - if reg.shape[1] == 1: - reg = np.reshape(reg, (reg.shape[2], reg.shape[3])) - - width = boundingbox[:, 2] - boundingbox[:, 0] + 1 - height = boundingbox[:, 3] - boundingbox[:, 1] + 1 - b_1 = boundingbox[:, 0] + reg[:, 0] * width - b_2 = boundingbox[:, 1] + reg[:, 1] * height - b_3 = boundingbox[:, 2] + reg[:, 2] * width - b_4 = boundingbox[:, 3] + reg[:, 3] * height - boundingbox[:, 0:4] = np.transpose(np.vstack([b_1, b_2, b_3, b_4])) - return boundingbox - - -def generate_bounding_box(imap, reg, scale, threshold): - """Use heatmap to generate bounding boxes""" - # pylint: disable=too-many-locals - stride = 2 - cellsize = 12 - - imap = np.transpose(imap) - d_x1 = np.transpose(reg[:, :, 0]) - d_y1 = np.transpose(reg[:, :, 1]) - d_x2 = np.transpose(reg[:, :, 2]) - d_y2 = np.transpose(reg[:, :, 3]) - dim_y, dim_x = np.where(imap >= threshold) - if dim_y.shape[0] == 1: - d_x1 = np.flipud(d_x1) - d_y1 = np.flipud(d_y1) - d_x2 = np.flipud(d_x2) - d_y2 = np.flipud(d_y2) - score = imap[(dim_y, dim_x)] - reg = np.transpose(np.vstack([d_x1[(dim_y, dim_x)], d_y1[(dim_y, dim_x)], - d_x2[(dim_y, dim_x)], d_y2[(dim_y, dim_x)]])) - if reg.size == 0: - reg = np.empty((0, 3)) - bbox = np.transpose(np.vstack([dim_y, dim_x])) - q_1 = np.fix((stride * bbox + 1) / scale) - q_2 = np.fix((stride * bbox + cellsize - 1 + 1) / scale) - boundingbox = np.hstack([q_1, q_2, np.expand_dims(score, 1), reg]) - return boundingbox, reg - - -# function pick = nms(boxes,threshold,type) -def nms(boxes, threshold, method): - """ Non_Max Suppression """ - # pylint: disable=too-many-locals - if boxes.size == 0: - return np.empty((0, 3)) - x_1 = boxes[:, 0] - y_1 = boxes[:, 1] - x_2 = boxes[:, 2] - y_2 = boxes[:, 3] - var_s = boxes[:, 4] - area = (x_2 - x_1 + 1) * (y_2 - y_1 + 1) - s_sort = np.argsort(var_s) - pick = np.zeros_like(var_s, dtype=np.int16) - counter = 0 - while s_sort.size > 0: - i = s_sort[-1] - pick[counter] = i - counter += 1 - idx = s_sort[0:-1] - xx_1 = np.maximum(x_1[i], x_1[idx]) - yy_1 = np.maximum(y_1[i], y_1[idx]) - xx_2 = np.minimum(x_2[i], x_2[idx]) - yy_2 = np.minimum(y_2[i], y_2[idx]) - width = np.maximum(0.0, xx_2-xx_1+1) - height = np.maximum(0.0, yy_2-yy_1+1) - inter = width * height - if method == 'Min': - var_o = inter / np.minimum(area[i], area[idx]) + def __init__(self, + weights_path: list[str], + batch_size: int, + input_size: int = 640, + minsize: int = 20, + threshold: list[float] | None = None, + factor: float = 0.709) -> None: + logger.debug(parse_class_init(locals())) + threshold = [0.6, 0.7, 0.7] if threshold is None else threshold + self._pnet = PNet(weights_path[0], + batch_size, + input_size, + minsize, + factor, + threshold[0]) + self._rnet = RNet(weights_path[1], + batch_size, + input_size, + threshold[1]) + self._onet = ONet(weights_path[2], + batch_size, + input_size, + threshold[2]) + logger.debug("Initialized: %s", self.__class__.__name__) + + def detect_faces(self, batch: np.ndarray) -> tuple[np.ndarray, tuple[np.ndarray]]: + """Detects faces in an image, and returns bounding boxes and points for them. + + Parameters + ---------- + batch: :class:`numpy.ndarray` + The input batch of images to detect face in + + Returns + ------- + List + list of numpy arrays containing the bounding box and 5 point landmarks + of detected faces + """ + rectangles = self._pnet(batch) + rectangles = self._rnet(batch, rectangles) + + ret_boxes, ret_points = zip(*self._onet(batch, rectangles)) + return np.array(ret_boxes, dtype="object"), ret_points + + +def nms(rectangles: np.ndarray, + scores: np.ndarray, + threshold: float, + method: str = "iom") -> tuple[np.ndarray, np.ndarray]: + """ apply non-maximum suppression on ROIs in same scale(matrix version) + + Parameters + ---------- + rectangles: :class:`np.ndarray` + The [b, l, t, r, b] bounding box detection candidates + threshold: float + Threshold for succesful match + method: str, optional + "iom" method or default. Defalt: "iom" + + Returns + ------- + rectangles: :class:`np.ndarray` + The [b, l, t, r, b] bounding boxes + scores :class:`np.ndarray` + The associated scores for the rectangles + + """ + if not np.any(rectangles): + return rectangles, scores + bboxes = rectangles[..., :4].T + area = np.multiply(bboxes[2] - bboxes[0] + 1, bboxes[3] - bboxes[1] + 1) + s_sort = scores.argsort() + + pick = [] + while len(s_sort) > 0: + s_bboxes = np.concatenate([ # s_sort[-1] have highest prob score, s_sort[0:-1]->others + np.maximum(bboxes[:2, s_sort[-1], None], bboxes[:2, s_sort[0:-1]]), + np.minimum(bboxes[2:, s_sort[-1], None], bboxes[2:, s_sort[0:-1]])], axis=0) + + inter = (np.maximum(0.0, s_bboxes[2] - s_bboxes[0] + 1) * + np.maximum(0.0, s_bboxes[3] - s_bboxes[1] + 1)) + + if method == "iom": + var_o = inter / np.minimum(area[s_sort[-1]], area[s_sort[0:-1]]) else: - var_o = inter / (area[i] + area[idx] - inter) - s_sort = s_sort[np.where(var_o <= threshold)] - pick = pick[0:counter] - return pick - - -# function [d_y ed_y d_x ed_x y e_y x e_x tmp_width tmp_height] = pad(total_boxes,width,height) -def pad(total_boxes, width, height): - """Compute the padding coordinates (pad the bounding boxes to square)""" - tmp_width = (total_boxes[:, 2] - total_boxes[:, 0] + 1).astype(np.int32) - tmp_height = (total_boxes[:, 3] - total_boxes[:, 1] + 1).astype(np.int32) - numbox = total_boxes.shape[0] - - d_x = np.ones((numbox), dtype=np.int32) - d_y = np.ones((numbox), dtype=np.int32) - ed_x = tmp_width.copy().astype(np.int32) - ed_y = tmp_height.copy().astype(np.int32) - - dim_x = total_boxes[:, 0].copy().astype(np.int32) - dim_y = total_boxes[:, 1].copy().astype(np.int32) - e_x = total_boxes[:, 2].copy().astype(np.int32) - e_y = total_boxes[:, 3].copy().astype(np.int32) - - tmp = np.where(e_x > width) - ed_x.flat[tmp] = np.expand_dims(-e_x[tmp] + width + tmp_width[tmp], 1) - e_x[tmp] = width - - tmp = np.where(e_y > height) - ed_y.flat[tmp] = np.expand_dims(-e_y[tmp] + height + tmp_height[tmp], 1) - e_y[tmp] = height - - tmp = np.where(dim_x < 1) - d_x.flat[tmp] = np.expand_dims(2 - dim_x[tmp], 1) - dim_x[tmp] = 1 - - tmp = np.where(dim_y < 1) - d_y.flat[tmp] = np.expand_dims(2 - dim_y[tmp], 1) - dim_y[tmp] = 1 - - return d_y, ed_y, d_x, ed_x, dim_y, e_y, dim_x, e_x, tmp_width, tmp_height - - -# function [bbox_a] = rerec(bbox_a) -def rerec(bbox_a): - """Convert bbox_a to square.""" - height = bbox_a[:, 3]-bbox_a[:, 1] - width = bbox_a[:, 2]-bbox_a[:, 0] - length = np.maximum(width, height) - bbox_a[:, 0] = bbox_a[:, 0] + width * 0.5 - length * 0.5 - bbox_a[:, 1] = bbox_a[:, 1] + height * 0.5 - length * 0.5 - bbox_a[:, 2:4] = bbox_a[:, 0:2] + np.transpose(np.tile(length, (2, 1))) - return bbox_a - - -def imresample(img, size): - """ Resample image """ - # pylint: disable=no-member - im_data = cv2.resize(img, (size[1], size[0]), - interpolation=cv2.INTER_AREA) # @UndefinedVariable - return im_data + var_o = inter / (area[s_sort[-1]] + area[s_sort[0:-1]] - inter) + + pick.append(s_sort[-1]) + s_sort = s_sort[np.where(var_o <= threshold)[0]] + + result_rectangle = rectangles[pick] + result_scores = scores[pick] + return result_rectangle, result_scores + + +def rect2square(rectangles: np.ndarray) -> np.ndarray: + """ change rectangles into squares (matrix version) + + Parameters + ---------- + rectangles: :class:`numpy.ndarray` + [b, x, y, x1, y1] rectangles + + Return + ------ + list + Original rectangle changed to a square + """ + width = rectangles[:, 2] - rectangles[:, 0] + height = rectangles[:, 3] - rectangles[:, 1] + length = np.maximum(width, height).T + rectangles[:, 0] = rectangles[:, 0] + width * 0.5 - length * 0.5 + rectangles[:, 1] = rectangles[:, 1] + height * 0.5 - length * 0.5 + rectangles[:, 2:4] = rectangles[:, 0:2] + np.repeat([length], 2, axis=0).T + return rectangles + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/detect/mtcnn_defaults.py b/plugins/extract/detect/mtcnn_defaults.py new file mode 100755 index 0000000000..8c4517b7ec --- /dev/null +++ b/plugins/extract/detect/mtcnn_defaults.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Mtcnn Detect plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "MTCNN Detector options.\n" + "Fast on GPU, slow on CPU. Uses fewer resources than other GPU detectors but can often return " + "more false positives." +) + + +minsize = ConfigItem( + datatype=int, + default=20, + group="settings", + info="The minimum size of a face (in pixels) to be accepted as a positive match." + "\nLower values use significantly more VRAM and will detect more false positives.", + rounding=10, + min_max=(20, 1000)) + +scalefactor = ConfigItem( + datatype=float, + default=0.709, + group="settings", + info="The scale factor for the image pyramid.", + rounding=3, + min_max=(0.1, 0.9)) + +batch_size = ConfigItem( + datatype=int, + default=8, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered.", + rounding=1, + min_max=(1, 64)) + +cpu = ConfigItem( + datatype=bool, + default=True, + group="settings", + info="MTCNN detector still runs fairly quickly on CPU on some setups. " + "Enable CPU mode here to use the CPU for this detector to save some VRAM at a " + "speed cost.") + +threshold_1 = ConfigItem( + datatype=float, + default=0.6, + group="threshold", + info="First stage threshold for face detection. This stage obtains face candidates.", + rounding=2, + min_max=(0.1, 0.9)) + +threshold_2 = ConfigItem( + datatype=float, + default=0.7, + group="threshold", + info="Second stage threshold for face detection. This stage refines face candidates.", + rounding=2, + min_max=(0.1, 0.9)) + +threshold_3 = ConfigItem( + datatype=float, + default=0.7, + group="threshold", + info="Third stage threshold for face detection. This stage further refines face " + "candidates.", + rounding=2, + min_max=(0.1, 0.9)) diff --git a/plugins/extract/detect/s3fd.py b/plugins/extract/detect/s3fd.py new file mode 100644 index 0000000000..43a5a4822f --- /dev/null +++ b/plugins/extract/detect/s3fd.py @@ -0,0 +1,555 @@ +#!/usr/bin/env python3 +""" S3FD Face detection plugin +https://arxiv.org/abs/1708.05237 + +Adapted from S3FD Port in FAN: +https://github.com/1adrianb/face-alignment +""" +from __future__ import annotations +import logging +import typing as T + +from scipy.special import logsumexp +import numpy as np + +from keras.layers import (Concatenate, Conv2D, Input, Layer, Maximum, MaxPooling2D, ZeroPadding2D) +from keras.models import Model +from keras import initializers, ops + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from ._base import BatchType, Detector +from . import s3fd_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class Detect(Detector): + """ S3FD detector for face recognition """ + def __init__(self, **kwargs) -> None: + git_model_id = 11 + model_filename = "s3fd_keras_v2.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.model: S3fd + self.name = "S3FD" + self.input_size = 640 + self.vram = 1088 # 1034 in testing + self.vram_per_batch = 960 # 922 in testing + self.batchsize = cfg.batch_size() + + def init_model(self) -> None: + """ Initialize S3FD Model""" + assert isinstance(self.model_path, str) + confidence = cfg.confidence() / 100 + self.model = S3fd(self.model_path, self.batchsize, confidence) + placeholder_shape = (self.batchsize, self.input_size, self.input_size, 3) + placeholder = np.zeros(placeholder_shape, dtype="float32") + self.model(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detection image(s) for prediction """ + assert isinstance(self.model, S3fd) + batch.feed = self.model.prepare_batch(np.array(batch.image)) + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + assert isinstance(self.model, S3fd) + predictions = self.model(feed) + assert isinstance(predictions, list) + return self.model.finalize_predictions(predictions) + + def process_output(self, batch) -> None: + """ Compile found faces for output """ + return + + +################################################################################ +# CUSTOM KERAS LAYERS +################################################################################ +class L2Norm(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ L2 Normalization layer for S3FD. + + Parameters + ---------- + n_channels: int + The number of channels to normalize + scale: float, optional + The scaling for initial weights. Default: `1.0` + """ + def __init__(self, n_channels: int, scale: float = 1.0, **kwargs) -> None: + super().__init__(**kwargs) + self._n_channels = n_channels + self._scale = scale + self.weight = self.add_weight(name="l2norm", + shape=(self._n_channels, ), + trainable=True, + initializer=initializers.Constant(value=self._scale), + dtype="float32") + + def call(self, inputs: KerasTensor, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: + """ Call the L2 Normalization Layer. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the L2 Normalization Layer + + Returns + ------- + :class:`keras.KerasTensor`: + The output from the L2 Normalization Layer + """ + norm = ops.sqrt(ops.sum(ops.power(inputs, 2), axis=-1, keepdims=True)) + 1e-10 + var_x = inputs / norm * self.weight + return var_x + + def get_config(self) -> dict: + """ Returns the config of the layer. + + Returns + ------- + dict + The configuration for the layer + """ + config = super().get_config() + config.update({"n_channels": self._n_channels, + "scale": self._scale}) + return config + + +class SliceO2K(Layer): # pylint:disable=too-many-ancestors,abstract-method + """ Custom Keras Slice layer generated by onnx2keras. """ + def __init__(self, + starts: list[int], + ends: list[int], + axes: list[int] | None = None, + steps: list[int] | None = None, + **kwargs) -> None: + self._starts = starts + self._ends = ends + self._axes = axes + self._steps = steps + super().__init__(**kwargs) + + def _get_slices(self, dimensions: int) -> list[tuple[int, ...]]: + """ Obtain slices for the given number of dimensions. + + Parameters + ---------- + dimensions: int + The number of dimensions to obtain slices for + + Returns + ------- + list + The slices for the given number of dimensions + """ + axes = tuple(range(dimensions)) if self._axes is None else self._axes + steps = (1,) * len(axes) if self._steps is None else self._steps + assert len(axes) == len(steps) == len(self._starts) == len(self._ends) + return list(zip(axes, self._starts, self._ends, steps)) + + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: + """Computes the output shape of the layer. + + Assumes that the layer will be built to match that input shape provided. + + Parameters + ---------- + input_shape: tuple or list of tuples + Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the + layer). Shape tuples can include ``None`` for free dimensions, instead of an integer. + + Returns + ------- + tuple + An output shape tuple. + """ + in_shape = list(input_shape) + for a_x, start, end, steps in self._get_slices(len(in_shape)): + size = in_shape[a_x] + if a_x == 0: + raise AttributeError("Can not slice batch axis.") + if size is None: + if start < 0 or end < 0: + raise AttributeError("Negative slices not supported on symbolic axes") + logger.warning("Slicing symbolic axis might lead to problems.") + in_shape[a_x] = (end - start) // steps + continue + if start < 0: + start = size - start + if end < 0: + end = size - end + in_shape[a_x] = (min(size, end) - start) // steps + return tuple(in_shape) + + def call(self, inputs, **kwargs): # pylint:disable=unused-argument,arguments-differ + """This is where the layer's logic lives. + + Parameters + ---------- + inputs: Input tensor, or list/tuple of input tensors. + The input to the layer + **kwargs: Additional keyword arguments. + Required for parent class but unused + Returns + ------- + A tensor or list/tuple of tensors. + The layer output + """ + ax_map = dict((x[0], slice(*x[1:])) for x in self._get_slices(ops.ndim(inputs))) + shape = inputs.shape + slices = [(ax_map[a] if a in ax_map else slice(None)) for a in range(len(shape))] + retval = inputs[tuple(slices)] + return retval + + def get_config(self) -> dict: + """ Returns the config of the layer. + + Returns + ------- + dict + The configuration for the layer + """ + config = super().get_config() + config.update({"starts": self._starts, + "ends": self._ends, + "axes": self._axes, + "steps": self._steps}) + return config + + +class S3fd(): + """ Keras Network + + Parameters + ---------- + weights_path: str + Full path to the S3FD weights file + batch_size: int + The batch size to feed the model + confidence: float + The confidence level to accept detections at + """ + def __init__(self, weights_path: str, batch_size: int, confidence: float) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + self.confidence = confidence + self.average_img = np.array([104.0, 117.0, 123.0]) + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def conv_block(cls, + inputs: KerasTensor, + filters: int, + idx: int, + recursions: int) -> KerasTensor: + """ First round convolutions with zero padding added. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the convolution block + filters: int + The number of filters + idx: int + The layer index for naming + recursions: int + The number of recursions of the block to perform + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the convolution block + """ + name = f"conv{idx}" + var_x = inputs + for i in range(1, recursions + 1): + rec_name = f"{name}_{i}" + var_x = ZeroPadding2D(1, name=f"{rec_name}.zeropad")(var_x) + var_x = Conv2D(filters, + kernel_size=3, + strides=1, + activation="relu", + name=rec_name)(var_x) + return var_x + + @classmethod + def conv_up(cls, inputs: KerasTensor, filters: int, idx: int) -> KerasTensor: + """ Convolution up filter blocks with zero padding added. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the convolution block + filters: int + The initial number of filters + idx: int + The layer index for naming + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the convolution block + """ + name = f"conv{idx}" + var_x = inputs + for i in range(1, 3): + rec_name = f"{name}_{i}" + size = 1 if i == 1 else 3 + if i == 2: + var_x = ZeroPadding2D(1, name=f"{rec_name}.zeropad")(var_x) + var_x = Conv2D(filters * i, + kernel_size=size, + strides=i, + activation="relu", + name=rec_name)(var_x) + return var_x + + def _load_model(self, weights_path: str) -> Model: + """ Keras S3FD Model Definition, adapted from FAN pytorch implementation. + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The S3FD model + """ + input_ = Input(shape=(640, 640, 3)) + var_x = self.conv_block(input_, 64, 1, 2) + var_x = MaxPooling2D(pool_size=2, strides=2)(var_x) + + var_x = self.conv_block(var_x, 128, 2, 2) + var_x = MaxPooling2D(pool_size=2, strides=2)(var_x) + + var_x = self.conv_block(var_x, 256, 3, 3) + f3_3 = var_x + var_x = MaxPooling2D(pool_size=2, strides=2)(var_x) + + var_x = self.conv_block(var_x, 512, 4, 3) + f4_3 = var_x + var_x = MaxPooling2D(pool_size=2, strides=2)(var_x) + + var_x = self.conv_block(var_x, 512, 5, 3) + f5_3 = var_x + var_x = MaxPooling2D(pool_size=2, strides=2)(var_x) + + var_x = ZeroPadding2D(3)(var_x) + var_x = Conv2D(1024, kernel_size=3, strides=1, activation="relu", name="fc6")(var_x) + var_x = Conv2D(1024, kernel_size=1, strides=1, activation="relu", name="fc7")(var_x) + ffc7 = var_x + + f6_2 = self.conv_up(var_x, 256, 6) + f7_2 = self.conv_up(f6_2, 128, 7) + + f3_3 = L2Norm(256, scale=10, name="conv3_3_norm")(f3_3) + f4_3 = L2Norm(512, scale=8, name="conv4_3_norm")(f4_3) + f5_3 = L2Norm(512, scale=5, name="conv5_3_norm")(f5_3) + + classes = [] + regs = [] + + f3_3 = ZeroPadding2D(1)(f3_3) + classes.append(Conv2D(4, kernel_size=3, strides=1, name="conv3_3_norm_mbox_conf")(f3_3)) + regs.append(Conv2D(4, kernel_size=3, strides=1, name="conv3_3_norm_mbox_loc")(f3_3)) + + f4_3 = ZeroPadding2D(1)(f4_3) + classes.append(Conv2D(2, kernel_size=3, strides=1, name="conv4_3_norm_mbox_conf")(f4_3)) + regs.append(Conv2D(4, kernel_size=3, strides=1, name="conv4_3_norm_mbox_loc")(f4_3)) + + f5_3 = ZeroPadding2D(1)(f5_3) + classes.append(Conv2D(2, kernel_size=3, strides=1, name="conv5_3_norm_mbox_conf")(f5_3)) + regs.append(Conv2D(4, kernel_size=3, strides=1, name="conv5_3_norm_mbox_loc")(f5_3)) + + ffc7 = ZeroPadding2D(1)(ffc7) + classes.append(Conv2D(2, kernel_size=3, strides=1, name="fc7_mbox_conf")(ffc7)) + regs.append(Conv2D(4, kernel_size=3, strides=1, name="fc7_mbox_loc")(ffc7)) + + f6_2 = ZeroPadding2D(1)(f6_2) + classes.append(Conv2D(2, kernel_size=3, strides=1, name="conv6_2_mbox_conf")(f6_2)) + regs.append(Conv2D(4, kernel_size=3, strides=1, name="conv6_2_mbox_loc")(f6_2)) + + f7_2 = ZeroPadding2D(1)(f7_2) + classes.append(Conv2D(2, kernel_size=3, strides=1, name="conv7_2_mbox_conf")(f7_2)) + regs.append(Conv2D(4, kernel_size=3, strides=1, name="conv7_2_mbox_loc")(f7_2)) + + # max-out background label + chunks = [SliceO2K(starts=[0], ends=[1], axes=[3], steps=None)(classes[0]), + SliceO2K(starts=[1], ends=[2], axes=[3], steps=None)(classes[0]), + SliceO2K(starts=[2], ends=[3], axes=[3], steps=None)(classes[0]), + SliceO2K(starts=[3], ends=[4], axes=[3], steps=None)(classes[0])] + + bmax = Maximum()([chunks[0], chunks[1], chunks[2]]) + classes[0] = Concatenate()([bmax, chunks[3]]) + + retval = Model(input_, + [classes[0], + regs[0], + classes[1], + regs[1], + classes[2], + regs[2], + classes[3], + regs[3], + classes[4], + regs[4], + classes[5], + regs[5]]) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def prepare_batch(self, batch: np.ndarray) -> np.ndarray: + """ Prepare a batch for prediction. + + Normalizes the feed images. + + Parameters + ---------- + batch: class:`numpy.ndarray` + The batch to be fed to the model + + Returns + ------- + class:`numpy.ndarray` + The normalized images for feeding to the model + """ + batch = batch - self.average_img + return batch + + def finalize_predictions(self, bounding_boxes_scales: list[np.ndarray]) -> np.ndarray: + """ Process the output from the model to obtain faces + + Parameters + ---------- + bounding_boxes_scales: list + The output predictions from the S3FD model + """ + ret = [] + batch_size = range(bounding_boxes_scales[0].shape[0]) + for img in batch_size: + bboxlist = [scale[img:img+1] for scale in bounding_boxes_scales] + boxes = self._post_process(bboxlist) + finallist = self._nms(boxes, 0.5) + ret.append(finallist) + return np.array(ret, dtype="object") + + def _process_bbox(self, + ocls: np.ndarray, + oreg: np.ndarray, + stride: int) -> list[list[np.ndarray]]: + """ Process a bounding box """ + retval = [] + for pos in zip(*np.where(ocls[:, :, :, 1] > 0.05)): + a_c = stride / 2 + pos[2] * stride, stride / 2 + pos[1] * stride + score = ocls[0, pos[1], pos[2], 1] + if score >= self.confidence: + loc = np.ascontiguousarray(oreg[0, pos[1], pos[2], :]).reshape((1, 4)) + priors = np.array([[a_c[0] / 1.0, + a_c[1] / 1.0, + stride * 4 / 1.0, + stride * 4 / 1.0]]) + box = self.decode(loc, priors) + x_1, y_1, x_2, y_2 = box[0] * 1.0 + retval.append([x_1, y_1, x_2, y_2, score]) + return retval + + def _post_process(self, bboxlist: list[np.ndarray]) -> np.ndarray: + """ Perform post processing on output + TODO: do this on the batch. + """ + retval = [] + for i in range(len(bboxlist) // 2): + bboxlist[i * 2] = self.softmax(bboxlist[i * 2], axis=3) + for i in range(len(bboxlist) // 2): + ocls, oreg = bboxlist[i * 2], bboxlist[i * 2 + 1] + stride = 2 ** (i + 2) # 4,8,16,32,64,128 + retval.extend(self._process_bbox(ocls, oreg, stride)) + + return_numpy = np.array(retval) if len(retval) != 0 else np.zeros((1, 5)) + return return_numpy + + @staticmethod + def softmax(inp, axis: int) -> np.ndarray: + """Compute softmax values for each sets of scores in x.""" + return np.exp(inp - logsumexp(inp, axis=axis, keepdims=True)) + + @staticmethod + def decode(location: np.ndarray, priors: np.ndarray) -> np.ndarray: + """Decode locations from predictions using priors to undo the encoding we did for offset + regression at train time. + + Parameters + ---------- + location: tensor + location predictions for location layers, + priors: tensor + Prior boxes in center-offset form. + + Returns + ------- + :class:`numpy.ndarray` + decoded bounding box predictions + """ + variances = [0.1, 0.2] + boxes = np.concatenate((priors[:, :2] + location[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * np.exp(location[:, 2:] * variances[1])), axis=1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + @staticmethod + def _nms(boxes: np.ndarray, threshold: float) -> np.ndarray: + """ Perform Non-Maximum Suppression """ + retained_box_indices = [] + + areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1) + ranked_indices = boxes[:, 4].argsort()[::-1] + while ranked_indices.size > 0: + best_rest = ranked_indices[0], ranked_indices[1:] + + max_of_xy = np.maximum(boxes[best_rest[0], :2], boxes[best_rest[1], :2]) + min_of_xy = np.minimum(boxes[best_rest[0], 2:4], boxes[best_rest[1], 2:4]) + width_height = np.maximum(0, min_of_xy - max_of_xy + 1) + intersection_areas = width_height[:, 0] * width_height[:, 1] + iou = intersection_areas / (areas[best_rest[0]] + + areas[best_rest[1]] - intersection_areas) + + overlapping_boxes = (iou > threshold).nonzero()[0] + if len(overlapping_boxes) != 0: + overlap_set = ranked_indices[overlapping_boxes + 1] + vote = np.average(boxes[overlap_set, :4], axis=0, weights=boxes[overlap_set, 4]) + boxes[best_rest[0], :4] = vote + retained_box_indices.append(best_rest[0]) + + non_overlapping_boxes = (iou <= threshold).nonzero()[0] + ranked_indices = ranked_indices[non_overlapping_boxes + 1] + return boxes[retained_box_indices] + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Get predictions from the S3FD model + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input to S3FD + + Returns + ------- + :class:`numpy.ndarray` + The output from S3FD + """ + return self._model.predict(inputs, verbose=0, batch_size=self._batch_size) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/detect/s3fd_defaults.py b/plugins/extract/detect/s3fd_defaults.py new file mode 100755 index 0000000000..1ecf1948d8 --- /dev/null +++ b/plugins/extract/detect/s3fd_defaults.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap S3Fd Detect plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "S3FD Detector options.\n" + "Fast on GPU, slow on CPU. Can detect more faces and fewer false positives than other GPU " + "detectors, but is a lot more resource intensive." + ) + + +confidence = ConfigItem( + datatype=int, + default=70, + group="settings", + info="The confidence level at which the detector has succesfully found a face.\n" + "Higher levels will be more discriminating, lower levels will have more false " + "positives.", + rounding=5, + min_max=(25, 100)) + +batch_size = ConfigItem( + datatype=int, + default=4, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered." + "\n\tAMD users: A batchsize of 8 requires about 2 GB vram.", + rounding=1, + min_max=(1, 64)) diff --git a/plugins/extract/extract_config.py b/plugins/extract/extract_config.py new file mode 100644 index 0000000000..2361864cc4 --- /dev/null +++ b/plugins/extract/extract_config.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" Default configurations for extract """ + +import gettext +import logging +import os + +from lib.config import FaceswapConfig +from lib.config import ConfigItem + +# LOCALES +_LANG = gettext.translation("plugins.extract.extract_config", localedir="locales", fallback=True) +_ = _LANG.gettext + +logger = logging.getLogger(__name__) + + +class _Config(FaceswapConfig): + """ Config File for Extraction """ + + def set_defaults(self, helptext="") -> None: + """ Set the default values for config """ + super().set_defaults(helptext=_("Options that apply to all extraction plugins")) + self._defaults_from_plugin(os.path.dirname(__file__)) + + +aligner_min_scale = ConfigItem( + datatype=float, + default=0.03, + group=_("filters"), + info=_( + "Filters out faces below this size. This is a multiplier of the minimum dimension of " + "the frame (i.e. 1280x720 = 720). If the original face extract box is smaller than " + "the minimum dimension times this multiplier, it is considered a false positive and " + "discarded. Faces which are found to be unusually smaller than the frame tend to be " + "misaligned images, except in extreme long-shots. These can be usually be safely " + "discarded."), + min_max=(0.0, 1.0), + rounding=2) + + +aligner_max_scale = ConfigItem( + datatype=float, + default=4.00, + group=_("filters"), + info=_( + "Filters out faces above this size. This is a multiplier of the minimum dimension of " + "the frame (i.e. 1280x720 = 720). If the original face extract box is larger than the " + "minimum dimension times this multiplier, it is considered a false positive and " + "discarded. Faces which are found to be unusually larger than the frame tend to be " + "misaligned images except in extreme close-ups. These can be usually be safely " + "discarded."), + min_max=(0.0, 10.0), + rounding=2) + + +aligner_distance = ConfigItem( + datatype=float, + default=40.0, + group=_("filters"), + info=_( + "Filters out faces who's landmarks are above this distance from an 'average' face. " + "Values above 15 tend to be fairly safe. Values above 10 will remove more false " + "positives, but may also filter out some faces at extreme angles."), + min_max=(0.0, 45.0), + rounding=1) + + +aligner_roll = ConfigItem( + datatype=float, + default=0.0, + group=_("filters"), + info=_( + "Filters out faces who's calculated roll is greater than zero +/- this value in " + "degrees. Aligned faces should have a roll value close to zero. Values that are a " + "significant distance from 0 degrees tend to be misaligned images. These can usually " + "be safely disgarded."), + min_max=(0.0, 90.0), + rounding=1) + + +aligner_features = ConfigItem( + datatype=bool, + default=True, + group=_("filters"), + info=_( + "Filters out faces where the lowest point of the aligned face's eye or eyebrow is " + "lower than the highest point of the aligned face's mouth. Any faces where this " + "occurs are misaligned and can be safely disgarded.")) + + +filter_refeed = ConfigItem( + datatype=bool, + default=True, + group=_("filters"), + info=_( + "If enabled, and 're-feed' has been selected for extraction, then interim alignments " + "will be filtered prior to averaging the final landmarks. This can help improve the " + "final alignments by removing any obvious misaligns from the interim results, and may " + "also help pick up difficult alignments. If disabled, then all re-feed results will " + "be averaged.")) + + +save_filtered = ConfigItem( + datatype=bool, + default=False, + group=_("filters"), + info=_( + "If enabled, saves any filtered out images into a sub-folder during the extraction " + "process. If disabled, filtered faces are deleted. Note: The faces will always be " + "filtered out of the alignments file, regardless of whether you keep the faces or " + "not.")) + + +realign_refeeds = ConfigItem( + datatype=bool, + default=True, + group=_("re-align"), + info=_( + "If enabled, and 're-align' has been selected for extraction, then all re-feed " + "iterations are re-aligned. If disabled, then only the final averaged output from re-" + "feed will be re-aligned.")) + + +filter_realign = ConfigItem( + datatype=bool, + default=True, + group=_("re-align"), + info=_( + "If enabled, and 're-align' has been selected for extraction, then any alignments " + "which would be filtered out will not be re-aligned.")) + + +# pylint:disable=duplicate-code +_IS_LOADED: bool = False + + +def load_config(config_file: str | None = None) -> None: + """ Load the Extraction configuration .ini file + + Parameters + ---------- + config_file : str | None, optional + Path to a custom .ini configuration file to load. Default: ``None`` (use default + configuration file) + """ + global _IS_LOADED # pylint:disable=global-statement + if not _IS_LOADED: + _Config(configfile=config_file) + _IS_LOADED = True diff --git a/plugins/extract/extract_media.py b/plugins/extract/extract_media.py new file mode 100644 index 0000000000..22700914ee --- /dev/null +++ b/plugins/extract/extract_media.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" Object for holding and manipulating media passing through a faceswap extraction pipeline """ +from __future__ import annotations +import logging +import typing as T + +import cv2 + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + import numpy as np + from lib.align.alignments import PNGHeaderSourceDict + from lib.align.detected_face import DetectedFace + +logger = logging.getLogger(__name__) + + +class ExtractMedia: + """ An object that passes through the :class:`~plugins.extract.pipeline.Extractor` pipeline. + + Parameters + ---------- + filename: str + The base name of the original frame's filename + image: :class:`numpy.ndarray` + The original frame or a faceswap aligned face image + detected_faces: list, optional + A list of :class:`~lib.align.DetectedFace` objects. Detected faces can be added + later with :func:`add_detected_faces`. Setting ``None`` will default to an empty list. + Default: ``None`` + is_aligned: bool, optional + ``True`` if the :attr:`image` is an aligned faceswap image otherwise ``False``. Used for + face filtering with vggface2. Aligned faceswap images will automatically skip detection, + alignment and masking. Default: ``False`` + """ + + def __init__(self, + filename: str, + image: np.ndarray, + detected_faces: list[DetectedFace] | None = None, + is_aligned: bool = False) -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self._filename = filename + self._image: np.ndarray | None = image + self._image_shape = T.cast(tuple[int, int, int], image.shape) + self._detected_faces: list[DetectedFace] = ([] if detected_faces is None + else detected_faces) + self._is_aligned = is_aligned + self._frame_metadata: PNGHeaderSourceDict | None = None + self._sub_folders: list[str | None] = [] + + @property + def filename(self) -> str: + """ str: The base name of the :attr:`image` filename. """ + return self._filename + + @property + def image(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The source frame for this object. """ + assert self._image is not None + return self._image + + @property + def image_shape(self) -> tuple[int, int, int]: + """ tuple: The shape of the stored :attr:`image`. """ + return self._image_shape + + @property + def image_size(self) -> tuple[int, int]: + """ tuple: The (`height`, `width`) of the stored :attr:`image`. """ + return self._image_shape[:2] + + @property + def detected_faces(self) -> list[DetectedFace]: + """list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """ + return self._detected_faces + + @property + def is_aligned(self) -> bool: + """ bool. ``True`` if :attr:`image` is an aligned faceswap image otherwise ``False`` """ + return self._is_aligned + + @property + def frame_metadata(self) -> PNGHeaderSourceDict: + """ dict: The frame metadata that has been added from an aligned image. This property + should only be called after :func:`add_frame_metadata` has been called when processing + an aligned face. For all other instances an assertion error will be raised. + + Raises + ------ + AssertionError + If frame metadata has not been populated from an aligned image + """ + assert self._frame_metadata is not None + return self._frame_metadata + + @property + def sub_folders(self) -> list[str | None]: + """ list: The sub_folders that the faces should be output to. Used when binning filter + output is enabled. The list corresponds to the list of detected faces + """ + return self._sub_folders + + def get_image_copy(self, color_format: T.Literal["BGR", "RGB", "GRAY"]) -> np.ndarray: + """ Get a copy of the image in the requested color format. + + Parameters + ---------- + color_format: ['BGR', 'RGB', 'GRAY'] + The requested color format of :attr:`image` + + Returns + ------- + :class:`numpy.ndarray`: + A copy of :attr:`image` in the requested :attr:`color_format` + """ + logger.trace("Requested color format '%s' for frame '%s'", # type:ignore[attr-defined] + color_format, self._filename) + image = getattr(self, f"_image_as_{color_format.lower()}")() + return image + + def add_detected_faces(self, faces: list[DetectedFace]) -> None: + """ Add detected faces to the object. Called at the end of each extraction phase. + + Parameters + ---------- + faces: list + A list of :class:`~lib.align.DetectedFace` objects + """ + logger.trace("Adding detected faces for filename: '%s'. " # type:ignore[attr-defined] + "(faces: %s, lrtb: %s)", self._filename, faces, + [(face.left, face.right, face.top, face.bottom) for face in faces]) + self._detected_faces = faces + + def add_sub_folders(self, folders: list[str | None]) -> None: + """ Add detected faces to the object. Called at the end of each extraction phase. + + Parameters + ---------- + folders: list + A list of str sub folder names or ``None`` if no sub folder is required. Should + correspond to the detected faces list + """ + logger.trace("Adding sub folders for filename: '%s'. " # type:ignore[attr-defined] + "(folders: %s)", self._filename, folders,) + self._sub_folders = folders + + def remove_image(self) -> None: + """ Delete the image and reset :attr:`image` to ``None``. + + Required for multi-phase extraction to avoid the frames stacking RAM. + """ + logger.trace("Removing image for filename: '%s'", # type:ignore[attr-defined] + self._filename) + del self._image + self._image = None + + def set_image(self, image: np.ndarray) -> None: + """ Add the image back into :attr:`image` + + Required for multi-phase extraction adds the image back to this object. + + Parameters + ---------- + image: :class:`numpy.ndarry` + The original frame to be re-applied to for this :attr:`filename` + """ + logger.trace("Reapplying image: (filename: `%s`, " # type:ignore[attr-defined] + "image shape: %s)", self._filename, image.shape) + self._image = image + + def add_frame_metadata(self, metadata: PNGHeaderSourceDict) -> None: + """ Add the source frame metadata from an aligned PNG's header data. + + metadata: dict + The contents of the 'source' field in the PNG header + """ + logger.trace("Adding PNG Source data for '%s': %s", # type:ignore[attr-defined] + self._filename, metadata) + dims = T.cast(tuple[int, int], metadata["source_frame_dims"]) + self._image_shape = (*dims, 3) + self._frame_metadata = metadata + + def _image_as_bgr(self) -> np.ndarray: + """ Get a copy of the source frame in BGR format. + + Returns + ------- + :class:`numpy.ndarray`: + A copy of :attr:`image` in BGR color format """ + return self.image[..., :3].copy() + + def _image_as_rgb(self) -> np.ndarray: + """ Get a copy of the source frame in RGB format. + + Returns + ------- + :class:`numpy.ndarray`: + A copy of :attr:`image` in RGB color format """ + return self.image[..., 2::-1].copy() + + def _image_as_gray(self) -> np.ndarray: + """ Get a copy of the source frame in gray-scale format. + + Returns + ------- + :class:`numpy.ndarray`: + A copy of :attr:`image` in gray-scale color format """ + return cv2.cvtColor(self.image.copy(), cv2.COLOR_BGR2GRAY) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/__init__.py b/plugins/extract/mask/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/extract/mask/_base.py b/plugins/extract/mask/_base.py new file mode 100644 index 0000000000..62c43adca9 --- /dev/null +++ b/plugins/extract/mask/_base.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +""" Base class for Face Masker plugins + +Plugins should inherit from this class + +See the override methods for which methods are required. + +The plugin will receive a :class:`~plugins.extract.extract_media.ExtractMedia` object. + +For each source item, the plugin must pass a dict to finalize containing: + +>>> {"filename": , +>>> "detected_faces": } +""" +from __future__ import annotations +import logging +import typing as T + +from dataclasses import dataclass, field + +import cv2 +import numpy as np +from torch.cuda import OutOfMemoryError + +from lib.align import AlignedFace, LandmarkType, transform_image +from lib.utils import FaceswapError +from plugins.extract import ExtractMedia +from plugins.extract._base import BatchType, ExtractorBatch, Extractor + +if T.TYPE_CHECKING: + from collections.abc import Generator + from queue import Queue + from lib.align import DetectedFace + from lib.align.aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +@dataclass +class MaskerBatch(ExtractorBatch): + """ Dataclass for holding items flowing through the aligner. + + Inherits from :class:`~plugins.extract._base.ExtractorBatch` + + Parameters + ---------- + roi_masks: list + The region of interest masks for the batch + """ + detected_faces: list[DetectedFace] = field(default_factory=list) + roi_masks: list[np.ndarray] = field(default_factory=list) + feed_faces: list[AlignedFace] = field(default_factory=list) + + +class Masker(Extractor): # pylint:disable=abstract-method + """ Masker plugin _base Object + + All Masker plugins must inherit from this class + + Parameters + ---------- + git_model_id: int + The second digit in the github tag that identifies this model. See + https://github.com/deepfakes-models/faceswap-models for more information + model_filename: str + The name of the model file to be loaded + + Other Parameters + ---------------- + configfile: str, optional + Path to a custom configuration ``ini`` file. Default: Use system configfile + + See Also + -------- + plugins.extract.pipeline : The extraction pipeline for calling plugins + plugins.extract.align : Aligner plugins + plugins.extract._base : Parent class for all extraction plugins + plugins.extract.detect._base : Detector parent class for extraction plugins. + plugins.extract.align._base : Aligner parent class for extraction plugins. + """ + + _logged_lm_count_once = False + + def __init__(self, + git_model_id: int | None = None, + model_filename: str | None = None, + configfile: str | None = None, + instance: int = 0, + **kwargs) -> None: + # pylint:disable=duplicate-code + logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile) + super().__init__(git_model_id, + model_filename, + configfile=configfile, + instance=instance, + **kwargs) + self.input_size = 256 # Override for model specific input_size + self.coverage_ratio = 1.0 # Override for model specific coverage_ratio + + self._info.plugin_type = "mask" + # Override if a specific type of landmark data is required: + self.landmark_type: LandmarkType | None = None + + self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-") + self._storage_centering: CenteringType = "face" # Centering to store the mask at + self._storage_size = 128 # Size to store masks at. Leave this at default + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def storage_centering(self) -> CenteringType: + """ Literal["face", "head", "legacy"] : The centering that the mask is stored at """ + return self._storage_centering + + def _maybe_log_warning(self, face: AlignedFace) -> None: + """ Log a warning, once, if we do not have full facial landmarks + + Parameters + ---------- + face: :class:`~lib.align.aligned_face.AlignedFace` + The aligned face object to test the landmark type for + """ + if face.landmark_type != LandmarkType.LM_2D_4 or self._logged_lm_count_once: + return + + msg = "are likely to be sub-standard" + msg = "can not be be generated" if self.name in ("Components", "Extended") else msg + + logger.warning("Extracted faces do not contain facial landmark data. '%s' masks %s.", + self.name, msg) + self._logged_lm_count_once = True + + def get_batch(self, queue: Queue) -> tuple[bool, MaskerBatch]: + """ Get items for inputting into the masker from the queue in batches + + Items are returned from the ``queue`` in batches of + :attr:`~plugins.extract._base.Extractor.batchsize` + + Items are received as :class:`~plugins.extract.extract_media.ExtractMedia` objects and + converted to ``dict`` for internal processing. + + To ensure consistent batch sizes for masker the items are split into separate items for + each :class:`~lib.align.DetectedFace` object. + + Remember to put ``'EOF'`` to the out queue after processing + the final batch + + Outputs items in the following format. All lists are of length + :attr:`~plugins.extract._base.Extractor.batchsize`: + + >>> {'filename': [], + >>> 'detected_faces': [[ MaskerBatch: + """ Just return the masker's predict function """ + assert isinstance(batch, MaskerBatch) + assert self.name is not None + # slightly hacky workaround to deal with landmarks based masks: + if self.name.lower() in ("components", "extended"): + feed = np.empty(2, dtype="object") + feed[0] = batch.feed + feed[1] = batch.feed_faces + else: + feed = batch.feed + + try: + batch.prediction = self.predict(feed) + except OutOfMemoryError as err: + msg = ("You do not have enough GPU memory available to run detection at the " + "selected batch size. You can try a number of things:" + "\n1) Close any other application that is using your GPU (web browsers are " + "particularly bad for this)." + "\n2) Lower the batchsize (the amount of images fed into the model) by " + "editing the plugin settings (GUI: Settings > Configure extract settings, " + "CLI: Edit the file faceswap/config/extract.ini)." + "\n3) Enable 'Single Process' mode.") + raise FaceswapError(msg) from err + + return batch + + def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]: + """ Finalize the output from Masker + + This should be called as the final task of each `plugin`. + + Pairs the detected faces back up with their original frame before yielding each frame. + + Parameters + ---------- + batch : dict + The final ``dict`` from the `plugin` process. It must contain the `keys`: + ``detected_faces``, ``filename``, ``feed_faces``, ``roi_masks`` + + Yields + ------ + :class:`~plugins.extract.extract_media.ExtractMedia` + The :attr:`DetectedFaces` list will be populated for this class with the bounding + boxes, landmarks and masks for the detected faces found in the frame. + """ + assert isinstance(batch, MaskerBatch) + for mask, face, feed_face, roi_mask in zip(batch.prediction, + batch.detected_faces, + batch.feed_faces, + batch.roi_masks): + if self.name in ("Components", "Extended") and not np.any(mask): + # Components/Extended masks can return empty when called from the manual tool with + # 4 Point ROI landmarks + continue + self._crop_out_of_bounds(mask, roi_mask) + face.add_mask(self._storage_name, + mask, + feed_face.adjusted_matrix, + feed_face.interpolators[1], + storage_size=self._storage_size, + storage_centering=self._storage_centering) + del batch.feed + + logger.trace("Item out: %s", # type: ignore + {key: val.shape if isinstance(val, np.ndarray) else val + for key, val in batch.__dict__.items()}) + for filename, face in zip(batch.filename, batch.detected_faces): + self._tracker.output_faces.append(face) + if len(self._tracker.output_faces) != self._tracker.faces_per_filename[filename]: + continue + + output = self._extract_media.pop(filename) + output.add_detected_faces(self._tracker.output_faces) + self._tracker.output_faces = [] + logger.trace("Yielding: (filename: '%s', image: %s, " # type:ignore[attr-defined] + "detected_faces: %s)", output.filename, output.image_shape, + len(output.detected_faces)) + yield output + + # <<< PROTECTED ACCESS METHODS >>> # + @classmethod + def _resize(cls, image: np.ndarray, target_size: int) -> np.ndarray: + """ resize input and output of mask models appropriately """ + height, width, channels = image.shape + image_size = max(height, width) + scale = target_size / image_size + if scale == 1.: + return image + method = cv2.INTER_CUBIC if scale > 1. else cv2.INTER_AREA + resized = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=method) + resized = resized if channels > 1 else resized[..., None] + return resized + + @classmethod + def _crop_out_of_bounds(cls, mask: np.ndarray, roi_mask: np.ndarray) -> None: + """ Un-mask any area of the predicted mask that falls outside of the original frame. + + Parameters + ---------- + masks: :class:`numpy.ndarray` + The predicted masks from the plugin + roi_mask: :class:`numpy.ndarray` + The roi mask. In frame is white, out of frame is black + """ + if np.all(roi_mask): + return # The whole of the face is within the frame + roi_mask = roi_mask[..., None] if mask.ndim == 3 else roi_mask + mask *= roi_mask diff --git a/plugins/extract/mask/bisenet_fp.py b/plugins/extract/mask/bisenet_fp.py new file mode 100644 index 0000000000..8e95638eb6 --- /dev/null +++ b/plugins/extract/mask/bisenet_fp.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +""" BiSeNet Face-Parsing mask plugin + +Architecture and Pre-Trained Model ported from PyTorch to Keras by TorzDF from +https://github.com/zllrunning/face-parsing.PyTorch +""" +from __future__ import annotations +import logging +import typing as T + +import numpy as np + +import keras.backend as K +from keras.layers import ( + Activation, Add, BatchNormalization, Concatenate, Conv2D, GlobalAveragePooling2D, Input, + MaxPooling2D, Multiply, Reshape, UpSampling2D, ZeroPadding2D) +from keras.models import Model + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from plugins.extract.extract_config import load_config +from ._base import BatchType, Masker, MaskerBatch +from . import bisenet_fp_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class Mask(Masker): # pylint:disable=too-many-instance-attributes + """ Neural network to process face image into a segmentation mask of the face """ + def __init__(self, **kwargs) -> None: + # We need access to user config prior to parent being initialized to correctly set the + # model filename + load_config(kwargs.get("configfile")) + self._is_faceswap, version = self._check_weights_selection() + + git_model_id = 14 + model_filename = f"bisnet_face_parsing_v{version}.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + + self.model: BiSeNet + self.name = "BiSeNet - Face Parsing" + self.input_size = 512 + self.color_format = "RGB" + self.vram = 384 if not cfg.cpu() else 0 # 378 in testing + self.vram_per_batch = 384 if not cfg.cpu() else 0 # ~328 in testing + self.batchsize = cfg.batch_size() + + self._segment_indices = self._get_segment_indices() + self._storage_centering = "head" if cfg.include_hair() else "face" + """ Literal["head", "face"] The mask type/storage centering to use """ + # Separate storage for face and head masks + self._storage_name = f"{self._storage_name}_{self._storage_centering}" + + def _check_weights_selection(self) -> tuple[bool, int]: + """ Check which weights have been selected. + + This is required for passing along the correct file name for the corresponding weights + selection. + + Returns + ------- + is_faceswap : bool + ``True`` if `faceswap` trained weights have been selected. ``False`` if `original` + weights have been selected. + version : int + ``1`` for non-faceswap, ``2`` if faceswap and full-head model is required. ``3`` if + faceswap and full-face is required + """ + is_faceswap = cfg.weights() == "faceswap" + version = 1 if not is_faceswap else 2 if cfg.include_hair() else 3 + return is_faceswap, version + + def _get_segment_indices(self) -> list[int]: + """ Obtain the segment indices to include within the face mask area based on user + configuration settings. + + Returns + ------- + list + The segment indices to include within the face mask area + + Notes + ----- + 'original' Model segment indices: + 0: background, 1: skin, 2: left brow, 3: right brow, 4: left eye, 5: right eye, 6: glasses + 7: left ear, 8: right ear, 9: earing, 10: nose, 11: mouth, 12: upper lip, 13: lower_lip, + 14: neck, 15: neck ?, 16: cloth, 17: hair, 18: hat + + 'faceswap' Model segment indices: + 0: background, 1: skin, 2: ears, 3: hair, 4: glasses + """ + retval = [1] if self._is_faceswap else [1, 2, 3, 4, 5, 10, 11, 12, 13] + + if cfg.include_glasses(): + retval.append(4 if self._is_faceswap else 6) + if cfg.include_ears(): + retval.extend([2] if self._is_faceswap else [7, 8, 9]) + if cfg.include_hair(): + retval.append(3 if self._is_faceswap else 17) + logger.debug("Selected segment indices: %s", retval) + return retval + + def init_model(self) -> None: + """ Initialize the BiSeNet Face Parsing model. """ + assert isinstance(self.model_path, str) + lbls = 5 if self._is_faceswap else 19 + placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3), + dtype="float32") + + with self.get_device_context(cfg.cpu()): + self.model = BiSeNet(self.model_path, self.batchsize, self.input_size, lbls) + self.model(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + assert isinstance(batch, MaskerBatch) + mean = (0.384, 0.314, 0.279) if self._is_faceswap else (0.485, 0.456, 0.406) + std = (0.324, 0.286, 0.275) if self._is_faceswap else (0.229, 0.224, 0.225) + + batch.feed = ((np.array([T.cast(np.ndarray, feed.face)[..., :3] + for feed in batch.feed_faces], + dtype="float32") / 255.0) - mean) / std + logger.trace("feed shape: %s", batch.feed.shape) # type:ignore[attr-defined] + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + with self.get_device_context(cfg.cpu()): + return self.model(feed)[0] + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + pred = batch.prediction.argmax(-1).astype("uint8") + batch.prediction = np.isin(pred, self._segment_indices).astype("float32") + +# BiSeNet Face-Parsing Model + +# MIT License + +# Copyright (c) 2019 zll + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +_NAME_TRACKER: set[str] = set() + + +def _get_name(name: str, start_idx: int = 1) -> str: + """ Auto numbering to keep track of layer names. + + Names are kept the same as the PyTorch original model, to enable easier porting of weights. + + Names are tracked and auto-appended with an integer to ensure they are unique. + + Parameters + ---------- + name: str + The name of the layer to get auto named. + start_idx + The first index number to start auto naming layers with the same name. Usually 0 or 1. + Pass -1 if the name should not be auto-named (i.e. should not have an integer appended + to the end) + + Returns + ------- + str + A unique version of the original name + """ + i = start_idx + while True: + retval = f"{name}{i}" if i != -1 else name + if retval not in _NAME_TRACKER: + break + i += 1 + _NAME_TRACKER.add(retval) + return retval + + +class ConvBn(): + """ Convolutional 3D with Batch Normalization block. + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution). + kernel_size: int, optional + The height and width of the 2D convolution window. Default: `3` + strides: int, optional + The strides of the convolution along the height and width. Default: `1` + padding: int, optional + The amount of padding to apply prior to the first Convolutional Layer. Default: `1` + activation: bool + Whether to include ReLu Activation at the end of the block. Default: ``True`` + prefix: str, optional + The prefix to name the layers within the block. Default: ``""`` (empty string, i.e. no + prefix) + start_idx: int, optional + The starting index for naming the layers within the block. See :func:`_get_name` for + more information. Default: `1` + """ + def __init__(self, filters: int, # pylint:disable=too-many-positional-arguments + kernel_size: int = 3, + strides: int = 1, + padding: int = 1, + activation: int = True, + prefix: str = "", + start_idx: int = 1) -> None: + self._filters = filters + self._kernel_size = kernel_size + self._strides = strides + self._padding = padding + self._activation = activation + self._prefix = f"{prefix}-" if prefix else prefix + self._start_idx = start_idx + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Convolutional Batch Normalization block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + var_x = inputs + if self._padding > 0 and self._kernel_size != 1: + var_x = ZeroPadding2D(self._padding, + name=_get_name(f"{self._prefix}zeropad", + start_idx=self._start_idx))(var_x) + padding = "valid" if self._padding != -1 else "same" + var_x = Conv2D(self._filters, + self._kernel_size, + strides=self._strides, + padding=padding, + use_bias=False, + name=_get_name(f"{self._prefix}conv", start_idx=self._start_idx))(var_x) + var_x = BatchNormalization(epsilon=1e-5, + name=_get_name(f"{self._prefix}bn", + start_idx=self._start_idx))(var_x) + if self._activation: + var_x = Activation("relu", + name=_get_name(f"{self._prefix}relu", + start_idx=self._start_idx))(var_x) + return var_x + + +class ResNet18(): + """ ResNet 18 block. Used at the start of BiSeNet Face Parsing. """ + def __init__(self): + self._feature_index = 1 if K.image_data_format() == "channels_first" else -1 + + def _basic_block(self, + inputs: KerasTensor, + prefix: str, + filters: int, + strides: int = 1) -> KerasTensor: + """ The basic building block for ResNet 18. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + prefix: str + The prefix to name the layers within the block + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution). + strides: int, optional + The strides of the convolution along the height and width. Default: `1` + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + res = ConvBn(filters, strides=strides, padding=1, prefix=prefix)(inputs) + res = ConvBn(filters, strides=1, padding=1, activation=False, prefix=prefix)(res) + + shortcut = inputs + filts = (shortcut.shape[self._feature_index], res.shape[self._feature_index]) + if strides != 1 or filts[0] != filts[1]: # Downsample + name = f"{prefix}-downsample-" + shortcut = Conv2D(filters, 1, + strides=strides, + use_bias=False, + name=_get_name(f"{name}", start_idx=0))(shortcut) + shortcut = BatchNormalization(epsilon=1e-5, + name=_get_name(f"{name}", start_idx=0))(shortcut) + + var_x = Add(name=f"{prefix}-add")([res, shortcut]) + var_x = Activation("relu", name=f"{prefix}-relu")(var_x) + return var_x + + def _basic_layer(self, # pylint:disable=too-many-positional-arguments + inputs: KerasTensor, + prefix: str, + filters: int, + num_blocks: int, + strides: int = 1) -> KerasTensor: + """ The basic layer for ResNet 18. Recursively builds from :func:`_basic_block`. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + prefix: str + The prefix to name the layers within the block + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution). + num_blocks: int + The number of basic blocks to recursively build + strides: int, optional + The strides of the convolution along the height and width. Default: `1` + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + var_x = self._basic_block(inputs, f"{prefix}-0", filters, strides=strides) + for i in range(num_blocks - 1): + var_x = self._basic_block(var_x, f"{prefix}-{i + 1}", filters, strides=1) + return var_x + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the ResNet 18 block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + var_x = ConvBn(64, kernel_size=7, strides=2, padding=3, prefix="cp-resnet")(inputs) + var_x = ZeroPadding2D(1, name="cp-resnet-zeropad")(var_x) + var_x = MaxPooling2D(pool_size=3, strides=2, name="cp-resnet-maxpool")(var_x) + + var_x = self._basic_layer(var_x, "cp-resnet-layer1", 64, 2) + feat8 = self._basic_layer(var_x, "cp-resnet-layer2", 128, 2, strides=2) + feat16 = self._basic_layer(feat8, "cp-resnet-layer3", 256, 2, strides=2) + feat32 = self._basic_layer(feat16, "cp-resnet-layer4", 512, 2, strides=2) + + return feat8, feat16, feat32 + + +class AttentionRefinementModule(): + """ The Attention Refinement block for BiSeNet Face Parsing + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution). + """ + def __init__(self, filters: int) -> None: + self._filters = filters + + def __call__(self, inputs: KerasTensor, feats: int) -> KerasTensor: + """ Call the Attention Refinement block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + feats: int + The number of features. Used for naming. + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + prefix = f"cp-arm{feats}" + feat = ConvBn(self._filters, prefix=f"{prefix}-conv", start_idx=-1, padding=-1)(inputs) + atten = GlobalAveragePooling2D(name=f"{prefix}-avgpool")(feat) + atten = Reshape((1, 1, atten.shape[-1]))(atten) + atten = Conv2D(self._filters, 1, use_bias=False, name=f"{prefix}-conv_atten")(atten) + atten = BatchNormalization(epsilon=1e-5, name=f"{prefix}-bn_atten")(atten) + atten = Activation("sigmoid", name=f"{prefix}-sigmoid")(atten) + var_x = Multiply(name=f"{prefix}.mul")([feat, atten]) + return var_x + + +class ContextPath(): + """ The Context Path block for BiSeNet Face Parsing. """ + def __init__(self): + self._resnet = ResNet18() + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Context Path block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + feat8, feat16, feat32 = self._resnet(inputs) + + avg = GlobalAveragePooling2D(name="cp-avgpool")(feat32) + avg = Reshape((1, 1, avg.shape[-1]))(avg) + avg = ConvBn(128, kernel_size=1, padding=0, prefix="cp-conv_avg", start_idx=-1)(avg) + + avg_up = UpSampling2D(size=feat32.shape[1:3], name="cp-upsample")(avg) + + feat32 = AttentionRefinementModule(128)(feat32, 32) + feat32 = Add(name="cp-add")([feat32, avg_up]) + feat32 = UpSampling2D(name="cp-upsample1")(feat32) + feat32 = ConvBn(128, kernel_size=3, prefix="cp-conv_head32", start_idx=-1)(feat32) + + feat16 = AttentionRefinementModule(128)(feat16, 16) + feat16 = Add(name="cp-add2")([feat16, feat32]) + feat16 = UpSampling2D(name="cp-upsample2")(feat16) + feat16 = ConvBn(128, kernel_size=3, prefix="cp-conv_head16", start_idx=-1)(feat16) + + return feat8, feat16, feat32 + + +class FeatureFusionModule(): + """ The Feature Fusion block for BiSeNet Face Parsing + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution). + """ + def __init__(self, filters: int) -> None: + self._filters = filters + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the Feature Fusion block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + feat = Concatenate(name="ffm-concat")(inputs) + feat = ConvBn(self._filters, + kernel_size=1, + padding=0, + prefix="ffm-convblk", + start_idx=-1)(feat) + + atten = GlobalAveragePooling2D(name="ffm-avgpool")(feat) + atten = Reshape((1, 1, atten.shape[-1]))(atten) + atten = Conv2D(self._filters // 4, 1, use_bias=False, name="ffm-conv1")(atten) + atten = Activation("relu", name="ffm-relu")(atten) + atten = Conv2D(self._filters, 1, use_bias=False, name="ffm-conv2")(atten) + atten = Activation("sigmoid", name="ffm-sigmoid")(atten) + + var_x = Multiply(name="ffm-mul")([feat, atten]) + var_x = Add(name="ffm-add")([var_x, feat]) + return var_x + + +class BiSeNetOutput(): + """ The BiSeNet Output block for Face Parsing + + Parameters + ---------- + filters: int + The dimensionality of the output space (i.e. the number of output filters in the + convolution). + num_class: int + The number of classes to generate + label, str, optional + The label for this output (for naming). Default: `""` (i.e. empty string, or no label) + """ + def __init__(self, filters: int, num_classes: int, label: str = "") -> None: + self._filters = filters + self._num_classes = num_classes + self._label = label + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the BiSeNet Output block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the block + """ + var_x = ConvBn(self._filters, prefix=f"conv_out{self._label}-conv", start_idx=-1)(inputs) + var_x = Conv2D(self._num_classes, 1, + use_bias=False, name=f"conv_out{self._label}-conv_out")(var_x) + return var_x + + +class BiSeNet(): + """ BiSeNet Face-Parsing Mask from https://github.com/zllrunning/face-parsing.PyTorch + + PyTorch model implemented in Keras by TorzDF + + Parameters + ---------- + weights_path: str + The path to the keras weights file + batch_size: int + The batch size to feed the model + input_size: int + The input size to the model + num_classes: int + The number of segmentation classes to create + """ + def __init__(self, + weights_path: str, + batch_size: int, + input_size: int, + num_classes: int) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._input_size = input_size + self._num_classes = num_classes + self._cp = ContextPath() + self._model = self._load_model(weights_path) + logger.debug("Initialized: %s", self.__class__.__name__) + + def _load_model(self, weights_path: str) -> Model: + """ Definition of the BiSeNet-FP Model. + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The BiSeNet-FP model + """ + input_ = Input((self._input_size, self._input_size, 3)) + + features = self._cp(input_) # res8, cp8, cp16 + feat_fuse = FeatureFusionModule(256)([features[0], features[1]]) + + feats = [BiSeNetOutput(256, self._num_classes)(feat_fuse), + BiSeNetOutput(64, self._num_classes, label="16")(features[1]), + BiSeNetOutput(64, self._num_classes, label="32")(features[2])] + + height, width = input_.shape[1:3] + output = [UpSampling2D(size=(height // feat.shape[1], width // feat.shape[2]), + interpolation="bilinear")(feat) + for feat in feats] + + retval = Model(input_, output) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Get predictions from the BiSeNet-FP model + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input to BiSeNet-FP + + Returns + ------- + :class:`numpy.ndarray` + The output from BiSeNet-FP + """ + return self._model.predict(inputs, verbose=0, batch_size=self._batch_size) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/bisenet_fp_defaults.py b/plugins/extract/mask/bisenet_fp_defaults.py new file mode 100644 index 0000000000..b335b7e493 --- /dev/null +++ b/plugins/extract/mask/bisenet_fp_defaults.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap BiSeNet Face Parsing plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "BiSeNet Face Parsing options.\n" + "Mask ported from https://github.com/zllrunning/face-parsing.PyTorch." + ) + + +batch_size = ConfigItem( + datatype=int, + default=8, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered.", + rounding=1, + min_max=(1, 64)) + +cpu = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="BiseNet mask still runs fairly quickly on CPU on some setups. Enable " + "CPU mode here to use the CPU for this masker to save some VRAM at a speed cost.") + +weights = ConfigItem( + datatype=str, + default="faceswap", + group="settings", + info="The trained weights to use.\n" + "\n\tfaceswap - Weights trained on wildly varied Faceswap extracted data to " + "better handle varying conditions, obstructions, glasses and multiple targets " + "within a single extracted image." + "\n\toriginal - The original weights trained on the CelebAMask-HQ dataset.", + choices=["faceswap", "original"], + gui_radio=True) + +include_ears = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Whether to include ears within the face mask.") + +include_hair = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Whether to include hair within the face mask.") + +include_glasses = ConfigItem( + datatype=bool, + default=True, + group="settings", + info="Whether to include glasses within the face mask.\n\tFor 'original' weights " + "excluding glasses will mask out the lenses as well as the frames.\n\tFor " + "'faceswap' weights, the model has been trained to mask out lenses if eyes cannot " + "be seen (i.e. dark sunglasses) or just the frames if the eyes can be seen.") diff --git a/plugins/extract/mask/components.py b/plugins/extract/mask/components.py new file mode 100644 index 0000000000..c785673ebf --- /dev/null +++ b/plugins/extract/mask/components.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" Components Mask for faceswap.py """ +from __future__ import annotations +import logging +import typing as T + +import cv2 +import numpy as np + +from lib.align import LandmarkType +from lib.utils import get_module_objects + +from ._base import BatchType, Masker + +if T.TYPE_CHECKING: + from lib.align.aligned_face import AlignedFace + +logger = logging.getLogger(__name__) + + +class Mask(Masker): + # pylint:disable=duplicate-code + """ Apply a landmarks based components mask """ + def __init__(self, **kwargs) -> None: + git_model_id = None + model_filename = None + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.input_size = 256 + self.name = "Components" + self.vram = 0 # Doesn't use GPU + self.vram_per_batch = 0 + self.batchsize = 1 + self.landmark_type = LandmarkType.LM_2D_68 + + def init_model(self) -> None: + logger.debug("No mask model to initialize") + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + batch.feed = np.zeros((self.batchsize, self.input_size, self.input_size, 1), + dtype="float32") + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + faces: list[AlignedFace] = feed[1] + feed = feed[0] + for mask, face in zip(feed, faces): + if LandmarkType.from_shape(face.landmarks.shape) != self.landmark_type: + # Called from the manual tool. # TODO This will only work with BS1 + feed = np.zeros_like(feed) + continue + parts = self.parse_parts(np.array(face.landmarks)) + for item in parts: + a_item = np.rint(np.concatenate(item)).astype("int32") + hull = cv2.convexHull(a_item) + cv2.fillConvexPoly(mask, hull, [1.0], lineType=cv2.LINE_AA) + return feed + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + @staticmethod + def parse_parts(landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]: + """ Component face hull mask """ + r_jaw = (landmarks[0:9], landmarks[17:18]) + l_jaw = (landmarks[8:17], landmarks[26:27]) + r_cheek = (landmarks[17:20], landmarks[8:9]) + l_cheek = (landmarks[24:27], landmarks[8:9]) + nose_ridge = (landmarks[19:25], landmarks[8:9],) + r_eye = (landmarks[17:22], + landmarks[27:28], + landmarks[31:36], + landmarks[8:9]) + l_eye = (landmarks[22:27], + landmarks[27:28], + landmarks[31:36], + landmarks[8:9]) + nose = (landmarks[27:31], landmarks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + return parts + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/custom.py b/plugins/extract/mask/custom.py new file mode 100644 index 0000000000..2d7b353ab7 --- /dev/null +++ b/plugins/extract/mask/custom.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +""" Components Mask for faceswap.py """ +from __future__ import annotations +import logging +import typing as T + +import numpy as np +from lib.utils import get_module_objects +from ._base import BatchType, Masker + +from . import custom_defaults as cfg + +if T.TYPE_CHECKING: + from lib.align.constants import CenteringType + +logger = logging.getLogger(__name__) + + +class Mask(Masker): + """ A mask that fills the whole face area with 1s or 0s (depending on user selected settings) + for custom editing. """ + # pylint:disable=duplicate-code + def __init__(self, **kwargs): + git_model_id = None + model_filename = None + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.input_size = 256 + self.name = "Custom" + self.vram = 0 # Doesn't use GPU + self.vram_per_batch = 0 + self.batchsize = cfg.batch_size() + self._storage_centering = T.cast("CenteringType", cfg.centering()) + # Separate storage for face and head masks + self._storage_name = f"{self._storage_name}_{self._storage_centering}" + + def init_model(self) -> None: + logger.debug("No mask model to initialize") + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + batch.feed = np.zeros((self.batchsize, self.input_size, self.input_size, 1), + dtype="float32") + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + if cfg.fill(): + feed[:] = 1.0 + return feed + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/custom_defaults.py b/plugins/extract/mask/custom_defaults.py new file mode 100644 index 0000000000..4eea21fcf7 --- /dev/null +++ b/plugins/extract/mask/custom_defaults.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap BiSeNet Face Parsing plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "Custom (dummy) Mask options..\n" + "The custom mask just fills a face patch with all 0's (masked out) or all 1's (masked in) for " + "later manual editing. It does not use the GPU for creation." + ) + + +batch_size = ConfigItem( + datatype=int, + default=8, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.", + rounding=1, + min_max=(1, 64)) + +centering = ConfigItem( + datatype=str, + group="settings", + default="face", + info="Whether to create a dummy mask with face or head centering.", + choices=["face", "head"], + gui_radio=True) + +fill = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Whether the mask should be filled (True) in which case the custom mask will be " + "created with the whole area masked in (i.e. you would need to manually edit out " + "the background) or unfilled (False) in which case you would need to manually " + "edit in the face.") diff --git a/plugins/extract/mask/extended.py b/plugins/extract/mask/extended.py new file mode 100644 index 0000000000..e88ba959dc --- /dev/null +++ b/plugins/extract/mask/extended.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +""" Extended Mask for faceswap.py """ +from __future__ import annotations +import logging +import typing as T + +import cv2 +import numpy as np + +from lib.align import LandmarkType +from lib.utils import get_module_objects + +from ._base import BatchType, Masker + +logger = logging.getLogger(__name__) + +if T.TYPE_CHECKING: + from lib.align.aligned_face import AlignedFace + + +class Mask(Masker): + """ Apply a landmarks based extended mask """ + def __init__(self, **kwargs): + git_model_id = None + model_filename = None + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.input_size = 256 + self.name = "Extended" + self.vram = 0 # Doesn't use GPU + self.vram_per_batch = 0 + self.batchsize = 1 + self.landmark_type = LandmarkType.LM_2D_68 + + def init_model(self) -> None: + logger.debug("No mask model to initialize") + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + batch.feed = np.zeros((self.batchsize, self.input_size, self.input_size, 1), + dtype="float32") + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + faces: list[AlignedFace] = feed[1] + feed = feed[0] + for mask, face in zip(feed, faces): + if LandmarkType.from_shape(face.landmarks.shape) != self.landmark_type: + # Called from the manual tool. # TODO This will only work with BS1 + feed = np.zeros_like(feed) + continue + parts = self.parse_parts(np.array(face.landmarks)) + for item in parts: + a_item = np.rint(np.concatenate(item)).astype("int32") + hull = cv2.convexHull(a_item) + cv2.fillConvexPoly(mask, hull, [1.0], lineType=cv2.LINE_AA) + return feed + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + @classmethod + def _adjust_mask_top(cls, landmarks: np.ndarray) -> None: + """ Adjust the top of the mask to extend above eyebrows + + Parameters + ---------- + landmarks: :class:`numpy.ndarray` + The 68 point landmarks to be adjusted + """ + # mid points between the side of face and eye point + ml_pnt = (landmarks[36] + landmarks[0]) // 2 + mr_pnt = (landmarks[16] + landmarks[45]) // 2 + + # mid points between the mid points and eye + ql_pnt = (landmarks[36] + ml_pnt) // 2 + qr_pnt = (landmarks[45] + mr_pnt) // 2 + + # Top of the eye arrays + bot_l = np.array((ql_pnt, landmarks[36], landmarks[37], landmarks[38], landmarks[39])) + bot_r = np.array((landmarks[42], landmarks[43], landmarks[44], landmarks[45], qr_pnt)) + + # Eyebrow arrays + top_l = landmarks[17:22] + top_r = landmarks[22:27] + + # Adjust eyebrow arrays + landmarks[17:22] = top_l + ((top_l - bot_l) // 2) + landmarks[22:27] = top_r + ((top_r - bot_r) // 2) + + def parse_parts(self, landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]: + """ Extended face hull mask """ + self._adjust_mask_top(landmarks) + + r_jaw = (landmarks[0:9], landmarks[17:18]) + l_jaw = (landmarks[8:17], landmarks[26:27]) + r_cheek = (landmarks[17:20], landmarks[8:9]) + l_cheek = (landmarks[24:27], landmarks[8:9]) + nose_ridge = (landmarks[19:25], landmarks[8:9],) + r_eye = (landmarks[17:22], + landmarks[27:28], + landmarks[31:36], + landmarks[8:9]) + l_eye = (landmarks[22:27], + landmarks[27:28], + landmarks[31:36], + landmarks[8:9]) + nose = (landmarks[27:31], landmarks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + return parts + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/unet_dfl.py b/plugins/extract/mask/unet_dfl.py new file mode 100644 index 0000000000..ec196b1137 --- /dev/null +++ b/plugins/extract/mask/unet_dfl.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" UNET DFL face mask plugin + +Architecture and Pre-Trained Model based on... +TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation +https://arxiv.org/abs/1801.05746 +https://github.com/ternaus/TernausNet + +Source Implementation and fine-tune training.... +https://github.com/iperov/DeepFaceLab/blob/master/nnlib/TernausNet.py + +Model file sourced from... +https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5 +""" +from __future__ import annotations + +import logging +import typing as T + +import numpy as np +from keras import backend as K, layers as kl, Model + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from ._base import BatchType, Masker, MaskerBatch +from . import unet_dfl_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + + +logger = logging.getLogger(__name__) + + +class Mask(Masker): + """ Neural network to process face image into a segmentation mask of the face """ + def __init__(self, **kwargs) -> None: + git_model_id = 6 + model_filename = "DFL_256_sigmoid_v1.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.model: UnetDFL + self.name = "U-Net" + self.input_size = 256 + self.vram = 320 # 276 in testing + self.vram_per_batch = 256 # ~215 in testing + self.batchsize = cfg.batch_size() + self._storage_centering = "legacy" + + def init_model(self) -> None: + assert self.name is not None and isinstance(self.model_path, str) + self.model = UnetDFL(self.model_path, self.batchsize) + placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3), + dtype="float32") + self.model(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + assert isinstance(batch, MaskerBatch) + batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3] + for feed in batch.feed_faces], dtype="float32") / 255.0 + logger.trace("feed shape: %s", batch.feed.shape) # type: ignore + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + return self.model(feed) + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + +class UnetDFL: + """ UNet DFL Definition for Keras 3 with PyTorch backend + + Parameters + ---------- + weights_path: str + Full path to the location of the weights file for the model + batch_size: int + The batch size to feed the model at + + Note + ---- + Model definition is explicitly stated as there is an incompatibility for certain + Conv2DTranspose combinations when model was trained on one backend but inferred on another: + https://github.com/keras-team/keras-core/issues/774 + The effect of this misaligns the mask and peforms bad inference for this model. + """ + def __init__(self, weights_path: str, batch_size: int) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def conv_block(cls, + inputs: KerasTensor, + filters: int, + recursions: int, + idx: int) -> KerasTensor: + """ Convolution block for UnetDFL downscales + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The inputs to the block + filters: int + The number of filters for the convolution + recursions: int + The number of convolutions to run + idx: The index id of the first convolution (used for naming) + + Returns + ------- + :class:`keras.KerasTensor` + The output from the convolution block + """ + output = inputs + + for _ in range(recursions): + output = kl.Conv2D(filters, + 3, + padding="same", + activation="relu", + kernel_initializer="random_uniform", + name=f"features_{idx}")(output) + idx += 2 + + return output + + @classmethod + def skip_block(cls, # pylint:disable=too-many-positional-arguments + input_1: KerasTensor, + input_2: KerasTensor, + conv_filters: int, + trans_filters: int, + linear: bool, + idx: int) -> KerasTensor: + """ Deconvolution + skip connection for UnetDFL upscales + + Parameters + ---------- + input_1: :class:`keras.KerasTensor` + The input to be upscaled + input_2: :class:`keras.KerasTensor` + The skip connection to be concatenated to the upscaled tensor + conv_filters: int + The number of filters to be used for the convolution + trans_filters: int + The number of filters to be used for the conv-transpose + linear: bool + ``True`` to use linear activation in the convolution, ``False`` to use ReLu + idx: int + The index for naming the layers + + Returns + ------- + :class:`keras.KerasTensor` + The output from the upscaled/skip connection + """ + output = kl.Conv2D(conv_filters, + 3, + padding="same", + activation="linear" if linear else "relu", + kernel_initializer="random_uniform", + name=f"conv2d_{idx}")(input_1) + + # TF vs PyTorch paddng is different. We need to negative pad the output for Torch + padding = "valid" if K.backend() == "torch" else "same" + output = kl.Conv2DTranspose(trans_filters, + 3, + strides=2, + padding=padding, + activation="relu", + kernel_initializer="random_uniform", + name=f"conv2d_transpose_{idx}")(output) + + if K.backend() == "torch": + output = output[:, :-1, :-1, :] + + return kl.Concatenate(name=f"concatenate_{idx}")([output, input_2]) + + def _load_model(self, weights_path: str) -> Model: + """ Definition of the UNet-DFL Model. + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The VGG-Clear model + """ + features = [] + input_ = kl.Input(shape=(256, 256, 3), name="input_1") + + features.append(self.conv_block(input_, 64, 1, 0)) + var_x = kl.MaxPool2D(pool_size=2, strides=2, name="max_pooling2d_1")(features[-1]) + + features.append(self.conv_block(var_x, 128, 1, 3)) + var_x = kl.MaxPool2D(pool_size=2, strides=2, name="max_pooling2d_2")(features[-1]) + + features.append(self.conv_block(var_x, 256, 2, 6)) + var_x = kl.MaxPool2D(pool_size=2, strides=2, name="max_pooling2d_3")(features[-1]) + + features.append(self.conv_block(var_x, 512, 2, 11)) + var_x = kl.MaxPool2D(pool_size=2, strides=2, name="max_pooling2d_4")(features[-1]) + + features.append(self.conv_block(var_x, 512, 2, 16)) + var_x = kl.MaxPool2D(pool_size=2, strides=2, name="max_pooling2d_5")(features[-1]) + + convs = [512, 512, 512, 256, 128] + for idx, (feats, filts) in enumerate(zip(reversed(features), convs)): + linear = idx == 0 + trans_filts = filts // 2 if idx < 2 else filts // 4 + var_x = self.skip_block(var_x, feats, filts, trans_filts, linear, idx + 1) + + var_x = kl.Conv2D(64, + 3, + padding="same", + activation="relu", + kernel_initializer="random_uniform", + name="conv2d_6")(var_x) + output = kl.Conv2D(1, + 3, + padding="same", + activation="sigmoid", + kernel_initializer="random_uniform", + name="conv2d_7")(var_x) + + model = Model(input_, output) + model.load_weights(weights_path) + model.make_predict_function() + return model + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Obtain predictions from the UNet-DFL Model + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input to UNet-DFL + + Returns + ------- + :class:`numpy.ndarray` + The output from UNet-DFL + """ + return self._model.predict(inputs, verbose=0, batch_size=self._batch_size) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/unet_dfl_defaults.py b/plugins/extract/mask/unet_dfl_defaults.py new file mode 100644 index 0000000000..4d20870c4e --- /dev/null +++ b/plugins/extract/mask/unet_dfl_defaults.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap UNET dfl plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "UNET_DFL options. Mask designed to provide smart segmentation of mostly frontal faces.\n" + "The mask model has been trained by community members. Insert more commentary on testing " + "here. Profile faces may result in sub-par performance." + ) + + +batch_size = ConfigItem( + datatype=int, + default=8, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered.", + rounding=1, + min_max=(1, 64)) diff --git a/plugins/extract/mask/vgg_clear.py b/plugins/extract/mask/vgg_clear.py new file mode 100644 index 0000000000..dc1a32a73c --- /dev/null +++ b/plugins/extract/mask/vgg_clear.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" VGG Clear face mask plugin. """ +from __future__ import annotations +import logging +import typing as T + +import numpy as np + +from keras import layers as kl, Model + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from ._base import BatchType, Masker, MaskerBatch +from . import vgg_clear_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +class Mask(Masker): + """ Neural network to process face image into a segmentation mask of the face """ + def __init__(self, **kwargs) -> None: + git_model_id = 8 + model_filename = "Nirkin_300_softmax_v1.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.model: VGGClear + self.name = "VGG Clear" + self.input_size = 300 + self.vram = 1344 # 1308 in testing + self.vram_per_batch = 448 # ~402 in testing + self.batchsize = cfg.batch_size() + + def init_model(self) -> None: + assert isinstance(self.model_path, str) + self.model = VGGClear(self.model_path, self.batchsize) + placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3), + dtype="float32") + self.model(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + assert isinstance(batch, MaskerBatch) + input_ = np.array([T.cast(np.ndarray, feed.face)[..., :3] + for feed in batch.feed_faces], dtype="float32") + batch.feed = input_ - np.mean(input_, axis=(1, 2))[:, None, None, :] + logger.trace("feed shape: %s", batch.feed.shape) # type: ignore + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + predictions = self.model(feed) + assert isinstance(predictions, np.ndarray) + return predictions[..., -1] + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + +class VGGClear(): + """ VGG Clear mask for Faceswap. + + Caffe model re-implemented in Keras by Kyle Vrooman. + Re-implemented for Keras by TorzDF + + Parameters + ---------- + weights_path: str + The path to the keras model file + batch_size: int + The batch size to feed the model + + References + ---------- + On Face Segmentation, Face Swapping, and Face Perception (https://arxiv.org/abs/1704.06729) + + Source Implementation: https://github.com/YuvalNirkin/face_segmentation + + Model file sourced from: + https://github.com/YuvalNirkin/face_segmentation/releases/download/1.1/face_seg_fcn8s_300_no_aug.zip + + """ + def __init__(self, weights_path: str, batch_size: int) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def _load_model(cls, weights_path: str) -> Model: + """ Definition of the VGG Clear Model. + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The VGG-Clear model + """ + input_ = kl.Input(shape=(300, 300, 3)) + var_x = kl.ZeroPadding2D(padding=((100, 100), (100, 100)), name="zero_padding2d_1")(input_) + + var_x = _ConvBlock(1, 64, 2)(var_x) + var_x = _ConvBlock(2, 128, 2)(var_x) + pool3 = _ConvBlock(3, 256, 3)(var_x) + pool4 = _ConvBlock(4, 512, 3)(pool3) + var_x = _ConvBlock(5, 512, 3)(pool4) + + score_pool3 = _ScorePool(3, 0.0001, (9, 8))(pool3) + score_pool4 = _ScorePool(4, 0.01, (5, 5))(pool4) + + var_x = kl.Conv2D(4096, 7, activation="relu", name="fc6")(var_x) + var_x = kl.Dropout(rate=0.5, name="drop6")(var_x) + var_x = kl.Conv2D(4096, 1, activation="relu", name="fc7")(var_x) + var_x = kl.Dropout(rate=0.5, name="drop7")(var_x) + var_x = kl.Conv2D(2, 1, activation="linear", name="score_fr_r")(var_x) + var_x = kl.Conv2DTranspose(2, + 4, + strides=2, + activation="linear", + use_bias=False, name="upscore2_r")(var_x) + + var_x = kl.Add(name="fuse_pool4")([var_x, score_pool4]) + var_x = kl.Conv2DTranspose(2, + 4, + strides=2, + activation="linear", + use_bias=False, + name="upscore_pool4_r")(var_x) + var_x = kl.Add(name="fuse_pool3")([var_x, score_pool3]) + var_x = kl.Conv2DTranspose(2, + 16, + strides=8, + activation="linear", + use_bias=False, + name="upscore8_r")(var_x) + var_x = kl.Cropping2D(cropping=((31, 45), (31, 45)), name="score")(var_x) + var_x = kl.Activation("softmax", name="softmax")(var_x) + + retval = Model(input_, var_x) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Get predictions from the VGG-Clear model + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input to VGG-Clear + + Returns + ------- + :class:`numpy.ndarray` + The output from VGG-Clear + """ + return self._model.predict(inputs, verbose=0, batch_size=self._batch_size) + + +class _ConvBlock(): + """ Convolutional loop with max pooling layer for VGG Clear. + + Parameters + ---------- + level: int + For naming. The current level for this convolutional loop + filters: int + The number of filters that should appear in each Conv2D layer + iterations: int + The number of consecutive Conv2D layers to create + """ + def __init__(self, level: int, filters: int, iterations: int) -> None: + self._name = f"conv{level}_" + self._level = level + self._filters = filters + self._iterator = range(1, iterations + 1) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the convolutional loop. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the convolutional block + """ + var_x = inputs + for i in self._iterator: + padding = "valid" if self._level == i == 1 else "same" + var_x = kl.Conv2D(self._filters, + 3, + padding=padding, + activation="relu", + name=f"{self._name}{i}")(var_x) + var_x = kl.MaxPooling2D(padding="same", + strides=(2, 2), + name=f"pool{self._level}")(var_x) + return var_x + + +class _ScorePool(): + """ Cropped scaling of the pooling layer. + + Parameters + ---------- + level: int + For naming. The current level for this score pool + scale: float + The scaling to apply to the pool + crop: tuple + The amount of 2D cropping to apply. Tuple of `ints` + """ + def __init__(self, level: int, scale: float, crop: tuple[int, int]): + self._name = f"_pool{level}" + self._cropping = (crop, crop) + self._scale = scale + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Score pool block. + + Parameters + ---------- + inputs: tensor + The input tensor to the block + + Returns + ------- + tensor + The output tensor from the score pool block + """ + var_x = kl.Lambda(lambda x: x * self._scale, name="scale" + self._name)(inputs) + var_x = kl.Conv2D(2, 1, activation="linear", name="score" + self._name + "_r")(var_x) + var_x = kl.Cropping2D(cropping=self._cropping, name="score" + self._name + "c")(var_x) + return var_x + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/vgg_clear_defaults.py b/plugins/extract/mask/vgg_clear_defaults.py new file mode 100644 index 0000000000..48ee92f329 --- /dev/null +++ b/plugins/extract/mask/vgg_clear_defaults.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap VGG clear plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "VGG_Clear options. Mask designed to provide smart segmentation of mostly frontal faces clear " + "of obstructions.\nProfile faces and obstructions may result in sub-par performance." + ) + + +batch_size = ConfigItem( + datatype=int, + default=6, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered.", + rounding=1, + min_max=(1, 64)) diff --git a/plugins/extract/mask/vgg_obstructed.py b/plugins/extract/mask/vgg_obstructed.py new file mode 100644 index 0000000000..b2733c5bc2 --- /dev/null +++ b/plugins/extract/mask/vgg_obstructed.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" VGG Obstructed face mask plugin """ +from __future__ import annotations +import logging +import typing as T + +import numpy as np + +from keras import layers as kl, Model + +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from ._base import BatchType, Masker, MaskerBatch +from . import vgg_obstructed_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + +# pylint:disable=duplicate-code + + +class Mask(Masker): + """ Neural network to process face image into a segmentation mask of the face """ + def __init__(self, **kwargs) -> None: + git_model_id = 5 + model_filename = "Nirkin_500_softmax_v1.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.model: VGGObstructed + self.name = "VGG Obstructed" + self.input_size = 500 + self.vram = 1728 # 1710 in testing + self.vram_per_batch = 896 # ~886 in testing + self.batchsize = cfg.batch_size() + + def init_model(self) -> None: + assert isinstance(self.model_path, str) + self.model = VGGObstructed(self.model_path, self.batchsize) + placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3), + dtype="float32") + self.model(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + assert isinstance(batch, MaskerBatch) + input_ = [T.cast(np.ndarray, feed.face)[..., :3] for feed in batch.feed_faces] + batch.feed = input_ - np.mean(input_, axis=(1, 2))[:, None, None, :] + logger.trace("feed shape: %s", batch.feed.shape) # type:ignore[attr-defined] + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Run model to get predictions """ + predictions = self.model(feed) + assert isinstance(predictions, np.ndarray) + return predictions[..., 0] * -1.0 + 1.0 + + def process_output(self, batch: BatchType) -> None: + """ Compile found faces for output """ + return + + +class VGGObstructed(): + """ VGG Obstructed mask for Faceswap. + + Caffe model re-implemented in Keras by Kyle Vrooman. + Re-implemented for Keras by TorzDF + + Parameters + ---------- + weights_path: str + The path to the keras model file + batch_size: int + The batch size to feed the model + + References + ---------- + On Face Segmentation, Face Swapping, and Face Perception (https://arxiv.org/abs/1704.06729) + Source Implementation: https://github.com/YuvalNirkin/face_segmentation + Model file sourced from: + https://github.com/YuvalNirkin/face_segmentation/releases/download/1.0/face_seg_fcn8s.zip + """ + def __init__(self, weights_path: str, batch_size: int) -> None: + logger.debug(parse_class_init(locals())) + self._batch_size = batch_size + self._model = self._load_model(weights_path) + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def _load_model(cls, weights_path: str) -> Model: + """ Definition of the VGG Obstructed Model. + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The VGG-Obstructed model + """ + input_ = kl.Input(shape=(500, 500, 3)) + var_x = kl.ZeroPadding2D(padding=((100, 100), (100, 100)))(input_) + + var_x = _ConvBlock(1, 64, 2)(var_x) + var_x = _ConvBlock(2, 128, 2)(var_x) + var_x = _ConvBlock(3, 256, 3)(var_x) + + score_pool3 = _ScorePool(3, 0.0001, 9)(var_x) + var_x = _ConvBlock(4, 512, 3)(var_x) + score_pool4 = _ScorePool(4, 0.01, 5)(var_x) + var_x = _ConvBlock(5, 512, 3)(var_x) + + var_x = kl.Conv2D(4096, 7, padding="valid", activation="relu", name="fc6")(var_x) + var_x = kl.Dropout(rate=0.5)(var_x) + var_x = kl.Conv2D(4096, 1, padding="valid", activation="relu", name="fc7")(var_x) + var_x = kl.Dropout(rate=0.5)(var_x) + + var_x = kl.Conv2D(21, 1, padding="valid", activation="linear", name="score_fr")(var_x) + var_x = kl.Conv2DTranspose(21, + 4, + strides=2, + activation="linear", + use_bias=False, + name="upscore2")(var_x) + + var_x = kl.Add()([var_x, score_pool4]) + var_x = kl.Conv2DTranspose(21, + 4, + strides=2, + activation="linear", + use_bias=False, + name="upscore_pool4")(var_x) + + var_x = kl.Add()([var_x, score_pool3]) + var_x = kl.Conv2DTranspose(21, + 16, + strides=8, + activation="linear", + use_bias=False, + name="upscore8")(var_x) + var_x = kl.Cropping2D(cropping=((31, 37), (31, 37)), name="score")(var_x) + var_x = kl.Activation("softmax", name="softmax")(var_x) + + retval = Model(input_, var_x) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Get predictions from the VGG-Clear model + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input to VGG-Obstructed + + Returns + ------- + :class:`numpy.ndarray` + The output from VGG-Obstructed + """ + return self._model.predict(inputs, verbose=0, batch_size=self._batch_size) + + +class _ConvBlock(): + """ Convolutional loop with max pooling layer for VGG Obstructed. + + Parameters + ---------- + level: int + For naming. The current level for this convolutional loop + filters: int + The number of filters that should appear in each Conv2D layer + iterations: int + The number of consecutive Conv2D layers to create + """ + def __init__(self, level: int, filters: int, iterations: int) -> None: + self._name = f"conv{level}_" + self._level = level + self._filters = filters + self._iterator = range(1, iterations + 1) + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the convolutional loop. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the convolutional block + """ + var_x = inputs + for i in self._iterator: + padding = "valid" if self._level == i == 1 else "same" + var_x = kl.Conv2D(self._filters, + 3, + padding=padding, + activation="relu", + name=f"{self._name}{i}")(var_x) + var_x = kl.MaxPooling2D(padding="same", + strides=(2, 2), + name=f"pool{self._level}")(var_x) + return var_x + + +class _ScorePool(): + """ Cropped scaling of the pooling layer. + + Parameters + ---------- + level: int + For naming. The current level for this score pool + scale: float + The scaling to apply to the pool + crop: int + The amount of 2D cropping to apply + """ + def __init__(self, level: int, scale: float, crop: int) -> None: + self._name = f"_pool{level}" + self._cropping = ((crop, crop), (crop, crop)) + self._scale = scale + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Score pool block. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the block + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the score pool block + """ + var_x = kl.Lambda(lambda x: x * self._scale, name="scale" + self._name)(inputs) + var_x = kl.Conv2D(21, + 1, + padding="valid", + activation="linear", + name="score" + self._name)(var_x) + var_x = kl.Cropping2D(cropping=self._cropping, name="score" + self._name + "c")(var_x) + return var_x + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/mask/vgg_obstructed_defaults.py b/plugins/extract/mask/vgg_obstructed_defaults.py new file mode 100644 index 0000000000..9a42624a40 --- /dev/null +++ b/plugins/extract/mask/vgg_obstructed_defaults.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap VGG obstructed plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "VGG_Obstructed options. Mask designed to provide smart segmentation of mostly frontal " + "faces.\nThe mask model has been specifically trained to recognize some facial obstructions " + "(hands and eyeglasses). Profile faces may result in sub-par performance." + ) + + +batch_size = ConfigItem( + datatype=int, + default=2, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered.", + rounding=1, + min_max=(1, 64)) diff --git a/plugins/extract/pipeline.py b/plugins/extract/pipeline.py new file mode 100644 index 0000000000..52ec0e4373 --- /dev/null +++ b/plugins/extract/pipeline.py @@ -0,0 +1,875 @@ +#!/usr/bin/env python3 +""" +Return a requested detector/aligner/masker pipeline + +This module sets up a pipeline for the extraction workflow, loading detect, align and mask +plugins either in parallel or in series, giving easy access to input and output. +""" +from __future__ import annotations +import logging +import os +import typing as T + +from lib.align import LandmarkType +from lib.gpu_stats import GPUStats +from lib.logger import parse_class_init +from lib.queue_manager import EventQueue, queue_manager, QueueEmpty +from lib.serializer import get_serializer +from lib.utils import get_backend, get_module_objects, FaceswapError +from plugins.plugin_loader import PluginLoader + +if T.TYPE_CHECKING: + from collections.abc import Generator + from ._base import Extractor as PluginExtractor + from .align._base import Aligner + from .align.external import Align as AlignImport + from .detect._base import Detector + from .detect.external import Detect as DetectImport + from .mask._base import Masker + from .recognition._base import Identity + from . import ExtractMedia + +logger = logging.getLogger(__name__) +_INSTANCES = -1 # Tracking for multiple instances of pipeline + + +def _get_instance(): + """ Increment the global :attr:`_INSTANCES` and obtain the current instance value """ + global _INSTANCES # pylint:disable=global-statement + _INSTANCES += 1 + return _INSTANCES + + +class Extractor(): # pylint:disable=too-many-instance-attributes + """ Creates a :mod:`~plugins.extract.detect`/:mod:`~plugins.extract.align``/\ + :mod:`~plugins.extract.mask` pipeline and yields results frame by frame from the + :attr:`detected_faces` generator + + :attr:`input_queue` is dynamically set depending on the current :attr:`phase` of extraction + + Parameters + ---------- + detector: str or ``None`` + The name of a detector plugin as exists in :mod:`plugins.extract.detect` + aligner: str or ``None`` + The name of an aligner plugin as exists in :mod:`plugins.extract.align` + masker: str or list or ``None`` + The name of a masker plugin(s) as exists in :mod:`plugins.extract.mask`. + This can be a single masker or a list of multiple maskers + recognition: str or ``None`` + The name of the recognition plugin to use. ``None`` to not do face recognition. + Default: ``None`` + configfile: str, optional + The path to a custom ``extract.ini`` configfile. If ``None`` then the system + :file:`config/extract.ini` file will be used. + multiprocess: bool, optional + Whether to attempt processing the plugins in parallel. This may get overridden + internally depending on the plugin combination. Default: ``False`` + rotate_images: str, optional + Used to set the :attr:`plugins.extract.detect.rotation` attribute. Pass in a single number + to use increments of that size up to 360, or pass in a ``list`` of ``ints`` to enumerate + exactly what angles to check. Can also pass in ``'on'`` to increment at 90 degree + intervals. Default: ``None`` + min_size: int, optional + Used to set the :attr:`plugins.extract.detect.min_size` attribute. Filters out faces + detected below this size. Length, in pixels across the diagonal of the bounding box. Set + to ``0`` for off. Default: ``0`` + normalize_method: {`None`, 'clahe', 'hist', 'mean'}, optional + Used to set the :attr:`plugins.extract.align.normalize_method` attribute. Normalize the + images fed to the aligner.Default: ``None`` + re_feed: int + The number of times to re-feed a slightly adjusted bounding box into the aligner. + Default: `0` + re_align: bool, optional + ``True`` to obtain landmarks by passing the initially aligned face back through the + aligner. Default ``False`` + disable_filter: bool, optional + Disable all aligner filters regardless of config option. Default: ``False`` + + Attributes + ---------- + phase: str + The current phase that the pipeline is running. Used in conjunction with :attr:`passes` and + :attr:`final_pass` to indicate to the caller which phase is being processed + """ + def __init__(self, # pylint:disable=too-many-arguments,too-many-positional-arguments + detector: str | None, + aligner: str | None, + masker: str | list[str] | None, + recognition: str | None = None, + configfile: str | None = None, + multiprocess: bool = False, + rotate_images: str | None = None, + min_size: int = 0, + normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None, + re_feed: int = 0, + re_align: bool = False, + disable_filter: bool = False) -> None: + logger.debug(parse_class_init(locals())) + self._instance = _get_instance() + maskers = [T.cast(str | None, + masker)] if not isinstance(masker, list) else T.cast(list[str | None], + masker) + self._flow = self._set_flow(detector, aligner, maskers, recognition) + # TODO Calculate scaling for more plugins than currently exist in _parallel_scaling + self._scaling_fallback = 0.4 + self._vram_stats = self._get_vram_stats() + self._detect = self._load_detect(detector, aligner, rotate_images, min_size, configfile) + self._align = self._load_align(aligner, + configfile, + normalize_method, + re_feed, + re_align, + disable_filter) + self._recognition = self._load_recognition(recognition, configfile) + self._mask = [self._load_mask(mask, configfile) for mask in maskers] + self._phases = self._set_phases(multiprocess) + self._phase_index = 0 + self._set_extractor_batchsize() + self._queues = self._add_queues() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def input_queue(self) -> EventQueue: + """ queue: Return the correct input queue depending on the current phase + + The input queue is the entry point into the extraction pipeline. An :class:`ExtractMedia` + object should be put to the queue. + + For detect/single phase operations the :attr:`ExtractMedia.filename` and + :attr:`~ExtractMedia.image` attributes should be populated. + + For align/mask (2nd/3rd pass operations) the :attr:`ExtractMedia.detected_faces` should + also be populated by calling :func:`ExtractMedia.set_detected_faces`. + """ + qname = f"extract{self._instance}_{self._current_phase[0]}_in" + retval = self._queues[qname] + logger.trace("%s: %s", qname, retval) # type: ignore + return retval + + @property + def passes(self) -> int: + """ int: Returns the total number of passes the extractor needs to make. + + This is calculated on several factors (vram available, plugin choice, + :attr:`multiprocess` etc.). It is useful for iterating over the pipeline + and handling accordingly. + + Example + ------- + >>> for phase in extractor.passes: + >>> if phase == 1: + >>> extract_media = ExtractMedia("path/to/image/file", image) + >>> extractor.input_queue.put(extract_media) + >>> else: + >>> extract_media.set_image(image) + >>> extractor.input_queue.put(extract_media) + """ + retval = len(self._phases) + logger.trace(retval) # type: ignore + return retval + + @property + def phase_text(self) -> str: + """ str: The plugins that are running in the current phase, formatted for info text + output. """ + plugin_types = set(self._get_plugin_type_and_index(phase)[0] + for phase in self._current_phase) + retval = ", ".join(plugin_type.title() for plugin_type in list(plugin_types)) + logger.trace(retval) # type: ignore + return retval + + @property + def final_pass(self) -> bool: + """ bool, Return ``True`` if this is the final extractor pass otherwise ``False`` + + Useful for iterating over the pipeline :attr:`passes` or :func:`detected_faces` and + handling accordingly. + + Example + ------- + >>> for face in extractor.detected_faces(): + >>> if extractor.final_pass: + >>> + >>> else: + >>> extract_media.set_image(image) + >>> + >>> extractor.input_queue.put(extract_media) + """ + retval = self._phase_index == len(self._phases) - 1 + logger.trace(retval) # type:ignore[attr-defined] + return retval + + @property + def aligner(self) -> Aligner: + """ The currently selected aligner plugin """ + assert self._align is not None + return self._align + + @property + def recognition(self) -> Identity: + """ The currently selected recognition plugin """ + assert self._recognition is not None + return self._recognition + + def reset_phase_index(self) -> None: + """ Reset the current phase index back to 0. Used for when batch processing is used in + extract. """ + self._phase_index = 0 + + def set_batchsize(self, + plugin_type: T.Literal["align", "detect"], + batchsize: int) -> None: + """ Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`. + + This should be set prior to :func:`launch` if the batch size is to be manually overridden + + Parameters + ---------- + plugin_type: {'align', 'detect'} + The plugin_type to be overridden + batchsize: int + The batch size to use for this plugin type + """ + logger.debug("Overriding batchsize for plugin_type: %s to: %s", plugin_type, batchsize) + plugin = getattr(self, f"_{plugin_type}") + plugin.batchsize = batchsize + + def launch(self) -> None: + """ Launches the plugin(s) + + This launches the plugins held in the pipeline, and should be called at the beginning + of each :attr:`phase`. To ensure VRAM is conserved, It will only launch the plugin(s) + required for the currently running phase + + Example + ------- + >>> for phase in extractor.passes: + >>> extractor.launch(): + >>> + """ + for phase in self._current_phase: + self._launch_plugin(phase) + + def detected_faces(self) -> Generator[ExtractMedia, None, None]: + """ Generator that returns results, frame by frame from the extraction pipeline + + This is the exit point for the extraction pipeline and is used to obtain the output + of any pipeline :attr:`phase` + + Yields + ------ + faces: :class:`~plugins.extract.extract_media.ExtractMedia` + The populated extracted media object. + + Example + ------- + >>> for extract_media in extractor.detected_faces(): + >>> filename = extract_media.filename + >>> image = extract_media.image + >>> detected_faces = extract_media.detected_faces + """ + logger.debug("Running Detection. Phase: '%s'", self._current_phase) + # If not multiprocessing, intercept the align in queue for + # detection phase + out_queue = self._output_queue + while True: + try: + self._check_and_raise_error() + faces = out_queue.get(True, 1) + if faces == "EOF": + break + except QueueEmpty: + continue + yield faces + + self._join_threads() + if self.final_pass: + for plugin in self._all_plugins: + plugin.on_completion() + logger.debug("Detection Complete") + else: + self._phase_index += 1 + logger.debug("Switching to phase: %s", self._current_phase) + + def _disable_lm_maskers(self) -> None: + """ Disable any 68 point landmark based maskers if alignment data is not 2D 68 + point landmarks and update the process flow/phases accordingly """ + logger.warning("Alignment data is not 68 point 2D landmarks. Some Faceswap functionality " + "will be unavailable for these faces") + + rem_maskers = [m.name for m in self._mask + if m is not None and m.landmark_type == LandmarkType.LM_2D_68] + self._mask = [m for m in self._mask if m is None or m.name not in rem_maskers] + + self._flow = [ + item for item in self._flow + if not item.startswith("mask") + or item.startswith("mask") and int(item.rsplit("_", maxsplit=1)[-1]) < len(self._mask)] + + self._phases = [[s for s in p if s in self._flow] for p in self._phases + if any(t in p for t in self._flow)] + + for queue in self._queues: + queue_manager.del_queue(queue) + del self._queues + self._queues = self._add_queues() + + logger.warning("The following maskers have been disabled due to unsupported landmarks: %s", + rem_maskers) + + def import_data(self, input_location: str) -> None: + """ Import json data to the detector and/or aligner if 'import' plugin has been selected + + Parameters + ---------- + input_location: str + Full path to the input location for the extract process + """ + assert self._detect is not None + import_plugins: list[DetectImport | AlignImport] = [ + p for p in (self._detect, self.aligner) # type:ignore[misc] + if T.cast(str, p.name).lower() == "external"] + + if not import_plugins: + return + + align_origin = None + if len(import_plugins) == 2: + align_origin = import_plugins[-1].origin + + logger.info("Importing external data for %s from json file...", + " and ".join([p.__class__.__name__ for p in import_plugins])) + + folder = input_location + folder = folder if os.path.isdir(folder) else os.path.dirname(folder) + + last_fname = "" + is_68_point = True + data = {} + for plugin in import_plugins: + plugin_type = plugin.__class__.__name__ + path = os.path.join(folder, plugin.file_name) + if not os.path.isfile(path): + raise FaceswapError(f"{plugin_type} import file could not be found at '{path}'") + + if path != last_fname: # Different import file for aligner data + last_fname = path + data = get_serializer("json").load(path) + + if plugin_type == "Detect": + plugin.import_data(data, align_origin) # type:ignore[call-arg] + else: + plugin.import_data(data) # type:ignore[call-arg] + is_68_point = plugin.landmark_type == LandmarkType.LM_2D_68 # type:ignore[union-attr] # noqa:E501 # pylint:disable="line-too-long" + + if not is_68_point: + self._disable_lm_maskers() + + logger.info("Imported external data") + + # <<< INTERNAL METHODS >>> # + @property + def _parallel_scaling(self) -> dict[int, float]: + """ dict: key is number of parallel plugins being loaded, value is the scaling factor that + the total base vram for those plugins should be scaled by + + Notes + ----- + VRAM for parallel plugins does not stack in a linear manner. Calculating the precise + scaling for any given plugin combination is non trivial, however the following are + calculations based on running 2-5 plugins in parallel using s3fd, fan, unet, vgg-clear + and vgg-obstructed. The worst ratio is selected for each combination, plus a little extra + to ensure that vram is not used up. + + If OOM errors are being reported, then these ratios should be relaxed some more + """ + retval = {0: 1.0, + 1: 1.0, + 2: 0.7, + 3: 0.55, + 4: 0.5, + 5: 0.4} + logger.trace(retval) # type: ignore + return retval + + @property + def _vram_per_phase(self) -> dict[str, float]: + """ dict: The amount of vram required for each phase in :attr:`_flow`. """ + retval = {} + for phase in self._flow: + plugin_type, idx = self._get_plugin_type_and_index(phase) + attr = getattr(self, f"_{plugin_type}") + attr = attr[idx] if idx is not None else attr + retval[phase] = attr.vram + logger.trace(retval) # type: ignore + return retval + + @property + def _total_vram_required(self) -> float: + """ Return vram required for all phases plus the buffer """ + vrams = self._vram_per_phase + vram_required_count = sum(1 for p in vrams.values() if p > 0) + logger.debug("VRAM requirements: %s. Plugins requiring VRAM: %s", + vrams, vram_required_count) + retval = (sum(vrams.values()) * + self._parallel_scaling.get(vram_required_count, self._scaling_fallback)) + logger.debug("Total VRAM required: %s", retval) + return retval + + @property + def _current_phase(self) -> list[str]: + """ list: The current phase from :attr:`_phases` that is running through the extractor. """ + retval = self._phases[self._phase_index] + logger.trace(retval) # type: ignore + return retval + + @property + def _final_phase(self) -> str: + """ Return the final phase from the flow list """ + retval = self._flow[-1] + logger.trace(retval) # type: ignore + return retval + + @property + def _output_queue(self) -> EventQueue: + """ Return the correct output queue depending on the current phase """ + if self.final_pass: + qname = f"extract{self._instance}_{self._final_phase}_out" + else: + qname = f"extract{self._instance}_{self._phases[self._phase_index + 1][0]}_in" + retval = self._queues[qname] + logger.trace("%s: %s", qname, retval) # type: ignore + return retval + + @property + def _all_plugins(self) -> list[PluginExtractor]: + """ Return list of all plugin objects in this pipeline """ + retval = [] + for phase in self._flow: + plugin_type, idx = self._get_plugin_type_and_index(phase) + attr = getattr(self, f"_{plugin_type}") + attr = attr[idx] if idx is not None else attr + retval.append(attr) + logger.trace("All Plugins: %s", retval) # type: ignore + return retval + + @property + def _active_plugins(self) -> list[PluginExtractor]: + """ Return the plugins that are currently active based on pass """ + retval = [] + for phase in self._current_phase: + plugin_type, idx = self._get_plugin_type_and_index(phase) + attr = getattr(self, f"_{plugin_type}") + retval.append(attr[idx] if idx is not None else attr) + logger.trace("Active plugins: %s", retval) # type: ignore + return retval + + @staticmethod + def _set_flow(detector: str | None, + aligner: str | None, + masker: list[str | None], + recognition: str | None) -> list[str]: + """ Set the flow list based on the input plugins + + Parameters + ---------- + detector: str or ``None`` + The name of a detector plugin as exists in :mod:`plugins.extract.detect` + aligner: str or ``None + The name of an aligner plugin as exists in :mod:`plugins.extract.align` + masker: str or list or ``None + The name of a masker plugin(s) as exists in :mod:`plugins.extract.mask`. + This can be a single masker or a list of multiple maskers + recognition: str or ``None`` + The name of the recognition plugin to use. ``None`` to not do face recognition. + """ + logger.debug("detector: %s, aligner: %s, masker: %s recognition: %s", + detector, aligner, masker, recognition) + retval = [] + if detector is not None and detector.lower() != "none": + retval.append("detect") + if aligner is not None and aligner.lower() != "none": + retval.append("align") + if recognition is not None and recognition.lower() != "none": + retval.append("recognition") + retval.extend([f"mask_{idx}" + for idx, mask in enumerate(masker) + if mask is not None and mask.lower() != "none"]) + logger.debug("flow: %s", retval) + return retval + + @staticmethod + def _get_plugin_type_and_index(flow_phase: str) -> tuple[str, int | None]: + """ Obtain the plugin type and index for the plugin for the given flow phase. + + When multiple plugins for the same phase are allowed (e.g. Mask) this will return + the plugin type and the index of the plugin required. If only one plugin is allowed + then the plugin type will be returned and the index will be ``None``. + + Parameters + ---------- + flow_phase: str + The phase within :attr:`_flow` that is to have the plugin type and index returned + + Returns + ------- + plugin_type: str + The plugin type for the given flow phase + index: int + The index of this plugin type within the flow, if there are multiple plugins in use + otherwise ``None`` if there is only 1 plugin in use for the given phase + """ + sidx = flow_phase.split("_")[-1] + if sidx.isdigit(): + idx: int | None = int(sidx) + plugin_type = "_".join(flow_phase.split("_")[:-1]) + else: + plugin_type = flow_phase + idx = None + return plugin_type, idx + + def _add_queues(self) -> dict[str, EventQueue]: + """ Add the required processing queues to Queue Manager """ + queues = {} + tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow] + tasks.append(f"extract{self._instance}_{self._final_phase}_out") + for task in tasks: + # Limit queue size to avoid stacking ram + queue_manager.add_queue(task, maxsize=1) + queues[task] = queue_manager.get_queue(task) + logger.debug("Queues: %s", queues) + return queues + + @staticmethod + def _get_vram_stats() -> dict[str, int | str]: + """ Obtain statistics on available VRAM and subtract a constant buffer from available vram. + + Returns + ------- + dict + Statistics on available VRAM + """ + vram_buffer = 256 # Leave a buffer for VRAM allocation + assert GPUStats is not None + gpu_stats = GPUStats() + stats = gpu_stats.get_card_most_free() + retval: dict[str, int | str] = {"count": gpu_stats.device_count, + "device": stats.device, + "vram_free": int(stats.free - vram_buffer), + "vram_total": int(stats.total)} + logger.debug(retval) + return retval + + def _set_parallel_processing(self, multiprocess: bool) -> bool: + """ Set whether to run detect, align, and mask together or separately. + + Parameters + ---------- + multiprocess: bool + ``True`` if the single-process command line flag has not been set otherwise ``False`` + """ + if not multiprocess: + logger.debug("Parallel processing disabled by cli.") + return False + + if self._vram_stats["count"] == 0: + logger.debug("No GPU detected. Enabling parallel processing.") + return True + + logger.verbose("%s - %sMB free of %sMB", # type: ignore + self._vram_stats["device"], + self._vram_stats["vram_free"], + self._vram_stats["vram_total"]) + if T.cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required: + logger.warning("Not enough free VRAM for parallel processing. " + "Switching to serial") + return False + return True + + def _set_phases(self, multiprocess: bool) -> list[list[str]]: + """ If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit + into VRAM, otherwise return the single flow. + + Parameters + ---------- + multiprocess: bool + ``True`` if the single-process command line flag has not been set otherwise ``False`` + + Returns + ------- + list: + The jobs to be undertaken split into phases that fit into GPU RAM + """ + phases: list[list[str]] = [] + current_phase: list[str] = [] + available = T.cast(int, self._vram_stats["vram_free"]) + for phase in self._flow: + num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0]) + num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0 + scaling = self._parallel_scaling.get(num_plugins, self._scaling_fallback) + required = sum(self._vram_per_phase[p] for p in current_phase + [phase]) * scaling + logger.debug("Num plugins for phase: %s, scaling: %s, vram required: %s", + num_plugins, scaling, required) + if required <= available and multiprocess: + logger.debug("Required: %s, available: %s. Adding phase '%s' to current phase: %s", + required, available, phase, current_phase) + current_phase.append(phase) + elif len(current_phase) == 0 or not multiprocess: + # Amount of VRAM required to run a single plugin is greater than available. We add + # it anyway, and hope it will run with warnings, as the alternative is to not run + # at all. + # This will also run if forcing single process + logger.debug("Required: %s, available: %s. Single plugin has higher requirements " + "than available or forcing single process: '%s'", + required, available, phase) + phases.append([phase]) + else: + logger.debug("Required: %s, available: %s. Adding phase to flow: %s", + required, available, current_phase) + phases.append(current_phase) + current_phase = [phase] + if current_phase: + phases.append(current_phase) + logger.debug("Total phases: %s, Phases: %s", len(phases), phases) + return phases + + # << INTERNAL PLUGIN HANDLING >> # + def _load_align(self, + aligner: str | None, + configfile: str | None, + normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None, + re_feed: int, + re_align: bool, + disable_filter: bool) -> Aligner | None: + """ Set global arguments and load aligner plugin + + Parameters + ---------- + aligner: str + The aligner plugin to load or ``None`` for no aligner + configfile: str + Optional full path to custom config file + normalize_method: str + Optional normalization method to use + re_feed: int + The number of times to adjust the image and re-feed to get an average score + re_align: bool + ``True`` to obtain landmarks by passing the initially aligned face back through the + aligner. + disable_filter: bool + Disable all aligner filters regardless of config option + + Returns + ------- + Aligner plugin if one is specified otherwise ``None`` + """ + if aligner is None or aligner.lower() == "none": + logger.debug("No aligner selected. Returning None") + return None + aligner_name = aligner.replace("-", "_").lower() + logger.debug("Loading Aligner: '%s'", aligner_name) + plugin = PluginLoader.get_aligner(aligner_name)(configfile=configfile, + normalize_method=normalize_method, + re_feed=re_feed, + re_align=re_align, + disable_filter=disable_filter, + instance=self._instance) + return plugin + + def _load_detect(self, + detector: str | None, + aligner: str | None, + rotation: str | None, + min_size: int, + configfile: str | None) -> Detector | None: + """ Set global arguments and load detector plugin + + Parameters + ---------- + detector: str | None + The name of the face detection plugin to use. ``None`` for no detection + aligner: str | None + The name of the face aligner plugin to use. ``None`` for no aligner + rotation: str | None + The rotation to perform on detection. ``None`` for no rotation + min_size: int + The minimum size of detected faces to accept + configfile: str | None + Full path to a custom config file to use. ``None`` for default config + + Returns + ------- + :class:`~plugins.extract.detect._base.Detector` | None + The face detection plugin to use, or ``None`` if no detection to be performed + """ + if detector is None or detector.lower() == "none": + logger.debug("No detector selected. Returning None") + return None + detector_name = detector.replace("-", "_").lower() + + if aligner == "external" and detector_name != "external": + logger.warning("Unsupported '%s' detector selected for 'External' aligner. Switching " + "detector to 'External'", detector_name) + detector_name = aligner + + logger.debug("Loading Detector: '%s'", detector_name) + plugin = PluginLoader.get_detector(detector_name)(rotation=rotation, + min_size=min_size, + configfile=configfile, + instance=self._instance) + return plugin + + def _load_mask(self, + masker: str | None, + configfile: str | None) -> Masker | None: + """ Set global arguments and load masker plugin + + Parameters + ---------- + masker: str or ``none`` + The name of the masker plugin to use or ``None`` if no masker + configfile: str + Full path to custom config.ini file or ``None`` to use default + + Returns + ------- + :class:`~plugins.extract.mask._base.Masker` or ``None`` + The masker plugin to use or ``None`` if no masker selected + """ + if masker is None or masker.lower() == "none": + logger.debug("No masker selected. Returning None") + return None + masker_name = masker.replace("-", "_").lower() + logger.debug("Loading Masker: '%s'", masker_name) + plugin = PluginLoader.get_masker(masker_name)(configfile=configfile, + instance=self._instance) + return plugin + + def _load_recognition(self, + recognition: str | None, + configfile: str | None) -> Identity | None: + """ Set global arguments and load recognition plugin """ + if recognition is None or recognition.lower() == "none": + logger.debug("No recognition selected. Returning None") + return None + recognition_name = recognition.replace("-", "_").lower() + logger.debug("Loading Recognition: '%s'", recognition_name) + plugin = PluginLoader.get_recognition(recognition_name)(configfile=configfile, + instance=self._instance) + return plugin + + def _launch_plugin(self, phase: str) -> None: + """ Launch an extraction plugin """ + logger.debug("Launching %s plugin", phase) + in_qname = f"extract{self._instance}_{phase}_in" + if phase == self._final_phase: + out_qname = f"extract{self._instance}_{self._final_phase}_out" + else: + next_phase = self._flow[self._flow.index(phase) + 1] + out_qname = f"extract{self._instance}_{next_phase}_in" + logger.debug("in_qname: %s, out_qname: %s", in_qname, out_qname) + kwargs = {"in_queue": self._queues[in_qname], "out_queue": self._queues[out_qname]} + + plugin_type, idx = self._get_plugin_type_and_index(phase) + plugin = getattr(self, f"_{plugin_type}") + plugin = plugin[idx] if idx is not None else plugin + plugin.initialize(**kwargs) + plugin.start() + logger.debug("Launched %s plugin", phase) + + def _set_plugins_batchsize(self, gpu_plugins: list[str], vram_free: int) -> None: + """ Set the batch size for the current phase so that it will fit in available VRAM. + + Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to + zero division error. + + Reduces the batchsize of the plugin which has a batch size > 1 and the largest VRAM + requirements. The final reduction is the plugin which has a batch size > 1 and the + smallest VRAM requirements that would fit the pipeline inside VRAM + + Parameters + ---------- + gpu_plugins: list[str] + The name of the plugins that use the GPU for the current phase + vram_free: int + The amount of available VRAM, in MBs + """ + logger.debug("GPU plugins: %s, Available vram: %s", gpu_plugins, vram_free) + plugins = [self._active_plugins[idx] + for idx, plugin in enumerate(self._current_phase) + if plugin in gpu_plugins] + base_vram = sum(p.vram for p in plugins) + vram_free = vram_free - base_vram + logger.debug("Base vram: %s, remaining vram: %s", base_vram, vram_free) + + to_allocate = [(p.batchsize, p.vram_per_batch) for p in plugins] + excess = sum(a[0] * a[1] for a in to_allocate) - vram_free + logger.debug("Plugins to allocate: %s, excess vram: %s", to_allocate, excess) + + while excess > 0: + chosen = next(p for p in to_allocate + if p[0] > 1 and p[1] == max(p[1] for p in to_allocate if p[0] > 1)) + + if excess - chosen[1] <= 0: + chosen = next(p for p in to_allocate + if p[0] > 1 and p[1] == min(p[1] for p in to_allocate + if p[0] > 1 and p[1] >= excess)) + + excess -= chosen[1] + logger.debug("Reducing batch size for item %s. Remaining %s", chosen, excess) + to_allocate[to_allocate.index(chosen)] = (chosen[0] - 1, chosen[1]) + + msg = [] + for plugin, alloc in zip(plugins, to_allocate): + if plugin.batchsize != alloc[0]: + logger.debug("Updating batchsize for plugin %s from %s to %s", + plugin.name, plugin.batchsize, alloc[0]) + plugin.batchsize = alloc[0] + msg.append(f"{plugin.__class__.__name__}: {plugin.batchsize}") + + logger.info("Reset batch sizes due to available VRAM: %s", ", ".join(msg)) + + def _set_extractor_batchsize(self) -> None: + """ + Sets the batch size of the requested plugins based on their vram, their + vram_per_batch_requirements and the number of plugins being loaded in the current phase. + Only adjusts if the the configured batch size requires more vram than is available. + """ + backend = get_backend() + if backend not in ("nvidia", "rocm"): + logger.debug("Not updating batchsize requirements for backend: '%s'", backend) + return + if sum(plugin.vram for plugin in self._active_plugins) == 0: + logger.debug("No plugins use VRAM. Not updating batchsize requirements.") + return + + batch_required = sum(plugin.vram_per_batch * plugin.batchsize + for plugin in self._active_plugins) + + gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0] + + scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback) + plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling + + vram_free = T.cast(int, self._vram_stats["vram_free"]) + total_required = plugins_required + batch_required + if total_required <= vram_free: + logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, " + "vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"]) + return + + self._set_plugins_batchsize(gpu_plugins, vram_free) + + def _join_threads(self): + """ Join threads for current pass """ + for plugin in self._active_plugins: + plugin.join() + + def _check_and_raise_error(self) -> None: + """ Check all threads for errors and raise if one occurs """ + for plugin in self._active_plugins: + plugin.check_and_raise_error() + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/recognition/__init__.py b/plugins/extract/recognition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/extract/recognition/_base.py b/plugins/extract/recognition/_base.py new file mode 100644 index 0000000000..e00abbe2a7 --- /dev/null +++ b/plugins/extract/recognition/_base.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +""" Base class for Face Recognition plugins + +All Recognition Plugins should inherit from this class. +See the override methods for which methods are required. + +The plugin will receive a :class:`~plugins.extract.extract_media.ExtractMedia` object. + +For each source frame, the plugin must pass a dict to finalize containing: + +>>> {'filename': , +>>> 'detected_faces': >> face = self.to_detected_face(, , , ) +""" +from __future__ import annotations +import logging +import typing as T + +from dataclasses import dataclass, field + +import numpy as np +from torch.cuda import OutOfMemoryError + +from lib.align import AlignedFace, DetectedFace, LandmarkType +from lib.image import read_image_meta +from lib.utils import FaceswapError +from plugins.extract import ExtractMedia, extract_config as cfg +from plugins.extract._base import BatchType, ExtractorBatch, Extractor + +if T.TYPE_CHECKING: + from collections.abc import Generator + from queue import Queue + from lib.align.aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +@dataclass +class RecogBatch(ExtractorBatch): + """ Dataclass for holding items flowing through the aligner. + + Inherits from :class:`~plugins.extract._base.ExtractorBatch` + """ + detected_faces: list[DetectedFace] = field(default_factory=list) + feed_faces: list[AlignedFace] = field(default_factory=list) + + +class Identity(Extractor): # pylint:disable=abstract-method + """ Face Recognition Object + + Parent class for all Recognition plugins + + Parameters + ---------- + git_model_id: int + The second digit in the github tag that identifies this model. See + https://github.com/deepfakes-models/faceswap-models for more information + model_filename: str + The name of the model file to be loaded + + Other Parameters + ---------------- + configfile: str, optional + Path to a custom configuration ``ini`` file. Default: Use system configfile + + See Also + -------- + plugins.extract.pipeline : The extraction pipeline for calling plugins + plugins.extract.detect : Detector plugins + plugins.extract._base : Parent class for all extraction plugins + plugins.extract.align._base : Aligner parent class for extraction plugins. + plugins.extract.mask._base : Masker parent class for extraction plugins. + """ + + _logged_lm_count_once = False + + def __init__(self, + git_model_id: int | None = None, + model_filename: str | None = None, + configfile: str | None = None, + instance: int = 0, + **kwargs): + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(git_model_id, # pylint:disable=duplicate-code + model_filename, + configfile=configfile, + instance=instance, + **kwargs) + self.input_size = 256 # Override for model specific input_size + self.centering: CenteringType = "legacy" # Override for model specific centering + self.coverage_ratio = 1.0 # Override for model specific coverage_ratio + + self._info.plugin_type = "recognition" + self._filter = IdentityFilter(cfg.save_filtered()) + logger.debug("Initialized _base %s", self.__class__.__name__) + + def _get_detected_from_aligned(self, item: ExtractMedia) -> None: + """ Obtain detected face objects for when loading in aligned faces and a detected face + object does not exist + + Parameters + ---------- + item: :class:`~plugins.extract.extract_media.ExtractMedia` + The extract media to populate the detected face for + """ + detected_face = DetectedFace() + meta = read_image_meta(item.filename).get("itxt", {}).get("alignments") + if meta: + detected_face.from_png_meta(meta) + item.add_detected_faces([detected_face]) + self._tracker.faces_per_filename[item.filename] += 1 # Track this added face + logger.debug("Obtained detected face: (filename: %s, detected_face: %s)", + item.filename, item.detected_faces) + + def _maybe_log_warning(self, face: AlignedFace) -> None: + """ Log a warning, once, if we do not have full facial landmarks + + Parameters + ---------- + face: :class:`~lib.align.aligned_face.AlignedFace` + The aligned face object to test the landmark type for + """ + if face.landmark_type != LandmarkType.LM_2D_4 or self._logged_lm_count_once: + return + logger.warning("Extracted faces do not contain facial landmark data. '%s' " + "identity data is likely to be sub-standard.", self.name) + self._logged_lm_count_once = True + + def get_batch(self, queue: Queue) -> tuple[bool, RecogBatch]: + """ Get items for inputting into the recognition from the queue in batches + + Items are returned from the ``queue`` in batches of + :attr:`~plugins.extract._base.Extractor.batchsize` + + Items are received as :class:`~plugins.extract.extract_media.ExtractMedia` objects and + converted to :class:`RecogBatch` for internal processing. + + To ensure consistent batch sizes for masker the items are split into separate items for + each :class:`~lib.align.DetectedFace` object. + + Remember to put ``'EOF'`` to the out queue after processing + the final batch + + Outputs items in the following format. All lists are of length + :attr:`~plugins.extract._base.Extractor.batchsize`: + + >>> {'filename': [], + >>> 'detected_faces': [[ RecogBatch: + """ Just return the recognition's predict function """ + # pylint:disable=duplicate-code + assert isinstance(batch, RecogBatch) + # slightly hacky workaround to deal with landmarks based masks: + try: + batch.prediction = self.predict(batch.feed) + except OutOfMemoryError as err: + msg = ("You do not have enough GPU memory available to run recognition at the " + "selected batch size. You can try a number of things:" + "\n1) Close any other application that is using your GPU (web browsers are " + "particularly bad for this)." + "\n2) Lower the batchsize (the amount of images fed into the model) by " + "editing the plugin settings (GUI: Settings > Configure extract settings, " + "CLI: Edit the file faceswap/config/extract.ini)." + "\n3) Enable 'Single Process' mode.") + raise FaceswapError(msg) from err + + return batch + + def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]: + """ Finalize the output from Masker + + This should be called as the final task of each `plugin`. + + Pairs the detected faces back up with their original frame before yielding each frame. + + Parameters + ---------- + batch : :class:`RecogBatch` + The final batch item from the `plugin` process. + + Yields + ------ + :class:`~plugins.extract.extract_media.ExtractMedia` + The :attr:`DetectedFaces` list will be populated for this class with the bounding + boxes, landmarks and masks for the detected faces found in the frame. + """ + assert isinstance(batch, RecogBatch) + assert isinstance(self.name, str) + for identity, face in zip(batch.prediction, batch.detected_faces): + face.add_identity(self.name.lower(), identity) + del batch.feed + + logger.trace("Item out: %s", # type: ignore + {key: val.shape if isinstance(val, np.ndarray) else val + for key, val in batch.__dict__.items()}) + + for filename, face in zip(batch.filename, batch.detected_faces): + self._tracker.output_faces.append(face) + if len(self._tracker.output_faces) != self._tracker.faces_per_filename[filename]: + continue + + output = self._extract_media.pop(filename) + self._tracker.output_faces = self._filter(self._tracker.output_faces, + output.sub_folders) + + output.add_detected_faces(self._tracker.output_faces) + self._tracker.output_faces = [] + logger.trace("Yielding: (filename: '%s', image: %s, " # type:ignore[attr-defined] + "detected_faces: %s)", output.filename, output.image_shape, + len(output.detected_faces)) + yield output + + def add_identity_filters(self, + filters: np.ndarray, + nfilters: np.ndarray, + threshold: float) -> None: + """ Add identity encodings to filter by identity in the recognition plugin + + Parameters + ---------- + filters: :class:`numpy.ndarray` + The array of filter embeddings to use + nfilters: :class:`numpy.ndarray` + The array of nfilter embeddings to use + threshold: float + The threshold for a positive filter match + """ + logger.debug("Adding identity filters") + self._filter.add_filters(filters, nfilters, threshold) + logger.debug("Added identity filters") + + +class IdentityFilter(): + """ Applies filters on the output of the recognition plugin + + Parameters + ---------- + save_output: bool + ``True`` if the filtered faces should be kept as they are being saved. ``False`` if they + should be deleted + """ + def __init__(self, save_output: bool) -> None: + logger.debug("Initializing %s: (save_output: %s)", self.__class__.__name__, save_output) + self._save_output = save_output + self._filter: np.ndarray | None = None + self._nfilter: np.ndarray | None = None + self._threshold = 0.0 + self._filter_enabled: bool = False + self._nfilter_enabled: bool = False + self._active: bool = False + self._counts = 0 + logger.debug("Initialized %s", self.__class__.__name__) + + def add_filters(self, filters: np.ndarray, nfilters: np.ndarray, threshold) -> None: + """ Add identity encodings to the filter and set whether each filter is enabled + + Parameters + ---------- + filters: :class:`numpy.ndarray` + The array of filter embeddings to use + nfilters: :class:`numpy.ndarray` + The array of nfilter embeddings to use + threshold: float + The threshold for a positive filter match + """ + logger.debug("Adding filters: %s, nfilters: %s, threshold: %s", + filters.shape, nfilters.shape, threshold) + self._filter = filters + self._nfilter = nfilters + self._threshold = threshold + self._filter_enabled = bool(np.any(self._filter)) + self._nfilter_enabled = bool(np.any(self._nfilter)) + self._active = self._filter_enabled or self._nfilter_enabled + logger.debug("filter active: %s, nfilter active: %s, all active: %s", + self._filter_enabled, self._nfilter_enabled, self._active) + + @classmethod + def _find_cosine_similiarity(cls, + source_identities: np.ndarray, + test_identity: np.ndarray) -> np.ndarray: + """ Find the cosine similarity between a source face identity and a test face identity + + Parameters + --------- + source_identities: :class:`numpy.ndarray` + The identity encoding for the source face identities + test_identity: :class:`numpy.ndarray` + The identity encoding for the face identity to test against the sources + + Returns + ------- + :class:`numpy.ndarray`: + The cosine similarity between a face identity and the source identities + """ + s_norm = np.linalg.norm(source_identities, axis=1) + i_norm = np.linalg.norm(test_identity) + retval = source_identities @ test_identity / (s_norm * i_norm) + return retval + + def _get_matches(self, + filter_type: T.Literal["filter", "nfilter"], + identities: np.ndarray) -> np.ndarray: + """ Obtain the average and minimum distances for each face against the source identities + to test against + + Parameters + ---------- + filter_type ["filter", "nfilter"] + The filter type to use for calculating the distance + identities: :class:`numpy.ndarray` + The identity encodings for the current face(s) being checked + + Returns + ------- + :class:`numpy.ndarray` + Boolean array. ``True`` if identity should be filtered otherwise ``False`` + """ + encodings = self._filter if filter_type == "filter" else self._nfilter + assert encodings is not None + distances = np.array([self._find_cosine_similiarity(encodings, identity) + for identity in identities]) + is_match = np.any(distances >= self._threshold, axis=-1) + # Invert for filter (set the `True` match to `False` for should filter) + retval = np.invert(is_match) if filter_type == "filter" else is_match + logger.trace("filter_type: %s, distances shape: %s, is_match: %s, ", # type: ignore + "retval: %s", filter_type, distances.shape, is_match, retval) + return retval + + def _filter_faces(self, + faces: list[DetectedFace], + sub_folders: list[str | None], + should_filter: list[bool]) -> list[DetectedFace]: + """ Filter the detected faces, either removing filtered faces from the list of detected + faces or setting the output subfolder to `"_identity_filt"` for any filtered faces if + saving output is enabled. + + Parameters + ---------- + faces: list + List of detected face objects to filter out on size + sub_folders: list + List of subfolder locations for any faces that have already been filtered when + config option `save_filtered` has been enabled. + should_filter: list + List of 'bool' corresponding to face that have not already been marked for filtering. + ``True`` indicates face should be filtered, ``False`` indicates face should be kept + + Returns + ------- + detected_faces: list + The filtered list of detected face objects, if saving filtered faces has not been + selected or the full list of detected faces + """ + retval: list[DetectedFace] = [] + self._counts += sum(should_filter) + for idx, face in enumerate(faces): + fldr = sub_folders[idx] + if fldr is not None: + # Saving to sub folder is selected and face is already filtered + # so this face was excluded from identity check + retval.append(face) + continue + to_filter = should_filter.pop(0) + if not to_filter or self._save_output: + # Keep the face if not marked as filtered or we are to output to a subfolder + retval.append(face) + if to_filter and self._save_output: + sub_folders[idx] = "_identity_filt" + + return retval + + def __call__(self, + faces: list[DetectedFace], + sub_folders: list[str | None]) -> list[DetectedFace]: + """ Call the identity filter function + + Parameters + ---------- + faces: list + List of detected face objects to filter out on size + sub_folders: list + List of subfolder locations for any faces that have already been filtered when + config option `save_filtered` has been enabled. + + Returns + ------- + detected_faces: list + The filtered list of detected face objects, if saving filtered faces has not been + selected or the full list of detected faces + """ + if not self._active: + return faces + + identities = np.array([face.identity["vggface2"] for face, fldr in zip(faces, sub_folders) + if fldr is None]) + logger.trace("face_count: %s, already_filtered: %s, identity_shape: %s", # type: ignore + len(faces), sum(x is not None for x in sub_folders), identities.shape) + + if not np.any(identities): + logger.trace("All faces already filtered: %s", sub_folders) # type: ignore + return faces + + should_filter: list[np.ndarray] = [] + for f_type in T.get_args(T.Literal["filter", "nfilter"]): + if not getattr(self, f"_{f_type}_enabled"): + continue + should_filter.append(self._get_matches(f_type, identities)) + + # If any of the filter or nfilter evaluate to 'should filter' then filter out face + final_filter: list[bool] = np.array(should_filter).max(axis=0).tolist() + logger.trace("should_filter: %s, final_filter: %s", # type: ignore + should_filter, final_filter) + return self._filter_faces(faces, sub_folders, final_filter) + + def output_counts(self): + """ Output the counts of filtered items """ + if not self._active or not self._counts: + return + logger.info("Identity filtered (%s): %s", self._threshold, self._counts) diff --git a/plugins/extract/recognition/vgg_face2.py b/plugins/extract/recognition/vgg_face2.py new file mode 100644 index 0000000000..76776fc702 --- /dev/null +++ b/plugins/extract/recognition/vgg_face2.py @@ -0,0 +1,598 @@ +#!/usr/bin python3 +""" VGG_Face2 inference and sorting """ + +from __future__ import annotations +import logging +import typing as T + +import numpy as np +import psutil +from fastcluster import linkage, linkage_vector +from keras.layers import (Activation, add, AveragePooling2D, BatchNormalization, Conv2D, Dense, + Flatten, Input, MaxPooling2D) +from keras.models import Model +from keras.regularizers import L2 + +from lib.logger import parse_class_init +from lib.model.layers import L2Normalize +from lib.utils import get_module_objects, FaceswapError +from ._base import BatchType, RecogBatch, Identity +from . import vgg_face2_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + from collections.abc import Generator + +logger = logging.getLogger(__name__) + + +class Recognition(Identity): + """ VGG Face feature extraction. + + Extracts feature vectors from faces in order to compare similarity. + + Notes + ----- + Input images should be in BGR Order + + Model exported from: https://github.com/WeidiXie/Keras-VGGFace2-ResNet50 which is based on: + https://www.robots.ox.ac.uk/~vgg/software/vgg_face/ + + + Licensed under Creative Commons Attribution License. + https://creativecommons.org/licenses/by-nc/4.0/ + """ + + def __init__(self, **kwargs) -> None: + logger.debug("Initializing %s", self.__class__.__name__) + git_model_id = 10 + model_filename = "vggface2_resnet50_v2.h5" + super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.model: Model + self.name: str = "VGGFace2" + self.input_size = 224 + self.color_format = "BGR" + + self.vram = 384 if not cfg.cpu() else 0 # 334 in testing + self.vram_per_batch = 192 if not cfg.cpu() else 0 # ~155 in testing + self.batchsize = cfg.batch_size() + + # Average image provided in https://github.com/ox-vgg/vgg_face2 + self._average_img = np.array([91.4953, 103.8827, 131.0912]) + logger.debug("Initialized %s", self.__class__.__name__) + + # <<< GET MODEL >>> # + def init_model(self) -> None: + """ Initialize VGG Face 2 Model. """ + assert isinstance(self.model_path, str) + placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3), + dtype="float32") + + with self.get_device_context(cfg.cpu()): + self.model = VGGFace2(self.input_size, self.model_path, self.batchsize) + self.model(placeholder) + + def process_input(self, batch: BatchType) -> None: + """ Compile the detected faces for prediction """ + assert isinstance(batch, RecogBatch) + batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3] + for feed in batch.feed_faces], + dtype="float32") - self._average_img + logger.trace("feed shape: %s", batch.feed.shape) # type:ignore[attr-defined] + + def predict(self, feed: np.ndarray) -> np.ndarray: + """ Return encodings for given image from vgg_face2. + + Parameters + ---------- + batch: numpy.ndarray + The face to be fed through the predictor. Should be in BGR channel order + + Returns + ------- + numpy.ndarray + The encodings for the face + """ + with self.get_device_context(cfg.cpu()): + retval = self.model(feed) + assert isinstance(retval, np.ndarray) + return retval + + def process_output(self, batch: BatchType) -> None: + """ No output processing for vgg_face2 """ + return + + +class ResNet50: + """ ResNet50 imported for VGG-Face2 adapted from + https://github.com/WeidiXie/Keras-VGGFace2-ResNet50 + + Parameters + ---------- + input_shape, Tuple[int, int, int] | None, optional + The input shape for the model. Default: ``None`` + use_truncated: bool, optional + ``True`` to use a truncated version of resnet. Default ``False`` + weight_decay: float + L2 Regularizer weight decay. Default: 1e-4 + trainable: bool, optional + ``True`` if the block should be trainable. Default: ``True`` + """ + def __init__(self, + input_shape: tuple[int, int, int] | None = None, + use_truncated: bool = False, + weight_decay: float = 1e-4, + trainable: bool = True) -> None: + logger.debug("Initializing %s: input_shape: %s, use_truncated: %s, weight_decay: %s, " + "trainable: %s", self.__class__.__name__, input_shape, use_truncated, + weight_decay, trainable) + + self._input_shape = (None, None, 3) if input_shape is None else input_shape + self._weight_decay = weight_decay + self._trainable = trainable + + self._kernel_initializer = "orthogonal" + self._use_bias = False + self._bn_axis = 3 + self._block_suffix = {0: "_reduce", 1: "", 2: "_increase"} + + self._identity_calls = [2, 3, 5, 2] + self._filters = [(64, 64, 256), (128, 128, 512), (256, 256, 1024), (512, 512, 2048)] + if use_truncated: + self._identity_calls = self._identity_calls[:-1] + self._filters = self._filters[:-1] + + logger.debug("Initialized %s", self.__class__.__name__) + + def _identity_block(self, + inputs: KerasTensor, + kernel_size: int, + filters: tuple[int, int, int], + stage: int, + block: int) -> KerasTensor: + """ The identity block is the block that has no conv layer at shortcut. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor + kernel_size: int + The kernel size of middle conv layer of the block + filters: tuple[int, int, int[ + The filterss of 3 conv layers in the main path + stage: int + The current stage label, used for generating layer names + block: int + The current block label, used for generating layer names + + Returns + ------- + :class:`keras.KerasTensor` + Output tensor for the block + """ + assert len(filters) == 3 + var_x = inputs + + for idx, filts in enumerate(filters): + k_size = kernel_size if idx == 1 else 1 + conv_name = f"conv{stage}_{block}_{k_size}x{k_size}{self._block_suffix[idx]}" + bn_name = f"{conv_name}_bn" + + var_x = Conv2D(filts, + k_size, + padding="same" if idx == 1 else "valid", + kernel_initializer=self._kernel_initializer, + use_bias=self._use_bias, + kernel_regularizer=L2(self._weight_decay), + trainable=self._trainable, + name=conv_name)(var_x) + var_x = BatchNormalization(axis=self._bn_axis, name=bn_name)(var_x) + if idx < 2: + var_x = Activation("relu")(var_x) + + var_x = add([var_x, inputs]) + var_x = Activation("relu")(var_x) + return var_x + + def _conv_block(self, + inputs: KerasTensor, + kernel_size: int, + filters: tuple[int, int, int], + stage: int, + block: int, + strides: tuple[int, int] = (2, 2)) -> KerasTensor: + """ A block that has a conv layer at shortcut. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor + kernel_size: int + The kernel size of middle conv layer of the block + filters: tuple[int, int, int[ + The filterss of 3 conv layers in the main path + stage: int + The current stage label, used for generating layer names + block: int + The current block label, used for generating layer names + strides: tuple[int, int], optional + The stride length for the first and last convolution. Default: (2, 2) + + Returns + ------- + :class:`keras.KerasTensor` + Output tensor for the block + + Notes + ----- + From stage 3, the first conv layer at main path is with `strides = (2,2)` and the shortcut + should have `strides = (2,2)` as well + """ + assert len(filters) == 3 + var_x = inputs + + for idx, filts in enumerate(filters): + k_size = kernel_size if idx == 1 else 1 + conv_name = f"conv{stage}_{block}_{k_size}x{k_size}{self._block_suffix[idx]}" + bn_name = f"{conv_name}_bn" + + var_x = Conv2D(filts, + k_size, + strides=strides if idx == 0 else (1, 1), + padding="same" if idx == 1 else "valid", + kernel_initializer=self._kernel_initializer, + use_bias=self._use_bias, + kernel_regularizer=L2(self._weight_decay), + trainable=self._trainable, + name=conv_name)(var_x) + var_x = BatchNormalization(axis=self._bn_axis, name=bn_name)(var_x) + if idx < 2: + var_x = Activation("relu")(var_x) + + conv_name = f"conv{stage}_{block}_1x1_proj" + bn_name = f"{conv_name}_bn" + + shortcut = Conv2D(filters[-1], + (1, 1), + strides=strides, + kernel_initializer=self._kernel_initializer, + use_bias=self._use_bias, + kernel_regularizer=L2(self._weight_decay), + trainable=self._trainable, + name=conv_name)(inputs) + shortcut = BatchNormalization(axis=self._bn_axis, name=bn_name)(shortcut) + + var_x = add([var_x, shortcut]) + var_x = Activation("relu")(var_x) + return var_x + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the resnet50 Network + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + Input tensor + + Returns + ------- + :class::class:`keras.KerasTensor` + Output tensor from resnet50 + """ + var_x = Conv2D(64, + (7, 7), + strides=(2, 2), + padding="same", + use_bias=self._use_bias, + kernel_initializer=self._kernel_initializer, + kernel_regularizer=L2(self._weight_decay), + trainable=self._trainable, + name="conv1_7x7_s2")(inputs) + + var_x = BatchNormalization(axis=self._bn_axis, name="conv1_7x7_s2_bn")(var_x) + var_x = Activation("relu")(var_x) + var_x = MaxPooling2D((3, 3), strides=(2, 2))(var_x) + + for idx, (recursuions, filters) in enumerate(zip(self._identity_calls, self._filters)): + stage = idx + 2 + strides = (1, 1) if stage == 2 else (2, 2) + var_x = self._conv_block(var_x, 3, filters, stage=stage, block=1, strides=strides) + + for recursion in range(recursuions): + block = recursion + 2 + var_x = self._identity_block(var_x, 3, filters, stage=stage, block=block) + + return var_x + + +class VGGFace2(): + """ VGG-Face 2 model with resnet 50 backbone. Adapted from + https://github.com/WeidiXie/Keras-VGGFace2-ResNet50 + + Parameters + ---------- + input_size, int + The input size for the model. + weights_path: str + The path to the keras weights file + batch_size: int + The batch size to feed the model + num_class: int, optional + Number of classes to train the model on + weight_decay: float + L2 Regularizer weight decay. Default: 1e-4 + """ + def __init__(self, + input_size: int, + weights_path: str, + batch_size: int, + num_classes: int = 8631, + weight_decay: float = 1e-4) -> None: + logger.debug(parse_class_init(locals())) + self._input_shape = (input_size, input_size, 3) + self._batch_size = batch_size + self._weight_decay = weight_decay + self._num_classes = num_classes + self._resnet = ResNet50(input_shape=self._input_shape, weight_decay=self._weight_decay) + self._model = self._load_model(weights_path) + logger.debug("Initialized %s", self.__class__.__name__) + + def _load_model(self, weights_path: str) -> Model: + """ load the vgg-face2 model + + Parameters + ---------- + weights_path: str + Full path to the model's weights + + Returns + ------- + :class:`keras.models.Model` + The VGG-Obstructed model + """ + inputs = Input(self._input_shape) + var_x = self._resnet(inputs) + + var_x = AveragePooling2D((7, 7), name="avg_pool")(var_x) + var_x = Flatten()(var_x) + var_x = Dense(512, activation="relu", name="dim_proj")(var_x) + var_x = L2Normalize(axis=1)(var_x) + + retval = Model(inputs, var_x) + retval.load_weights(weights_path) + retval.make_predict_function() + return retval + + def __call__(self, inputs: np.ndarray) -> np.ndarray: + """ Get output from the vgg-face2 model + + Parameters + ---------- + inputs: :class:`numpy.ndarray` + The input to vgg-face2 + + Returns + ------- + :class:`numpy.ndarray` + The output from vgg-face2 + """ + return self._model.predict(inputs, verbose=0, batch_size=self._batch_size) + + +class Cluster(): + """ Cluster the outputs from a VGG-Face 2 Model + + Parameters + ---------- + predictions: numpy.ndarray + A stacked matrix of vgg_face2 predictions of the shape (`N`, `D`) where `N` is the + number of observations and `D` are the number of dimensions. NB: The given + :attr:`predictions` will be overwritten to save memory. If you still require the + original values you should take a copy prior to running this method + method: ['single','centroid','median','ward'] + The clustering method to use. + threshold: float, optional + The threshold to start creating bins for. Set to ``None`` to disable binning + """ + + def __init__(self, + predictions: np.ndarray, + method: T.Literal["single", "centroid", "median", "ward"], + threshold: float | None = None) -> None: + logger.debug("Initializing: %s (predictions: %s, method: %s, threshold: %s)", + self.__class__.__name__, predictions.shape, method, threshold) + self._num_predictions = predictions.shape[0] + + self._should_output_bins = threshold is not None + self._threshold = 0.0 if threshold is None else threshold + self._bins: dict[int, int] = {} + self._iterator = self._integer_iterator() + + self._result_linkage = self._do_linkage(predictions, method) + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _integer_iterator(cls) -> Generator[int, None, None]: + """ Iterator that just yields consecutive integers """ + i = -1 + while True: + i += 1 + yield i + + def _use_vector_linkage(self, dims: int) -> bool: + """ Calculate the RAM that will be required to sort these images and select the appropriate + clustering method. + + From fastcluster documentation: + "While the linkage method requires Θ(N:sup:`2`) memory for clustering of N points, this + [vector] method needs Θ(N D)for N points in RD, which is usually much smaller." + also: + "half the memory can be saved by specifying :attr:`preserve_input`=``False``" + + To avoid under calculating we divide the memory calculation by 1.8 instead of 2 + + Parameters + ---------- + dims: int + The number of dimensions in the vgg_face output + + Returns + ------- + bool: + ``True`` if vector_linkage should be used. ``False`` if linkage should be used + """ + np_float = 24 # bytes size of a numpy float + divider = 1024 * 1024 # bytes to MB + + free_ram = psutil.virtual_memory().available / divider + linkage_required = (((self._num_predictions ** 2) * np_float) / 1.8) / divider + vector_required = ((self._num_predictions * dims) * np_float) / divider + logger.debug("free_ram: %sMB, linkage_required: %sMB, vector_required: %sMB", + int(free_ram), int(linkage_required), int(vector_required)) + + if linkage_required < free_ram: + logger.verbose("Using linkage method") # type:ignore[attr-defined] + retval = False + elif vector_required < free_ram: + logger.warning("Not enough RAM to perform linkage clustering. Using vector " + "clustering. This will be significantly slower. Free RAM: %sMB. " + "Required for linkage method: %sMB", + int(free_ram), int(linkage_required)) + retval = True + else: + raise FaceswapError("Not enough RAM available to sort faces. Try reducing " + f"the size of your dataset. Free RAM: {int(free_ram)}MB. " + f"Required RAM: {int(vector_required)}MB") + logger.debug(retval) + return retval + + def _do_linkage(self, + predictions: np.ndarray, + method: T.Literal["single", "centroid", "median", "ward"]) -> np.ndarray: + """ Use FastCluster to perform vector or standard linkage + + Parameters + ---------- + predictions: :class:`numpy.ndarray` + A stacked matrix of vgg_face2 predictions of the shape (`N`, `D`) where `N` is the + number of observations and `D` are the number of dimensions. + method: ['single','centroid','median','ward'] + The clustering method to use. + + Returns + ------- + :class:`numpy.ndarray` + The [`num_predictions`, 4] linkage vector + """ + dims = predictions.shape[-1] + if self._use_vector_linkage(dims): + retval = linkage_vector(predictions, method=method) + else: + retval = linkage(predictions, method=method, preserve_input=False) + logger.debug("Linkage shape: %s", retval.shape) + return retval + + def _process_leaf_node(self, + current_index: int, + current_bin: int) -> list[tuple[int, int]]: + """ Process the output when we have hit a leaf node """ + if not self._should_output_bins: + return [(current_index, 0)] + + if current_bin not in self._bins: + next_val = 0 if not self._bins else max(self._bins.values()) + 1 + self._bins[current_bin] = next_val + return [(current_index, self._bins[current_bin])] + + def _get_bin(self, + tree: np.ndarray, + points: int, + current_index: int, + current_bin: int) -> int: + """ Obtain the bin that we are currently in. + + If we are not currently below the threshold for binning, get a new bin ID from the integer + iterator. + + Parameters + ---------- + tree: numpy.ndarray + A hierarchical tree (dendrogram) + points: int + The number of points given to the clustering process + current_index: int + The position in the tree for the recursive traversal + current_bin int, optional + The ID for the bin we are currently in. Only used when binning is enabled + + Returns + ------- + int + The current bin ID for the node + """ + if tree[current_index - points, 2] >= self._threshold: + current_bin = next(self._iterator) + logger.debug("Creating new bin ID: %s", current_bin) + return current_bin + + def _seriation(self, + tree: np.ndarray, + points: int, + current_index: int, + current_bin: int = 0) -> list[tuple[int, int]]: + """ Seriation method for sorted similarity. + + Seriation computes the order implied by a hierarchical tree (dendrogram). + + Parameters + ---------- + tree: numpy.ndarray + A hierarchical tree (dendrogram) + points: int + The number of points given to the clustering process + current_index: int + The position in the tree for the recursive traversal + current_bin int, optional + The ID for the bin we are currently in. Only used when binning is enabled + + Returns + ------- + list: + The indices in the order implied by the hierarchical tree + """ + if current_index < points: # Output the leaf node + return self._process_leaf_node(current_index, current_bin) + + if self._should_output_bins: + current_bin = self._get_bin(tree, points, current_index, current_bin) + + left = int(tree[current_index-points, 0]) + right = int(tree[current_index-points, 1]) + + serate_left = self._seriation(tree, points, left, current_bin=current_bin) + serate_right = self._seriation(tree, points, right, current_bin=current_bin) + + return serate_left + serate_right # type: ignore + + def __call__(self) -> list[tuple[int, int]]: + """ Process the linkages. + + Transforms a distance matrix into a sorted distance matrix according to the order implied + by the hierarchical tree (dendrogram). + + Returns + ------- + list: + List of indices with the order implied by the hierarchical tree or list of tuples of + (`index`, `bin`) if a binning threshold was provided + """ + logger.info("Sorting face distances. Depending on your dataset this may take some time...") + if self._threshold: + self._threshold = self._result_linkage[:, 2].max() * self._threshold + result_order = self._seriation(self._result_linkage, + self._num_predictions, + self._num_predictions + self._num_predictions - 2) + return result_order + + +__all__ = get_module_objects(__name__) diff --git a/plugins/extract/recognition/vgg_face2_defaults.py b/plugins/extract/recognition/vgg_face2_defaults.py new file mode 100644 index 0000000000..6d32466b0f --- /dev/null +++ b/plugins/extract/recognition/vgg_face2_defaults.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap VGG Face2 recognition plugin. + + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = ( + "VGG Face 2 identity recognition.\n" + "A Keras port of the model trained for VGGFace2: A dataset for recognising faces across pose " + "and age. (https://arxiv.org/abs/1710.08092)" + ) + + +batch_size = ConfigItem( + datatype=int, + default=16, + group="settings", + info="The batch size to use. To a point, higher batch sizes equal better performance, " + "but setting it too high can harm performance.\n" + "\n\tNvidia users: If the batchsize is set higher than the your GPU can " + "accomodate then this will automatically be lowered.", + rounding=1, + min_max=(1, 64)) + +cpu = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="VGG Face2 still runs fairly quickly on CPU on some setups. Enable " + "CPU mode here to use the CPU for this plugin to save some VRAM at a speed cost.") diff --git a/plugins/plugin_loader.py b/plugins/plugin_loader.py index 3506cb84cc..02c3fb36ca 100644 --- a/plugins/plugin_loader.py +++ b/plugins/plugin_loader.py @@ -1,86 +1,303 @@ -#!/usr/bin/env python3 -""" Plugin loader for extract, training and model tasks """ - -import logging -import os -from importlib import import_module - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class PluginLoader(): - """ Plugin loader for extract, training and model tasks """ - @staticmethod - def get_detector(name): - """ Return requested detector plugin """ - return PluginLoader._import("extract.detect", name) - - @staticmethod - def get_aligner(name): - """ Return requested detector plugin """ - return PluginLoader._import("extract.align", name) - - @staticmethod - def get_converter(name): - """ Return requested converter plugin """ - return PluginLoader._import("convert", name) - - @staticmethod - def get_model(name): - """ Return requested model plugin """ - return PluginLoader._import("train.model", name) - - @staticmethod - def get_trainer(name): - """ Return requested trainer plugin """ - return PluginLoader._import("train.trainer", name) - - @staticmethod - def _import(attr, name): - """ Import the plugin's module """ - name = name.replace("-", "_") - ttl = attr.split(".")[-1].title() - logger.info("Loading %s from %s plugin...", ttl, name.title()) - attr = "model" if attr == "Trainer" else attr.lower() - mod = ".".join(("plugins", attr, name)) - module = import_module(mod) - return getattr(module, ttl) - - @staticmethod - def get_available_models(): - """ Return a list of available models """ - modelpath = os.path.join(os.path.dirname(__file__), "train", "model") - models = sorted(item.name.replace(".py", "").replace("_", "-") - for item in os.scandir(modelpath) - if not item.name.startswith("_") - and item.name.endswith(".py")) - return models - - @staticmethod - def get_available_converters(): - """ Return a list of available converters """ - converter_path = os.path.join(os.path.dirname(__file__), "convert") - converters = sorted(item.name.replace(".py", "").replace("_", "-") - for item in os.scandir(converter_path) - if not item.name.startswith("_") - and item.name.endswith(".py")) - return converters - - @staticmethod - def get_available_extractors(extractor_type): - """ Return a list of available models """ - extractpath = os.path.join(os.path.dirname(__file__), - "extract", - extractor_type) - extractors = sorted(item.name.replace(".py", "").replace("_", "-") - for item in os.scandir(extractpath) - if not item.name.startswith("_") - and item.name.endswith(".py") - and item.name != "manual.py") - return extractors - - @staticmethod - def get_default_model(): - """ Return the default model """ - models = PluginLoader.get_available_models() - return 'original' if 'original' in models else models[0] +#!/usr/bin/env python3 +""" Plugin loader for Faceswap extract, training and convert tasks """ +from __future__ import annotations +import logging +import os +import typing as T + +from importlib import import_module + +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from collections.abc import Callable + from plugins.extract.detect._base import Detector + from plugins.extract.align._base import Aligner + from plugins.extract.mask._base import Masker + from plugins.extract.recognition._base import Identity + from plugins.train.model._base import ModelBase + from plugins.train.trainer._base import TrainerBase + +logger = logging.getLogger(__name__) + + +class PluginLoader(): + """ Retrieve, or get information on, Faceswap plugins + + Return a specific plugin, list available plugins, or get the default plugin for a + task. + + Example + ------- + >>> from plugins.plugin_loader import PluginLoader + >>> align_plugins = PluginLoader.get_available_extractors('align') + >>> aligner = PluginLoader.get_aligner('cv2-dnn') + """ + @staticmethod + def get_detector(name: str, disable_logging: bool = False) -> type[Detector]: + """ Return requested detector plugin + + Parameters + ---------- + name: str + The name of the requested detector plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.extract.detect` object: + An extraction detector plugin + """ + return PluginLoader._import("extract.detect", name, disable_logging) + + @staticmethod + def get_aligner(name: str, disable_logging: bool = False) -> type[Aligner]: + """ Return requested aligner plugin + + Parameters + ---------- + name: str + The name of the requested aligner plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.extract.align` object: + An extraction aligner plugin + """ + return PluginLoader._import("extract.align", name, disable_logging) + + @staticmethod + def get_masker(name: str, disable_logging: bool = False) -> type[Masker]: + """ Return requested masker plugin + + Parameters + ---------- + name: str + The name of the requested masker plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.extract.mask` object: + An extraction masker plugin + """ + return PluginLoader._import("extract.mask", name, disable_logging) + + @staticmethod + def get_recognition(name: str, disable_logging: bool = False) -> type[Identity]: + """ Return requested recognition plugin + + Parameters + ---------- + name: str + The name of the requested reccognition plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.extract.recognition` object: + An extraction recognition plugin + """ + return PluginLoader._import("extract.recognition", name, disable_logging) + + @staticmethod + def get_model(name: str, disable_logging: bool = False) -> type[ModelBase]: + """ Return requested training model plugin + + Parameters + ---------- + name: str + The name of the requested training model plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.train.model` object: + A training model plugin + """ + return PluginLoader._import("train.model", name, disable_logging) + + @staticmethod + def get_trainer(name: str, disable_logging: bool = False) -> type[TrainerBase]: + """ Return requested training trainer plugin + + Parameters + ---------- + name: str + The name of the requested training trainer plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.train.trainer` object: + A training trainer plugin + """ + return PluginLoader._import("train.trainer", name, disable_logging) + + @staticmethod + def get_converter(category: str, name: str, disable_logging: bool = False) -> Callable: + """ Return requested converter plugin + + Converters work slightly differently to other faceswap plugins. They are created to do a + specific task (e.g. color adjustment, mask blending etc.), so multiple plugins will be + loaded in the convert phase, rather than just one plugin for the other phases. + + Parameters + ---------- + name: str + The name of the requested converter plugin + disable_logging: bool, optional + Whether to disable the INFO log message that the plugin is being imported. + Default: `False` + + Returns + ------- + :class:`plugins.convert` object: + A converter sub plugin + """ + return PluginLoader._import(f"convert.{category}", name, disable_logging) + + @staticmethod + def _import(attr: str, name: str, disable_logging: bool): + """ Import the plugin's module + + Parameters + ---------- + name: str + The name of the requested converter plugin + disable_logging: bool + Whether to disable the INFO log message that the plugin is being imported. + + Returns + ------- + :class:`plugin` object: + A plugin + """ + name = name.replace("-", "_") + ttl = attr.split(".")[-1].title() + if not disable_logging: + logger.info("Loading %s from %s plugin...", ttl, name.title()) + attr = "model" if attr == "Trainer" else attr.lower() + mod = ".".join(("plugins", attr, name)) + module = import_module(mod) + return getattr(module, ttl) + + @staticmethod + def get_available_extractors(extractor_type: T.Literal["align", "detect", "mask"], + add_none: bool = False, + extend_plugin: bool = False) -> list[str]: + """ Return a list of available extractors of the given type + + Parameters + ---------- + extractor_type: {'align', 'detect', 'mask'} + The type of extractor to return the plugins for + add_none: bool, optional + Append "none" to the list of returned plugins. Default: False + extend_plugin: bool, optional + Some plugins have configuration options that mean that multiple 'pseudo-plugins' + can be generated based on their settings. An example of this is the bisenet-fp mask + which, whilst selected as 'bisenet-fp' can be stored as 'bisenet-fp-face' and + 'bisenet-fp-head' depending on whether hair has been included in the mask or not. + ``True`` will generate each pseudo-plugin, ``False`` will generate the original + plugin name. Default: ``False`` + + Returns + ------- + list: + A list of the available extractor plugin names for the given type + """ + extractpath = os.path.join(os.path.dirname(__file__), + "extract", + extractor_type) + extractors = [item.name.replace(".py", "").replace("_", "-") + for item in os.scandir(extractpath) + if not item.name.startswith("_") + and not item.name.endswith("defaults.py") + and item.name.endswith(".py")] + extendable = ["bisenet-fp", "custom"] + if extend_plugin and extractor_type == "mask" and any(ext in extendable + for ext in extractors): + for msk in extendable: + extractors.remove(msk) + extractors.extend([f"{msk}_face", f"{msk}_head"]) + + extractors = sorted(extractors) + if add_none: + extractors.insert(0, "none") + return extractors + + @staticmethod + def get_available_models() -> list[str]: + """ Return a list of available training models + + Returns + ------- + list: + A list of the available training model plugin names + """ + modelpath = os.path.join(os.path.dirname(__file__), "train", "model") + models = sorted(item.name.replace(".py", "").replace("_", "-") + for item in os.scandir(modelpath) + if not item.name.startswith("_") + and not item.name.endswith("defaults.py") + and item.name.endswith(".py")) + return models + + @staticmethod + def get_default_model() -> str: + """ Return the default training model plugin name + + Returns + ------- + str: + The default faceswap training model + + """ + models = PluginLoader.get_available_models() + return 'original' if 'original' in models else models[0] + + @staticmethod + def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> list[str]: + """ Return a list of available converter plugins in the given category + + Parameters + ---------- + convert_category: {'color', 'mask', 'scaling', 'writer'} + The category of converter plugin to return the plugins for + add_none: bool, optional + Append "none" to the list of returned plugins. Default: True + + Returns + ------- + list + A list of the available converter plugin names in the given category + """ + + convertpath = os.path.join(os.path.dirname(__file__), + "convert", + convert_category) + converters = sorted(item.name.replace(".py", "").replace("_", "-") + for item in os.scandir(convertpath) + if not item.name.startswith("_") + and not item.name.endswith("defaults.py") + and item.name.endswith(".py")) + if add_none: + converters.insert(0, "none") + return converters + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/__init__.py b/plugins/train/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/train/_config.py b/plugins/train/_config.py deleted file mode 100644 index 994c9cd067..0000000000 --- a/plugins/train/_config.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -""" Default configurations for models """ - -import logging - -from lib.config import FaceswapConfig - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -MASK_TYPES = ["none", "dfaker", "dfl_full"] -MASK_INFO = "The mask to be used for training. Select none to not use a mask" -COVERAGE_INFO = ("How much of the extracted image to train on. Generally the model is optimized\n" - "to the default value. Sensible values to use are:" - "\n\t62.5%% spans from eyebrow to eyebrow." - "\n\t75.0%% spans from temple to temple." - "\n\t87.5%% spans from ear to ear." - "\n\t100.0%% is a mugshot.") - - -class Config(FaceswapConfig): - """ Config File for Models """ - - def set_defaults(self): - """ Set the default values for config """ - logger.debug("Setting defaults") - # << GLOBAL OPTIONS >> # - section = "global" - self.add_section(title=section, - info="Options that apply to all models") - self.add_item( - section=section, title="icnr_init", datatype=bool, default=False, - info="Use ICNR Kernel Initializer for upscaling.\nThis can help reduce the " - "'checkerboard effect' when upscaling the image.") - self.add_item( - section=section, title="subpixel_upscaling", datatype=bool, default=False, - info="Use subpixel upscaling rather than pixel shuffler.\n" - "Might increase speed at cost of VRAM") - self.add_item( - section=section, title="reflect_padding", datatype=bool, default=False, - info="Use reflect padding rather than zero padding.") - self.add_item( - section=section, title="dssim_mask_loss", datatype=bool, default=True, - info="If using a mask, Use DSSIM loss for Mask training rather than Mean Absolute " - "Error\nMay increase overall quality.") - self.add_item( - section=section, title="penalized_mask_loss", datatype=bool, default=True, - info="If using a mask, Use Penalized loss for Mask training. Can stack with DSSIM.\n" - "May increase overall quality.") - - # << DFAKER OPTIONS >> # - section = "model.dfaker" - self.add_section(title=section, - info="Dfaker Model (Adapted from https://github.com/dfaker/df)") - self.add_item( - section=section, title="mask_type", datatype=str, default="dfaker", - choices=MASK_TYPES, info=MASK_INFO) - self.add_item( - section=section, title="coverage", datatype=float, default=100.0, rounding=1, - min_max=(62.5, 100.0), info=COVERAGE_INFO) - - # << DFL MODEL OPTIONS >> # - section = "model.dfl_h128" - self.add_section(title=section, - info="DFL H128 Model (Adapted from " - "https://github.com/iperov/DeepFaceLab)") - self.add_item( - section=section, title="lowmem", datatype=bool, default=False, - info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " - "with a changed lowmem mode are not compatible with each other.") - self.add_item( - section=section, title="mask_type", datatype=str, default="dfl_full", - choices=MASK_TYPES, info=MASK_INFO) - self.add_item( - section=section, title="coverage", datatype=float, default=62.5, rounding=1, - min_max=(62.5, 100.0), info=COVERAGE_INFO) - - # << IAE MODEL OPTIONS >> # - section = "model.iae" - self.add_section(title=section, - info="Intermediate Auto Encoder. Based on Original Model, uses " - "intermediate layers to try to better get details") - self.add_item( - section=section, title="dssim_loss", datatype=bool, default=False, - info="Use DSSIM for Loss rather than Mean Absolute Error\n" - "May increase overall quality.") - self.add_item( - section=section, title="mask_type", datatype=str, default="none", - choices=MASK_TYPES, info=MASK_INFO) - self.add_item( - section=section, title="coverage", datatype=float, default=62.5, rounding=1, - min_max=(62.5, 100.0), info=COVERAGE_INFO) - - # << ORIGINAL MODEL OPTIONS >> # - section = "model.original" - self.add_section(title=section, - info="Original Faceswap Model") - self.add_item( - section=section, title="lowmem", datatype=bool, default=False, - info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " - "with a changed lowmem mode are not compatible with each other.") - self.add_item( - section=section, title="dssim_loss", datatype=bool, default=False, - info="Use DSSIM for Loss rather than Mean Absolute Error\n" - "May increase overall quality.") - self.add_item( - section=section, title="mask_type", datatype=str, default="none", - choices=MASK_TYPES, info=MASK_INFO) - self.add_item( - section=section, title="coverage", datatype=float, default=62.5, rounding=1, - min_max=(62.5, 100.0), info=COVERAGE_INFO) - - # << UNBALANCED MODEL OPTIONS >> # - section = "model.unbalanced" - self.add_section(title=section, - info="An unbalanced model with adjustable input size options.\n" - "This is an unbalanced model so b>a swaps may not work well") - self.add_item( - section=section, title="lowmem", datatype=bool, default=False, - info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " - "with a changed lowmem mode are not compatible with each other. NB: lowmem will " - "override cutom nodes and complexity settings.") - self.add_item( - section=section, title="dssim_loss", datatype=bool, default=False, - info="Use DSSIM for Loss rather than Mean Absolute Error\n" - "May increase overall quality.") - self.add_item( - section=section, title="mask_type", datatype=str, default="none", - choices=MASK_TYPES, info=MASK_INFO) - self.add_item( - section=section, title="nodes", datatype=int, default=1024, rounding=64, - min_max=(512, 4096), - info="Number of nodes for decoder. Don't change this unless you " - "know what you are doing!") - self.add_item( - section=section, title="complexity_encoder", datatype=int, default=128, - rounding=16, min_max=(64, 1024), - info="Encoder Convolution Layer Complexity. sensible ranges: " - "128 to 160") - self.add_item( - section=section, title="complexity_decoder_a", datatype=int, default=384, - rounding=16, min_max=(64, 1024), - info="Decoder A Complexity.") - self.add_item( - section=section, title="complexity_decoder_b", datatype=int, default=512, - rounding=16, min_max=(64, 1024), - info="Decoder B Complexity.") - self.add_item( - section=section, title="input_size", datatype=int, default=128, - rounding=64, min_max=(64, 512), - info="Resolution (in pixels) of the image to train on.\n" - "BE AWARE Larger resolution will dramatically increase" - "VRAM requirements.\n" - "Make sure your resolution is divisible by 64 (e.g. 64, 128, 256 etc.).\n" - "NB: Your faceset must be at least 1.6x larger than your required input size.\n" - " (e.g. 160 is the maximum input size for a 256x256 faceset)") - self.add_item( - section=section, title="coverage", datatype=float, default=62.5, rounding=1, - min_max=(62.5, 100.0), info=COVERAGE_INFO) - - # << VILLAIN MODEL OPTIONS >> # - section = "model.villain" - self.add_section(title=section, - info="A Higher resolution version of the Original " - "Model by VillainGuy.\n" - "Extremely VRAM heavy. Full model requires 9GB+ for batchsize 16") - self.add_item( - section=section, title="lowmem", datatype=bool, default=False, - info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " - "with a changed lowmem mode are not compatible with each other.") - self.add_item( - section=section, title="dssim_loss", datatype=bool, default=False, - info="Use DSSIM for Loss rather than Mean Absolute Error\n" - "May increase overall quality.") - self.add_item( - section=section, title="mask_type", datatype=str, default="none", - choices=["none", "dfaker", "dfl_full"], - info="The mask to be used for training. Select none to not use a mask") - self.add_item( - section=section, title="coverage", datatype=float, default=62.5, rounding=1, - min_max=(62.5, 100.0), info=COVERAGE_INFO) diff --git a/plugins/train/model/_base.py b/plugins/train/model/_base.py deleted file mode 100644 index c11f568e9f..0000000000 --- a/plugins/train/model/_base.py +++ /dev/null @@ -1,651 +0,0 @@ -#!/usr/bin/env python3 -""" Base class for Models. ALL Models should at least inherit from this class - - When inheriting model_data should be a list of NNMeta objects. - See the class for details. -""" -import logging -import os -import sys -import time - -from json import JSONDecodeError - -from keras import losses -from keras.models import load_model -from keras.optimizers import Adam -from keras.utils import get_custom_objects, multi_gpu_model - -from lib import Serializer -from lib.model.losses import DSSIMObjective, PenalizedLoss -from lib.model.nn_blocks import NNBlocks -from lib.multithreading import MultiThread -from plugins.train._config import Config - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name -_CONFIG = None - - -class ModelBase(): - """ Base class that all models should inherit from """ - def __init__(self, - model_dir, - gpus, - no_logs=False, - warp_to_landmarks=False, - no_flip=False, - training_image_size=256, - alignments_paths=None, - preview_scale=100, - input_shape=None, - encoder_dim=None, - trainer="original", - predict=False): - logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, " - "training_image_size, %s, alignments_paths: %s, preview_scale: %s, " - "input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus, - training_image_size, alignments_paths, preview_scale, input_shape, - encoder_dim) - self.predict = predict - self.model_dir = model_dir - self.gpus = gpus - self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"], - use_icnr_init=self.config["icnr_init"], - use_reflect_padding=self.config["reflect_padding"]) - self.input_shape = input_shape - self.output_shape = None # set after model is compiled - self.encoder_dim = encoder_dim - self.trainer = trainer - - self.state = State(self.model_dir, self.name, no_logs, training_image_size) - self.rename_legacy() - self.load_state_info() - - self.networks = dict() # Networks for the model - self.predictors = dict() # Predictors for model - self.history = dict() # Loss history per save iteration) - - # Training information specific to the model should be placed in this - # dict for reference by the trainer. - self.training_opts = {"alignments": alignments_paths, - "preview_scaling": preview_scale / 100, - "warp_to_landmarks": warp_to_landmarks, - "no_flip": no_flip} - - self.build() - self.set_training_data() - logger.debug("Initialized ModelBase (%s)", self.__class__.__name__) - - @property - def config(self): - """ Return config dict for current plugin """ - global _CONFIG # pylint: disable=global-statement - if not _CONFIG: - model_name = ".".join(self.__module__.split(".")[-2:]) - logger.debug("Loading config for: %s", model_name) - _CONFIG = Config(model_name).config_dict - return _CONFIG - - @property - def name(self): - """ Set the model name based on the subclass """ - basename = os.path.basename(sys.modules[self.__module__].__file__) - retval = os.path.splitext(basename)[0].lower() - logger.debug("model name: '%s'", retval) - return retval - - def set_training_data(self): - """ Override to set model specific training data. - - super() this method for defaults otherwise be sure to add """ - logger.debug("Setting training data") - self.training_opts["training_size"] = self.state.training_size - self.training_opts["no_logs"] = self.state.current_session["no_logs"] - self.training_opts["mask_type"] = self.config.get("mask_type", None) - self.training_opts["coverage_ratio"] = self.calculate_coverage_ratio() - self.training_opts["preview_images"] = 14 - logger.debug("Set training data: %s", self.training_opts) - - def calculate_coverage_ratio(self): - """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """ - coverage_ratio = self.config.get("coverage", 62.5) / 100 - logger.debug("Requested coverage_ratio: %s", coverage_ratio) - cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2 - coverage_ratio = cropped_size / self.state.training_size - logger.debug("Final coverage_ratio: %s", coverage_ratio) - return coverage_ratio - - def build(self): - """ Build the model. Override for custom build methods """ - self.add_networks() - self.load_models(swapped=False) - self.build_autoencoders() - self.log_summary() - self.compile_predictors() - - def build_autoencoders(self): - """ Override for Model Specific autoencoder builds - - NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names - are expected: - face (the input for image) - mask (the input for mask if it is used) - - """ - raise NotImplementedError - - def add_networks(self): - """ Override to add neural networks """ - raise NotImplementedError - - def load_state_info(self): - """ Load the input shape from state file if it exists """ - logger.debug("Loading Input Shape from State file") - if not self.state.inputs: - logger.debug("No input shapes saved. Using model config") - return - if not self.state.face_shapes: - logger.warning("Input shapes stored in State file, but no matches for 'face'." - "Using model config") - return - input_shape = self.state.face_shapes[0] - logger.debug("Setting input shape from state file: %s", input_shape) - self.input_shape = input_shape - - def add_network(self, network_type, side, network): - """ Add a NNMeta object """ - logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network) - filename = "{}_{}".format(self.name, network_type.lower()) - name = network_type.lower() - if side: - side = side.lower() - filename += "_{}".format(side.upper()) - name += "_{}".format(side) - filename += ".h5" - logger.debug("name: '%s', filename: '%s'", name, filename) - self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network) - - def add_predictor(self, side, model): - """ Add a predictor to the predictors dictionary """ - logger.debug("Adding predictor: (side: '%s', model: %s)", side, model) - if self.gpus > 1: - logger.debug("Converting to multi-gpu: side %s", side) - model = multi_gpu_model(model, self.gpus) - self.predictors[side] = model - if not self.state.inputs: - self.store_input_shapes(model) - if not self.output_shape: - self.set_output_shape(model) - - def store_input_shapes(self, model): - """ Store the input and output shapes to state """ - logger.debug("Adding input shapes to state for model") - inputs = {tensor.name: tensor.get_shape().as_list()[-3:] for tensor in model.inputs} - if not any(inp for inp in inputs.keys() if inp.startswith("face")): - raise ValueError("No input named 'face' was found. Check your input naming. " - "Current input names: {}".format(inputs)) - self.state.inputs = inputs - logger.debug("Added input shapes: %s", self.state.inputs) - - def set_output_shape(self, model): - """ Set the output shape for use in training and convert """ - logger.debug("Setting output shape") - out = [tensor.get_shape().as_list()[-3:] for tensor in model.outputs] - if not out: - raise ValueError("No outputs found! Check your model.") - self.output_shape = tuple(out[0]) - logger.debug("Added output shape: %s", self.output_shape) - - def compile_predictors(self): - """ Compile the predictors """ - logger.debug("Compiling Predictors") - optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0) - - for side, model in self.predictors.items(): - loss_names = ["loss"] - loss_funcs = [self.loss_function(side)] - mask = [inp for inp in model.inputs if inp.name.startswith("mask")] - if mask: - loss_names.insert(0, "mask_loss") - loss_funcs.insert(0, self.mask_loss_function(mask[0], side)) - model.compile(optimizer=optimizer, loss=loss_funcs) - - if len(loss_names) > 1: - loss_names.insert(0, "total_loss") - self.state.add_session_loss_names(side, loss_names) - self.history[side] = list() - logger.debug("Compiled Predictors. Losses: %s", loss_names) - - def loss_function(self, side): - """ Set the loss function """ - if self.config.get("dssim_loss", False): - if side == "a" and not self.predict: - logger.verbose("Using DSSIM Loss") - loss_func = DSSIMObjective() - else: - if side == "a" and not self.predict: - logger.verbose("Using Mean Absolute Error Loss") - loss_func = losses.mean_absolute_error - logger.debug(loss_func) - return loss_func - - def mask_loss_function(self, mask, side): - """ Set the loss function for masks - Side is input so we only log once """ - if self.config.get("dssim_mask_loss", False): - if side == "a" and not self.predict: - logger.verbose("Using DSSIM Loss for mask") - mask_loss_func = DSSIMObjective() - else: - if side == "a" and not self.predict: - logger.verbose("Using Mean Absolute Error Loss for mask") - mask_loss_func = losses.mean_absolute_error - - if self.config.get("penalized_mask_loss", False): - if side == "a" and not self.predict: - logger.verbose("Using Penalized Loss for mask") - mask_loss_func = PenalizedLoss(mask, mask_loss_func) - logger.debug(mask_loss_func) - return mask_loss_func - - def converter(self, swap): - """ Converter for autoencoder models """ - logger.debug("Getting Converter: (swap: %s)", swap) - if swap: - retval = self.predictors["a"].predict - else: - retval = self.predictors["b"].predict - logger.debug("Got Converter: %s", retval) - return retval - - @property - def iterations(self): - "Get current training iteration number" - return self.state.iterations - - def map_models(self, swapped): - """ Map the models for A/B side for swapping """ - logger.debug("Map models: (swapped: %s)", swapped) - models_map = {"a": dict(), "b": dict()} - sides = ("a", "b") if not swapped else ("b", "a") - for network in self.networks.values(): - if network.side == sides[0]: - models_map["a"][network.type] = network.filename - if network.side == sides[1]: - models_map["b"][network.type] = network.filename - logger.debug("Mapped models: (models_map: %s)", models_map) - return models_map - - def log_summary(self): - """ Verbose log the model summaries """ - if self.predict: - return - for side in sorted(list(self.predictors.keys())): - logger.verbose("[%s %s Summary]:", self.name.title(), side.upper()) - self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x)) - for name, nnmeta in self.networks.items(): - if nnmeta.side is not None and nnmeta.side != side: - continue - logger.verbose("%s:", name.title()) - nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x)) - - def load_models(self, swapped): - """ Load models from file """ - logger.debug("Load model: (swapped: %s)", swapped) - model_mapping = self.map_models(swapped) - for network in self.networks.values(): - if not network.side: - is_loaded = network.load(predict=self.predict) - else: - is_loaded = network.load(fullpath=model_mapping[network.side][network.type], - predict=self.predict) - if not is_loaded: - break - if is_loaded: - logger.info("Loaded model from disk: '%s'", self.model_dir) - return is_loaded - - def save_models(self): - """ Backup and save the models """ - logger.debug("Backing up and saving models") - should_backup = self.get_save_averages() - save_threads = list() - for network in self.networks.values(): - name = "save_{}".format(network.name) - save_threads.append(MultiThread(network.save, name=name, should_backup=should_backup)) - save_threads.append(MultiThread(self.state.save, - name="save_state", should_backup=should_backup)) - for thread in save_threads: - thread.start() - for thread in save_threads: - if thread.has_error: - logger.error(thread.errors[0]) - thread.join() - # Put in a line break to avoid jumbled console - print("\n") - logger.info("saved models") - - def get_save_averages(self): - """ Return the loss averages since last save and reset historical losses - - This protects against model corruption by only backing up the model - if any of the loss values have fallen. - TODO This is not a perfect system. If the model corrupts on save_iteration - 1 - then model may still backup - """ - logger.debug("Getting Average loss since last save") - avgs = dict() - backup = True - - for side, loss in self.history.items(): - if not loss: - backup = False - break - - avgs[side] = sum(loss) / len(loss) - self.history[side] = list() # Reset historical loss - - if not self.state.lowest_avg_loss.get(side, None): - logger.debug("Setting initial save iteration loss average for '%s': %s", - side, avgs[side]) - self.state.lowest_avg_loss[side] = avgs[side] - continue - - if backup: - # Only run this if backup is true. All losses must have dropped for a valid backup - backup = self.check_loss_drop(side, avgs[side]) - - logger.debug("Lowest historical save iteration loss average: %s", - self.state.lowest_avg_loss) - logger.debug("Average loss since last save: %s", avgs) - - if backup: # Update lowest loss values to the state - for side, avg_loss in avgs.items(): - logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss) - self.state.lowest_avg_loss[side] = avg_loss - - logger.debug("Backing up: %s", backup) - return backup - - def check_loss_drop(self, side, avg): - """ Check whether total loss has dropped since lowest loss """ - if avg < self.state.lowest_avg_loss[side]: - logger.debug("Loss for '%s' has dropped", side) - return True - logger.debug("Loss for '%s' has not dropped", side) - return False - - def rename_legacy(self): - """ Legacy Original, LowMem and IAE models had inconsistent naming conventions - Rename them if they are found and update """ - legacy_mapping = {"iae": [("IAE_decoder.h5", "iae_decoder.h5"), - ("IAE_encoder.h5", "iae_encoder.h5"), - ("IAE_inter_A.h5", "iae_intermediate_A.h5"), - ("IAE_inter_B.h5", "iae_intermediate_B.h5"), - ("IAE_inter_both.h5", "iae_inter.h5")], - "original": [("encoder.h5", "original_encoder.h5"), - ("decoder_A.h5", "original_decoder_A.h5"), - ("decoder_B.h5", "original_decoder_B.h5"), - ("lowmem_encoder.h5", "original_encoder.h5"), - ("lowmem_decoder_A.h5", "original_decoder_A.h5"), - ("lowmem_decoder_B.h5", "original_decoder_B.h5")]} - if self.name not in legacy_mapping.keys(): - return - logger.debug("Renaming legacy files") - - set_lowmem = False - updated = False - for old_name, new_name in legacy_mapping[self.name]: - old_path = os.path.join(str(self.model_dir), old_name) - new_path = os.path.join(str(self.model_dir), new_name) - if os.path.exists(old_path) and not os.path.exists(new_path): - logger.info("Updating legacy model name from: '%s' to '%s'", old_name, new_name) - os.rename(old_path, new_path) - if old_name.startswith("lowmem"): - set_lowmem = True - updated = True - - if not updated: - logger.debug("No legacy files to rename") - return - - logger.debug("Creating state file for legacy model") - self.state.inputs = {"face:0": [64, 64, 3]} - self.state.training_size = 256 - self.state.config["coverage"] = 62.5 - self.state.config["subpixel_upscaling"] = False - self.state.config["reflect_padding"] = False - self.state.config["mask_type"] = None - self.state.config["lowmem"] = False - self.encoder_dim = 1024 - - if set_lowmem: - logger.debug("Setting encoder_dim and lowmem flag for legacy lowmem model") - self.encoder_dim = 512 - self.state.config["lowmem"] = True - - self.state.replace_config() - self.state.save() - - -class NNMeta(): - """ Class to hold a neural network and it's meta data - - filename: The full path and filename of the model file for this network. - type: The type of network. For networks that can be swapped - The type should be identical for the corresponding - A and B networks, and should be unique for every A/B pair. - Otherwise the type should be completely unique. - side: A, B or None. Used to identify which networks can - be swapped. - network: Define network to this. - """ - - def __init__(self, filename, network_type, side, network): - logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', " - "network: %s", self.__class__.__name__, filename, network_type, - side, network) - self.filename = filename - self.type = network_type.lower() - self.side = side - self.name = self.set_name() - self.network = network - self.network.name = self.name - logger.debug("Initialized %s", self.__class__.__name__) - - def set_name(self): - """ Set the network name """ - name = self.type - if self.side: - name += "_{}".format(self.side) - return name - - def load(self, fullpath=None, predict=False): - """ Load model """ - fullpath = fullpath if fullpath else self.filename - logger.debug("Loading model: '%s'", fullpath) - try: - network = load_model(self.filename, custom_objects=get_custom_objects()) - except ValueError as err: - if str(err).lower().startswith("cannot create group in read only mode"): - self.convert_legacy_weights() - return True - if predict: - raise ValueError("Unable to load training data. Error: {}".format(str(err))) - logger.warning("Failed loading existing training data. Generating new models") - logger.debug("Exception: %s", str(err)) - return False - except OSError as err: # pylint: disable=broad-except - if predict: - raise ValueError("Unable to load training data. Error: {}".format(str(err))) - logger.warning("Failed loading existing training data. Generating new models") - logger.debug("Exception: %s", str(err)) - return False - self.network = network # Update network with saved model - self.network.name = self.type - return True - - def save(self, fullpath=None, should_backup=False): - """ Save model """ - fullpath = fullpath if fullpath else self.filename - if should_backup: - self.backup(fullpath=fullpath) - logger.debug("Saving model: '%s'", fullpath) - self.network.save(fullpath) - - def backup(self, fullpath=None): - """ Backup Model """ - origfile = fullpath if fullpath else self.filename - backupfile = origfile + ".bk" - logger.debug("Backing up: '%s' to '%s'", origfile, backupfile) - if os.path.exists(backupfile): - os.remove(backupfile) - if os.path.exists(origfile): - os.rename(origfile, backupfile) - - def convert_legacy_weights(self): - """ Convert legacy weights files to hold the model topology """ - logger.info("Adding model topology to legacy weights file: '%s'", self.filename) - self.network.load_weights(self.filename) - self.save(should_backup=False) - self.network.name = self.type - - -class State(): - """ Class to hold the model's current state and autoencoder structure """ - def __init__(self, model_dir, model_name, no_logs, training_image_size): - logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, " - "training_image_size: '%s'", self.__class__.__name__, model_dir, - model_name, no_logs, training_image_size) - self.serializer = Serializer.get_serializer("json") - filename = "{}_state.{}".format(model_name, self.serializer.ext) - self.filename = str(model_dir / filename) - self.name = model_name - self.iterations = 0 - self.session_iterations = 0 - self.training_size = training_image_size - self.sessions = dict() - self.lowest_avg_loss = dict() - self.inputs = dict() - self.config = dict() - self.load() - self.session_id = self.new_session_id() - self.create_new_session(no_logs) - logger.debug("Initialized %s:", self.__class__.__name__) - - @property - def face_shapes(self): - """ Return a list of stored face shape inputs """ - return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")] - - @property - def mask_shapes(self): - """ Return a list of stored mask shape inputs """ - return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")] - - @property - def loss_names(self): - """ Return the loss names for this session """ - return self.sessions[self.session_id]["loss_names"] - - @property - def current_session(self): - """ Return the current session dict """ - return self.sessions[self.session_id] - - def new_session_id(self): - """ Return new session_id """ - if not self.sessions: - session_id = 1 - else: - session_id = max(int(key) for key in self.sessions.keys()) + 1 - logger.debug(session_id) - return session_id - - def create_new_session(self, no_logs): - """ Create a new session """ - logger.debug("Creating new session. id: %s", self.session_id) - self.sessions[self.session_id] = {"timestamp": time.time(), - "no_logs": no_logs, - "loss_names": dict(), - "batchsize": 0, - "iterations": 0} - - def add_session_loss_names(self, side, loss_names): - """ Add the session loss names to the sessions dictionary """ - logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names) - self.sessions[self.session_id]["loss_names"][side] = loss_names - - def add_session_batchsize(self, batchsize): - """ Add the session batchsize to the sessions dictionary """ - logger.debug("Adding session batchsize: %s", batchsize) - self.sessions[self.session_id]["batchsize"] = batchsize - - def increment_iterations(self): - """ Increment total and session iterations """ - self.iterations += 1 - self.sessions[self.session_id]["iterations"] += 1 - - def load(self): - """ Load state file """ - logger.debug("Loading State") - try: - with open(self.filename, "rb") as inp: - state = self.serializer.unmarshal(inp.read().decode("utf-8")) - self.name = state.get("name", self.name) - self.sessions = state.get("sessions", dict()) - self.lowest_avg_loss = state.get("lowest_avg_loss", dict()) - self.iterations = state.get("iterations", 0) - self.training_size = state.get("training_size", 256) - self.inputs = state.get("inputs", dict()) - self.config = state.get("config", dict()) - logger.debug("Loaded state: %s", state) - self.replace_config() - except IOError as err: - logger.warning("No existing state file found. Generating.") - logger.debug("IOError: %s", str(err)) - except JSONDecodeError as err: - logger.debug("JSONDecodeError: %s:", str(err)) - - def save(self, should_backup=False): - """ Save iteration number to state file """ - logger.debug("Saving State") - if should_backup: - self.backup() - try: - with open(self.filename, "wb") as out: - state = {"name": self.name, - "sessions": self.sessions, - "lowest_avg_loss": self.lowest_avg_loss, - "iterations": self.iterations, - "inputs": self.inputs, - "training_size": self.training_size, - "config": _CONFIG} - state_json = self.serializer.marshal(state) - out.write(state_json.encode("utf-8")) - except IOError as err: - logger.error("Unable to save model state: %s", str(err.strerror)) - logger.debug("Saved State") - - def backup(self): - """ Backup state file """ - origfile = self.filename - backupfile = origfile + ".bk" - logger.debug("Backing up: '%s' to '%s'", origfile, backupfile) - if os.path.exists(backupfile): - os.remove(backupfile) - if os.path.exists(origfile): - os.rename(origfile, backupfile) - - def replace_config(self): - """ Replace the loaded config with the one contained within the state file """ - global _CONFIG # pylint: disable=global-statement - # Add any new items to state config for legacy purposes - for key, val in _CONFIG.items(): - if key not in self.config.keys(): - logger.info("Adding new config item to state file: '%s': '%s'", key, val) - self.config[key] = val - logger.debug("Replacing config. Old config: %s", _CONFIG) - _CONFIG = self.config - logger.debug("Replaced config. New config: %s", _CONFIG) - logger.info("Using configuration saved in state file") diff --git a/plugins/train/model/_base/__init__.py b/plugins/train/model/_base/__init__.py new file mode 100644 index 0000000000..c26c15103e --- /dev/null +++ b/plugins/train/model/_base/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Base class for Models plugins ALL Models should at least inherit from this class. """ + +from .model import get_all_sub_models, ModelBase diff --git a/plugins/train/model/_base/inference.py b/plugins/train/model/_base/inference.py new file mode 100644 index 0000000000..d5fa6ef97e --- /dev/null +++ b/plugins/train/model/_base/inference.py @@ -0,0 +1,283 @@ +#! /usr/env/bin/python3 +""" Handles the recompilation of a Faceswap model into a version that can be used for inference """ +from __future__ import annotations +import logging +import typing as T + +import keras + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + import keras.src.ops.node + +logger = logging.getLogger(__name__) + + +class Inference(): + """ Calculates required layers and compiles a saved model for inference. + + Parameters + ---------- + saved_model: :class:`keras.Model` + The saved trained Faceswap model + switch_sides: bool + ``True`` if the swap should be performed "B" > "A" ``False`` if the swap should be + "A" > "B" + """ + def __init__(self, saved_model: keras.Model, switch_sides: bool) -> None: + logger.debug(parse_class_init(locals())) + + self._layers: list[keras.Layer] = [lyr for lyr in saved_model.layers + if not isinstance(lyr, keras.layers.InputLayer)] + """list[:class:`keras.layers.Layer]: All the layers that exist within the model excluding + input layers """ + + self._input = self._get_model_input(saved_model, switch_sides) + """:class:`keras.KerasTensor`: The correct input for the inference model """ + + self._name = f"{saved_model.name}_inference" + """str: The name for the final inference model""" + + self._model = self._build() + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def model(self) -> keras.Model: + """ :class:`keras.Model`: The Faceswap model, compiled for inference. """ + return self._model + + def _get_model_input(self, model: keras.Model, switch_sides: bool) -> list[keras.KerasTensor]: + """ Obtain the inputs for the requested swap direction. + + Parameters + ---------- + saved_model: :class:`keras.Model` + The saved trained Faceswap model + switch_sides: bool + ``True`` if the swap should be performed "B" > "A" ``False`` if the swap should be + "A" > "B" + + Returns + ------- + list[]:class:`keras.KerasTensor`] + The input tensor to feed the model for the requested swap direction + """ + inputs: list[keras.KerasTensor] = model.input + assert len(inputs) == 2, "Faceswap models should have exactly 2 inputs" + idx = 0 if switch_sides else 1 + retval = inputs[idx] + logger.debug("model inputs: %s, idx: %s, inference_input: '%s'", + [(i.name, i.shape[1:]) for i in inputs], idx, retval.name) + return [retval] + + def _get_candidates(self, input_tensors: list[keras.KerasTensor | keras.Layer] + ) -> T.Generator[tuple[keras.Layer, list[keras.src.ops.node.KerasHistory]], + None, None]: + """ Given a list of input tensors, get all layers from the main model which have the given + input tensors marked as Inbound nodes for the model + + Parameters + ---------- + input_tensors: list[:class:`keras.KerasTensor` | :class:`keras.Layer`] + List of Tensors that act as an input to a layer within the model + + Yields + ------ + tuple[:class:`keras.KerasLayer`, list[:class:`keras.src.ops.node.KerasHistory'] + Any layer in the main model that use the given input tensors as an input along with the + corresponding keras inbound history + """ + unique_input_names = set(i.name for i in input_tensors) + for layer in self._layers: + + history = [tensor._keras_history # pylint:disable=protected-access + for node in layer._inbound_nodes # pylint:disable=protected-access + for parent in node.parent_nodes + for tensor in parent.outputs] + + unique_inbound_names = set(h.operation.name for h in history) + if not unique_input_names.issubset(unique_inbound_names): + logger.debug("%s: Skipping candidate '%s' unmatched inputs: %s", + unique_input_names, layer.name, unique_inbound_names) + continue + + logger.debug("%s: Yielding candidate '%s'. History: %s", + unique_input_names, layer.name, [(h.operation.name, h.node_index) + for h in history]) + yield layer, history + + @T.overload + def _group_inputs(self, layer: keras.Layer, inputs: list[tuple[keras.Layer, int]] + ) -> list[list[tuple[keras.Layer, int]]]: + ... + + @T.overload + def _group_inputs(self, layer: keras.Layer, inputs: list[keras.src.ops.node.KerasHistory] + ) -> list[list[keras.src.ops.node.KerasHistory]]: + ... + + def _group_inputs(self, layer, inputs): + """ Layers can have more than one input. In these instances we need to group the inputs + and the layers' inbound nodes to correspond to inputs per instance. + + Parameters + ---------- + layer: :class:`keras.Layer` + The current layer being processed + inputs: list[:class:`keras.KerasTensor`] | list[:class:`keras.src.ops.node.KerasHistory`] + List of input tensors or inbound keras histories to be grouped per layer input + + Returns + ------- + list[list[tuple[:class:`keras.Layer`, int]]] | + list[list[:class:`keras.src.ops.node.KerasHistory`] + A list of list of input layers and the corresponding node index or inbound keras + histories + """ + layer_inputs = 1 if isinstance(layer.input, keras.KerasTensor) else len(layer.input) + num_inputs = len(inputs) + + total_calls = num_inputs / layer_inputs + assert total_calls.is_integer() + total_calls = int(total_calls) + + retval = [inputs[i * layer_inputs: i * layer_inputs + layer_inputs] + for i in range(total_calls)] + + return retval + + def _layers_from_inputs(self, + input_tensors: list[keras.KerasTensor | keras.Layer], + node_indices: list[int] + ) -> tuple[list[keras.Layer], + list[keras.src.ops.node.KerasHistory], + list[int]]: + """ Given a list of input tensors and their corresponding inbound node ids, return all of + the layers for the model that uses the given nodes as their input + + Parameters + ---------- + input_tensors: list[:class:`keras.KerasTensor` | :class:`keras.Layer`] + List of Tensors that act as an input to a layer within the model + node_indices: list[int] + The list of node indices corresponding to the inbound node index of the given layers + + Returns + ------- + list[:class:`keras.layers.Layer`] + Any layers from the model that use the given inputs as its input. Empty list if there + are no matches + list[:class:`keras.src.ops.node.KerasHistory`] + The keras inbound history for the layers + list[int] + The output node index for the layer, used for the inbound node index of the next layer + """ + retval: tuple[list[keras.Layer], + list[keras.src.ops.node.KerasHistory], + list[int]] = ([], [], []) + for layer, history in self._get_candidates(input_tensors): + grp_inputs = self._group_inputs(layer, list(zip(input_tensors, node_indices))) + grp_hist = self._group_inputs(layer, history) + + for input_group in grp_inputs: # pylint:disable=not-an-iterable + have = [(i[0].name, i[1]) for i in input_group] + for out_idx, hist in enumerate(grp_hist): + requires = [(h.operation.name, h.node_index) for h in hist] + if sorted(have) != sorted(requires): + logger.debug("%s: Skipping '%s'. Requires %s. Output node index: %s", + have, layer.name, requires, out_idx) + continue + retval[0].append(layer) + retval[1].append(hist) + retval[2].append(out_idx) + + logger.debug("Got layers %s for input_tensors: %s", + [x.name for x in retval[0]], [t.name for t in input_tensors]) + return retval + + def _build_layers(self, + layers: list[keras.Layer], + history: list[keras.src.ops.node.KerasHistory], + inputs: list[keras.KerasTensor]) -> list[keras.KerasTensor]: + """ Compile the given layers with the given inputs + + Parameters + ---------- + layers: list[:class:`keras.Layer`] + The layers to be called with the given inputs + history: list[:class:`keras.src.ops.node.KerasHistory`] + The corresponding keras inbound history for the layers + inputs: list[:class:`keras.KerasTensor] + The inputs for the given layers + + Returns + ------- + list[:class:`keras.KerasTensor`] + The list of compiled layers + """ + retval = [] + given_order = [i._keras_history.operation.name # pylint:disable=protected-access + for i in inputs] + for layer, hist in zip(layers, history): + layer_input = [inputs[given_order.index(h.operation.name)] + for h in hist if h.operation.name in given_order] + if layer_input != inputs: + logger.debug("Sorted layer inputs %s to %s", + given_order, + [i._keras_history.operation.name # pylint:disable=protected-access + for i in layer_input]) + + if isinstance(layer_input, list) and len(layer_input) == 1: + # Flatten single inputs to stop Keras warnings + actual_input = layer_input[0] + else: + actual_input = layer_input + + built = layer(actual_input) + built = built if isinstance(built, list) else [built] + logger.debug( + "Compiled layer '%s' from input(s) %s", + layer.name, + [i._keras_history.operation.name # pylint:disable=protected-access + for i in layer_input]) + retval.extend(built) + + logger.debug( + "Compiled layers %s from input %s", + [x._keras_history.operation.name for x in retval], # pylint:disable=protected-access + [x._keras_history.operation.name for x in inputs]) # pylint:disable=protected-access + return retval + + def _build(self): + """ Extract the sub-models from the saved model that are required for inference. + + Returns + ------- + :class:`keras.Model` + The model compiled for inference + """ + logger.debug("Compiling inference model") + + layers = self._input + node_index = [0] + built = layers + + while True: + layers, history, node_index = self._layers_from_inputs(layers, node_index) + if not layers: + break + + built = self._build_layers(layers, history, built) + + assert len(self._input) == 1 + assert len(built) in (1, 2) + out = built[0] if len(built) == 1 else built + retval = keras.Model(inputs=self._input[0], outputs=out, name=self._name) + logger.debug("Compiled inference model '%s': %s", retval.name, retval) + + return retval + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/_base/io.py b/plugins/train/model/_base/io.py new file mode 100644 index 0000000000..a4635f2a7d --- /dev/null +++ b/plugins/train/model/_base/io.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python3 +""" +IO handling for the model base plugin. + +The objects in this module should not be called directly, but are called from +:class:`~plugins.train.model._base.ModelBase` + +This module handles: + - The loading, saving and backing up of keras models to and from disk. + - The loading and freezing of weights for model plugins. +""" +from __future__ import annotations +import logging +import os +import sys +import typing as T +from keras import layers, models as kmodels + +from lib.logger import parse_class_init +from lib.model.backup_restore import Backup +from lib.utils import get_module_objects, FaceswapError + +from .update import Legacy, PatchKerasConfig + +if T.TYPE_CHECKING: + from .model import ModelBase + from keras import Optimizer + +logger = logging.getLogger(__name__) + + +def get_all_sub_models( + model: kmodels.Model, + models: list[kmodels.Model] | None = None) -> list[kmodels.Model]: + """ For a given model, return all sub-models that occur (recursively) as children. + + Parameters + ---------- + model: :class:`keras.models.Model` + A Keras model to scan for sub models + models: `None` + Do not provide this parameter. It is used for recursion + + Returns + ------- + list + A list of all :class:`keras.models.Model` objects found within the given model. + The provided model will always be returned in the first position + """ + if models is None: + models = [model] + else: + models.append(model) + for layer in model.layers: + if isinstance(layer, kmodels.Model): + get_all_sub_models(layer, models=models) + return models + + +class IO(): + """ Model saving and loading functions. + + Handles the loading and saving of the plugin model from disk as well as the model backup and + snapshot functions. + + Parameters + ---------- + plugin: :class:`Model` + The parent plugin class that owns the IO functions. + model_dir: str + The full path to the model save location + is_predict: bool + ``True`` if the model is being loaded for inference. ``False`` if the model is being loaded + for training. + save_optimizer: ["never", "always", "exit"] + When to save the optimizer weights. `"never"` never saves the optimizer weights. `"always"` + always saves the optimizer weights. `"exit"` only saves the optimizer weights on an exit + request. + """ + def __init__(self, + plugin: ModelBase, + model_dir: str, + is_predict: bool, + save_optimizer: T.Literal["never", "always", "exit"]) -> None: + logger.debug(parse_class_init(locals())) + self._plugin = plugin + self._is_predict = is_predict + self._model_dir = model_dir + self._save_optimizer = save_optimizer + self._history: list[float] = [] + """list[float]: Loss history for current save iteration """ + self._backup = Backup(self._model_dir, self._plugin.name) + self._update_legacy() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def model_dir(self) -> str: + """ str: The full path to the model folder """ + return self._model_dir + + @property + def filename(self) -> str: + """str: The filename for this model.""" + return os.path.join(self._model_dir, f"{self._plugin.name}.keras") + + @property + def model_exists(self) -> bool: + """ bool: ``True`` if a model of the type being loaded exists within the model folder + location otherwise ``False``. + """ + return os.path.isfile(self.filename) + + @property + def history(self) -> list[float]: + """ list[float]: list of loss history for the current save iteration. """ + return self._history + + @property + def multiple_models_in_folder(self) -> list[str] | None: + """ :list: or ``None`` If there are multiple model types in the requested folder, or model + types that don't correspond to the requested plugin type, then returns the list of plugin + names that exist in the folder, otherwise returns ``None`` """ + plugins = [fname.replace(".keras", "") + for fname in os.listdir(self._model_dir) + if fname.endswith(".keras")] + test_names = plugins + [self._plugin.name] + test = False if not test_names else os.path.commonprefix(test_names) == "" + retval = None if not test else plugins + logger.debug("plugin name: %s, plugins: %s, test result: %s, retval: %s", + self._plugin.name, plugins, test, retval) + return retval + + def _update_legacy(self) -> None: + """ Look for faceswap 2.x .h5 files in the model folder. If exists, then update to Faceswap + 3 .keras file and backup the original model .h5 file + + Note: Currently disabled as keras hangs trying to load old faceswap models + """ + if self.model_exists: + logger.debug("Existing model file is current: '%s'", os.path.basename(self.filename)) + return + + old_fname = f"{os.path.splitext(self.filename)[0]}.h5" + if not os.path.isfile(old_fname): + logger.debug("No legacy model file to update") + return + + Legacy(old_fname) + + def load(self) -> kmodels.Model: + """ Loads the model from disk + + If the predict function is to be called and the model cannot be found in the model folder + then an error is logged and the process exits. + + When loading the model, the plugin model folder is scanned for custom layers which are + added to Keras' custom objects. + + Returns + ------- + :class:`keras.models.Model` + The saved model loaded from disk + """ + logger.debug("Loading model: %s", self.filename) + if self._is_predict and not self.model_exists: + logger.error("Model could not be found in folder '%s'. Exiting", self._model_dir) + sys.exit(1) + + try: + model = kmodels.load_model(self.filename, compile=False) + except RuntimeError as err: + if "unable to get link info" in str(err).lower(): + msg = (f"Unable to load the model from '{self.filename}'. This may be a " + "temporary error but most likely means that your model has corrupted.\n" + "You can try to load the model again but if the problem persists you " + "should use the Restore Tool to restore your model from backup.\n" + f"Original error: {str(err)}") + raise FaceswapError(msg) from err + raise err + except KeyError as err: + if "unable to open object" in str(err).lower(): + msg = (f"Unable to load the model from '{self.filename}'. This may be a " + "temporary error but most likely means that your model has corrupted.\n" + "You can try to load the model again but if the problem persists you " + "should use the Restore Tool to restore your model from backup.\n" + f"Original error: {str(err)}") + raise FaceswapError(msg) from err + if 'parameter name can\\\'t contain "."' in str(err).lower(): + PatchKerasConfig(self.filename)() + return self.load() + raise err + except TypeError as err: + if any(x in str(err) for x in ("Could not locate class 'Conv2D'", + "Could not locate class 'DepthwiseConv2D'")): + PatchKerasConfig(self.filename)() + return self.load() + raise err + + logger.info("Loaded model from disk: '%s'", self.filename) + return model # pyright:ignore[reportReturnType] + + def _remove_optimizer(self) -> Optimizer: + """ Keras 3 `.keras` format ignores the `save_optimizer` kwarg. To hack around this we + remove the optimizer from the model prior to saving and then re-attach it to the model + + Returns + ------- + :class:`keras.optimizers.Optimizer` | None + The optimizer for the model, if it should not be saved. ``None`` if it should be saved + """ + retval = self._plugin.model.optimizer + del self._plugin.model.optimizer + logger.debug("Removed optimizer for saving: %s", retval) + return retval + + def _save_model(self, is_exit: bool, force_save_optimizer: bool) -> None: + """ Save the model either with or without the optimizer weights + + Keras 3 ignores 'save_optimizer` so if it should not be saved, we remove it from + the model for saving, then re-attach it + + Parameters + ---------- + is_exit: bool + ``True`` if the save request has come from an exit process request otherwise ``False``. + force_save_optimizer: bool + ``True`` to force saving the optimizer weights with the model, otherwise ``False``. + """ + include_optimizer = (force_save_optimizer or + self._save_optimizer == "always" or + (self._save_optimizer == "exit" and is_exit)) + + optimizer = None + if not include_optimizer: + optimizer = self._remove_optimizer() + + self._plugin.model.save(self.filename) + self._plugin.state.save() + + if not include_optimizer: + assert optimizer is not None + logger.debug("Re-attaching optimizer: %s", optimizer) + setattr(self._plugin.model, "optimizer", optimizer) + + def _get_save_average(self) -> float: + """ Return the average loss since the last save iteration and reset historical loss + + Returns + ------- + float + The average loss since the last save iteration + """ + logger.debug("Getting save averages") + if not self._history: + logger.debug("No loss in history") + retval = 0.0 + else: + retval = sum(self._history) / len(self._history) + self._history = [] # Reset historical loss + logger.debug("Average loss since last save: %s", round(retval, 5)) + return retval + + def _should_backup(self, save_average: float) -> bool: + """ Check whether the loss average for this save iteration is the lowest that has been + seen. + + This protects against model corruption by only backing up the model if the sum of all loss + functions has fallen. + + Notes + ----- + This is by no means a perfect system. If the model corrupts at an iteration close + to a save iteration, then the averages may still be pushed lower than a previous + save average, resulting in backing up a corrupted model. Changing loss weighting can also + arteficially impact this + + Parameters + ---------- + save_average: float + The average loss since the last save iteration + """ + if not self._plugin.state.lowest_avg_loss: + logger.debug("Set initial save iteration loss average: %s", save_average) + self._plugin.state.lowest_avg_loss = save_average + return False + + old_average = self._plugin.state.lowest_avg_loss + backup = save_average < old_average + + if backup: # Update lowest loss values to the state file + self._plugin.state.lowest_avg_loss = save_average + logger.debug("Updated lowest historical save iteration average from: %s to: %s", + old_average, save_average) + + logger.debug("Should backup: %s", backup) + return backup + + def _maybe_backup(self) -> tuple[float, bool]: + """ Backup the model if total average loss has dropped for the save iteration + + Returns + ------- + float + The total loss average since the last save iteration + bool + ``True`` if the model was backed up + """ + save_average = self._get_save_average() + should_backup = self._should_backup(save_average) + if not save_average or not should_backup: + logger.debug("Not backing up model (save_average: %s, should_backup: %s)", + save_average, should_backup) + return save_average, False + + logger.debug("Backing up model") + self._backup.backup_model(self.filename) + self._backup.backup_model(self._plugin.state.filename) + return save_average, True + + def save(self, + is_exit: bool = False, + force_save_optimizer: bool = False) -> None: + """ Backup and save the model and state file. + + Parameters + ---------- + is_exit: bool, optional + ``True`` if the save request has come from an exit process request otherwise ``False``. + Default: ``False`` + force_save_optimizer: bool, optional + ``True`` to force saving the optimizer weights with the model, otherwise ``False``. + Default:``False`` + """ + logger.debug("Backing up and saving models") + print("\x1b[2K", end="\r") # Clear last line + logger.info("Saving Model...") + + self._save_model(is_exit, force_save_optimizer) + save_average, backed_up = self._maybe_backup() + + msg = "[Saved optimizer state for Snapshot]" if force_save_optimizer else "[Saved model]" + if save_average: + msg += f" - Average total loss since last save: {save_average:.5f}" + if backed_up: + msg += " [Model backed up]" + logger.info(msg) + + def snapshot(self) -> None: + """ Perform a model snapshot. + + Notes + ----- + Snapshot function is called 1 iteration after the model was saved, so that it is built from + the latest save, hence iteration being reduced by 1. + """ + logger.debug("Performing snapshot. Iterations: %s", self._plugin.iterations) + self._backup.snapshot_models(self._plugin.iterations - 1) + logger.debug("Performed snapshot") + + +class Weights(): + """ Handling of freezing and loading model weights + + Parameters + ---------- + plugin: :class:`Model` + The parent plugin class that owns the IO functions. + """ + def __init__(self, plugin: ModelBase) -> None: + logger.debug("Initializing %s: (plugin: %s)", self.__class__.__name__, plugin) + self._model = plugin.model + self._name = plugin.model_name + self._do_freeze = plugin._args.freeze_weights + self._weights_file = self._check_weights_file(plugin._args.load_weights) + + self._freeze_layers = plugin.freeze_layers + self._load_layers = plugin.load_layers + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _check_weights_file(cls, weights_file: str) -> str | None: + """ Validate that we have a valid path to a .keras file. + + Parameters + ---------- + weights_file: str + The full path to a weights file + + Returns + ------- + str + The full path to a weights file + """ + if not weights_file: + logger.debug("No weights file selected.") + return None + + msg = "" + if not os.path.exists(weights_file): + msg = f"Load weights selected, but the path '{weights_file}' does not exist." + elif not os.path.splitext(weights_file)[-1].lower() == ".keras": + msg = (f"Load weights selected, but the path '{weights_file}' is not a valid Keras " + f"model (.keras) file.") + + if msg: + msg += " Please check and try again." + raise FaceswapError(msg) + + logger.verbose("Using weights file: %s", weights_file) # type:ignore + return weights_file + + def freeze(self) -> None: + """ If freeze has been selected in the cli arguments, then freeze those models indicated + in the plugin's configuration. """ + # Blanket unfreeze layers, as checking the value of :attr:`layer.trainable` appears to + # return ``True`` even when the weights have been frozen + for layer in get_all_sub_models(self._model): + layer.trainable = True + + if not self._do_freeze: + logger.debug("Freeze weights deselected. Not freezing") + return + + for layer in get_all_sub_models(self._model): + if layer.name in self._freeze_layers: + logger.info("Freezing weights for '%s' in model '%s'", layer.name, self._name) + layer.trainable = False + self._freeze_layers.remove(layer.name) + if self._freeze_layers: + logger.warning("The following layers were set to be frozen but do not exist in the " + "model: %s", self._freeze_layers) + + def load(self, model_exists: bool) -> None: + """ Load weights for newly created models, or output warning for pre-existing models. + + Parameters + ---------- + model_exists: bool + ``True`` if a model pre-exists and is being resumed, ``False`` if this is a new model + """ + if not self._weights_file: + logger.debug("No weights file provided. Not loading weights.") + return + if model_exists and self._weights_file: + logger.warning("Ignoring weights file '%s' as this model is resuming.", + self._weights_file) + return + + weights_models = self._get_weights_model() + all_models = get_all_sub_models(self._model) + loaded_ops = 0 + skipped_ops = 0 + + for model_name in self._load_layers: + sub_model = next((lyr for lyr in all_models if lyr.name == model_name), None) + sub_weights = next((lyr for lyr in weights_models if lyr.name == model_name), None) + + if not sub_model or not sub_weights: + msg = f"Skipping layer {model_name} as not in " + msg += "current_model." if not sub_model else f"weights '{self._weights_file}.'" + logger.warning(msg) + continue + + logger.info("Loading weights for layer '%s'", model_name) + skipped_ops = 0 + loaded_ops = 0 + for layer in sub_model.layers: + success = self._load_layer_weights(layer, sub_weights, model_name) + if success == 0: + skipped_ops += 1 + elif success == 1: + loaded_ops += 1 + + del weights_models + + if loaded_ops == 0: + raise FaceswapError(f"No weights were succesfully loaded from your weights file: " + f"'{self._weights_file}'. Please check and try again.") + if skipped_ops > 0: + logger.warning("%s weight(s) were unable to be loaded for your model. This is most " + "likely because the weights you are trying to load were trained with " + "different settings than you have set for your current model.", + skipped_ops) + + def _get_weights_model(self) -> list[kmodels.Model]: + """ Obtain a list of all sub-models contained within the weights model. + + Returns + ------- + list + List of all models contained within the .keras file + + Raises + ------ + FaceswapError + In the event of a failure to load the weights, or the weights belonging to a different + model + """ + retval = get_all_sub_models(kmodels.load_model( # pyright:ignore[reportArgumentType] + self._weights_file, + compile=False)) + if not retval: + raise FaceswapError(f"Error loading weights file {self._weights_file}.") + + if retval[0].name != self._name: + raise FaceswapError(f"You are attempting to load weights from a '{retval[0].name}' " + f"model into a '{self._name}' model. This is not supported.") + return retval + + def _load_layer_weights(self, + layer: layers.Layer, + sub_weights: layers.Layer, + model_name: str) -> T.Literal[-1, 0, 1]: + """ Load the weights for a single layer. + + Parameters + ---------- + layer: :class:`keras.layers.Layer` + The layer to set the weights for + sub_weights: list + The list of layers in the weights model to load weights from + model_name: str + The name of the current sub-model that is having it's weights loaded + + Returns + ------- + int + `-1` if the layer has no weights to load. `0` if weights loading was unsuccessful. `1` + if weights loading was successful + """ + old_weights = layer.get_weights() + if not old_weights: + logger.debug("Skipping layer without weights: %s", layer.name) + return -1 + + layer_weights = next((lyr for lyr in sub_weights.layers + if lyr.name == layer.name), None) + if not layer_weights: + logger.warning("The weights file '%s' for layer '%s' does not contain weights for " + "'%s'. Skipping", self._weights_file, model_name, layer.name) + return 0 + + new_weights = layer_weights.get_weights() + if old_weights[0].shape != new_weights[0].shape: + logger.warning("The weights for layer '%s' are of incompatible shapes. Skipping.", + layer.name) + return 0 + logger.verbose("Setting weights for '%s'", layer.name) # type:ignore + layer.set_weights(layer_weights.get_weights()) + return 1 + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/_base/model.py b/plugins/train/model/_base/model.py new file mode 100644 index 0000000000..e8274a0548 --- /dev/null +++ b/plugins/train/model/_base/model.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +""" +Base class for Models. ALL Models should at least inherit from this class. + +See :mod:`~plugins.train.model.original` for an annotated example for how to create model plugins. +""" +from __future__ import annotations +import logging +import os +import sys +import typing as T + +import keras + +from lib.logger import parse_class_init +from lib.utils import get_module_objects, FaceswapError +from plugins.train import train_config as cfg + +from .inference import Inference +from .io import IO, get_all_sub_models, Weights +from .settings import Loss, Optimizer, Settings +from .state import State + +if T.TYPE_CHECKING: + import argparse + import numpy as np + + +logger = logging.getLogger(__name__) + + +class ModelBase(): # pylint:disable=too-many-instance-attributes + """ Base class that all model plugins should inherit from. + + Parameters + ---------- + model_dir: str + The full path to the model save location + arguments: :class:`argparse.Namespace` + The arguments that were passed to the train or convert process as generated from + Faceswap's command line arguments + predict: bool, optional + ``True`` if the model is being loaded for inference, ``False`` if the model is being loaded + for training. Default: ``False`` + + Attributes + ---------- + input_shape: tuple or list + A `tuple` of `ints` defining the shape of the faces that the model takes as input. This + should be overridden by model plugins in their :func:`__init__` function. If the input size + is the same for both sides of the model, then this can be a single 3 dimensional `tuple`. + If the inputs have different sizes for `"A"` and `"B"` this should be a `list` of 2 3 + dimensional shape `tuples`, 1 for each side respectively. + """ + def __init__(self, + model_dir: str, + arguments: argparse.Namespace, + predict: bool = False) -> None: + logger.debug(parse_class_init(locals())) + # Input shape must be set within the plugin after initializing + self.input_shape: tuple[int, ...] = () + self.color_order: T.Literal["bgr", "rgb"] = "bgr" # Override for image color channel order + + self._args = arguments + self._is_predict = predict + self._model: keras.Model | None = None + + cfg.load_config(config_file=arguments.configfile) + + if cfg.Loss.penalized_mask_loss() and cfg.Loss.mask_type() == "none": + raise FaceswapError("Penalized Mask Loss has been selected but you have not chosen a " + "Mask to use. Please select a mask or disable Penalized Mask " + "Loss.") + + if cfg.Loss.learn_mask() and cfg.Loss.mask_type() == "none": + raise FaceswapError("'Learn Mask' has been selected but you have not chosen a Mask to " + "use. Please select a mask or disable 'Learn Mask'.") + + self._mixed_precision = cfg.mixed_precision() + self._io = IO(self, model_dir, + self._is_predict, + T.cast(T.Literal["never", "always", "exit"], cfg.Optimizer.save_optimizer())) + self._check_multiple_models() + + self._state = State(model_dir, + self.name, + False if self._is_predict else self._args.no_logs) + self._settings = Settings(self._args, + self._mixed_precision, + self._is_predict) + self._loss = Loss(self.color_order) + + logger.debug("Initialized ModelBase (%s)", self.__class__.__name__) + + @property + def model(self) -> keras.Model: + """:class:`keras.Model`: The compiled model for this plugin. """ + return self._model + + @property + def command_line_arguments(self) -> argparse.Namespace: + """ :class:`argparse.Namespace`: The command line arguments passed to the model plugin from + either the train or convert script """ + return self._args + + @property + def coverage_ratio(self) -> float: + """ float: The ratio of the training image to crop out and train on as defined in user + configuration options. + + NB: The coverage ratio is a raw float, but will be applied to integer pixel images. + + To ensure consistent rounding and guaranteed even image size, the calculation for coverage + should always be: :math:`(original_size * coverage_ratio // 2) * 2` + """ + return cfg.coverage() / 100. + + @property + def io(self) -> IO: # pylint:disable=invalid-name + """ :class:`~plugins.train.model.io.IO`: Input/Output operations for the model """ + return self._io + + @property + def name(self) -> str: + """ str: The name of this model based on the plugin name. """ + _name = sys.modules[self.__module__].__file__ + assert isinstance(_name, str) + return os.path.splitext(os.path.basename(_name))[0].lower() + + @property + def model_name(self) -> str: + """ str: The name of the keras model. Generally this will be the same as :attr:`name` + but some plugins will override this when they contain multiple architectures """ + return self.name + + @property + def input_shapes(self) -> list[tuple[None, int, int, int]]: + """ list: A flattened list corresponding to all of the inputs to the model. """ + shapes = [T.cast(tuple[None, int, int, int], inputs.shape) + for inputs in self.model.inputs] + return shapes + + @property + def output_shapes(self) -> list[tuple[None, int, int, int]]: + """ list: A flattened list corresponding to all of the outputs of the model. """ + shapes = [T.cast(tuple[None, int, int, int], output.shape) + for output in self.model.outputs] + return shapes + + @property + def iterations(self) -> int: + """ int: The total number of iterations that the model has trained. """ + return self._state.iterations + + @property + def warmup_steps(self) -> int: + """ int : The number of steps to perform learning rate warmup """ + return self._args.warmup + + @property + def freeze_layers(self) -> list[str]: + """ list[str] : Override to set plugin specific layers that can be frozen. Defaults to + ["encoder"] """ + return ["encoder"] + + @property + def load_layers(self) -> list[str]: + """ list[str] : Override to set plugin specific layers that can be loaded. Defaults to + ["encoder"] """ + return ["encoder"] + + # Private properties + @property + def _config_section(self) -> str: + """ str: The section name for the current plugin for loading configuration options from the + config file. """ + return ".".join(self.__module__.split(".")[-2:]) + + @property + def state(self) -> "State": + """:class:`State`: The state settings for the current plugin. """ + return self._state + + def _check_multiple_models(self) -> None: + """ Check whether multiple models exist in the model folder, and that no models exist that + were trained with a different plugin than the requested plugin. + + Raises + ------ + FaceswapError + If multiple model files, or models for a different plugin from that requested exists + within the model folder + """ + multiple_models = self._io.multiple_models_in_folder + if multiple_models is None: + logger.debug("Contents of model folder are valid") + return + + if len(multiple_models) == 1: + msg = (f"You have requested to train with the '{self.name}' plugin, but a model file " + f"for the '{multiple_models[0]}' plugin already exists in the folder " + f"'{self.io.model_dir}'.\nPlease select a different model folder.") + else: + ptypes = "', '".join(multiple_models) + msg = (f"There are multiple plugin types ('{ptypes}') stored in the model folder '" + f"{self.io.model_dir}'. This is not supported.\nPlease split the model files " + "into their own folders before proceeding") + raise FaceswapError(msg) + + def build(self) -> None: + """ Build the model and assign to :attr:`model`. + + Within the defined strategy scope, either builds the model from scratch or loads an + existing model if one exists. + + If running inference, then the model is built only for the required side to perform the + swap function, otherwise the model is then compiled with the optimizer and chosen + loss function(s). + + Finally, a model summary is outputted to the logger at verbose level. + """ + is_summary = hasattr(self._args, "summary") and self._args.summary + if self._io.model_exists: + model = self.io.load() + if self._is_predict: + inference = Inference(model, self._args.swap_model) + self._model = inference.model + else: + self._model = model + else: + self._validate_input_shape() + inputs = self._get_inputs() + if not self._settings.use_mixed_precision and not is_summary: + # Store layer names which can be switched to mixed precision + model, mp_layers = self._settings.get_mixed_precision_layers(self.build_model, + inputs) + self._state.add_mixed_precision_layers(mp_layers) + self._model = model + else: + self._model = self.build_model(inputs) + if not is_summary and not self._is_predict: + self._compile_model() + self._output_summary() + + def _validate_input_shape(self) -> None: + """ Validate that the input shape is either a single shape tuple of 3 dimensions or + a list of 2 shape tuples of 3 dimensions. """ + assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple" + + def _get_inputs(self) -> list[keras.layers.Input]: + """ Obtain the standardized inputs for the model. + + The inputs will be returned for the "A" and "B" sides in the shape as defined by + :attr:`input_shape`. + + Returns + ------- + list + A list of :class:`keras.layers.Input` tensors. This will be a list of 2 tensors (one + for each side) each of shapes :attr:`input_shape`. + """ + logger.debug("Getting inputs") + input_shapes = [self.input_shape, self.input_shape] + inputs = [keras.layers.Input(shape=shape, name=f"face_in_{side}") + for side, shape in zip(("a", "b"), input_shapes)] + logger.debug("inputs: %s", inputs) + return inputs + + def build_model(self, inputs: list[keras.layers.Input]) -> keras.Model: + """ Override for Model Specific autoencoder builds. + + Parameters + ---------- + inputs: list + A list of :class:`keras.layers.Input` tensors. This will be a list of 2 tensors (one + for each side) each of shapes :attr:`input_shape`. + + Returns + ------- + :class:`keras.Model` + See Keras documentation for the correct structure, but note that parameter :attr:`name` + is a required rather than an optional argument in Faceswap. You should assign this to + the attribute ``self.name`` that is automatically generated from the plugin's filename. + """ + raise NotImplementedError + + def _summary_to_log(self, summary: str) -> None: + """ Function to output Keras model summary to log file at verbose log level + + Parameters + ---------- + summary, str + The model summary output from keras + """ + for line in summary.splitlines(): + logger.verbose(line) # type:ignore[attr-defined] + + def _output_summary(self) -> None: + """ Output the summary of the model and all sub-models to the verbose logger. """ + if hasattr(self._args, "summary") and self._args.summary: + print_fn = None # Print straight to stdout + else: + # print to logger + print_fn = self._summary_to_log + parent = self.model + for idx, model in enumerate(get_all_sub_models(self.model)): + if idx == 0: + parent = model + continue + model.summary(print_fn=print_fn) + parent.summary(print_fn=print_fn) + + def _compile_model(self) -> None: + """ Compile the model to include the Optimizer and Loss Function(s). """ + logger.debug("Compiling Model") + + if self.state.model_needs_rebuild: + self._model = self._settings.check_model_precision(self._model, self._state) + + optimizer = Optimizer().optimizer + if self._settings.use_mixed_precision: + optimizer = self._settings.loss_scale_optimizer(optimizer) + + weights = Weights(self) + weights.load(self._io.model_exists) + weights.freeze() + + self._loss.configure(self.model) + losses = list(self._loss.functions.values()) + self.model.compile(optimizer=optimizer, loss=losses) + self._state.add_session_loss_names(self._loss.names) + logger.debug("Compiled Model: %s", self.model) + + def add_history(self, loss: np.ndarray) -> None: + """ Add the current iteration's loss history to :attr:`_io.history`. + + Called from the trainer after each iteration, for tracking loss drop over time between + save iterations. + + Parameters + ---------- + loss : :class:`numpy.ndarray` + The loss values for the A and B side for the current iteration. This should be the + collated loss values for each side. + """ + self._io.history.append(float(sum(loss))) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/_base/settings.py b/plugins/train/model/_base/settings.py new file mode 100644 index 0000000000..1875a1c8bf --- /dev/null +++ b/plugins/train/model/_base/settings.py @@ -0,0 +1,767 @@ +#!/usr/bin/env python3 +""" +Settings for the model base plugins. + +The objects in this module should not be called directly, but are called from +:class:`~plugins.train.model._base.ModelBase` + +Handles configuration of model plugins for: + - Loss configuration + - Optimizer settings + - General global model configuration settings +""" +from __future__ import annotations +from dataclasses import dataclass, field +import logging +import typing as T + +import keras +from keras import config as k_config, dtype_policies, losses as k_losses, optimizers + +from lib.model import losses +from lib.model.optimizers import AdaBelief +from lib.model.autoclip import AutoClipper +from lib.model.nn_blocks import reset_naming +from lib.logger import parse_class_init +from lib.utils import get_module_objects +from plugins.train.train_config import Loss as cfg_loss, Optimizer as cfg_opt + +if T.TYPE_CHECKING: + from collections.abc import Callable + from argparse import Namespace + from keras import KerasTensor + from .state import State + +logger = logging.getLogger(__name__) + + +@dataclass +class LossClass: + """ Typing class for holding loss functions. + + Parameters + ---------- + function: Callable + The function that takes in the true/predicted images and returns the loss + init: bool, Optional + Whether the loss object ``True`` needs to be initialized (i.e. it's a class) or + ``False`` it does not require initialization (i.e. it's a function). + Default ``True`` + kwargs: dict + Any keyword arguments to supply to the loss function at initialization. + """ + function: Callable[[KerasTensor, KerasTensor], + KerasTensor] | T.Any = k_losses.MeanSquaredError + init: bool = True + kwargs: dict[str, T.Any] = field(default_factory=dict) + + +class Loss(): + """ Holds loss names and functions for an Autoencoder. + + Parameters + ---------- + color_order: str + Color order of the model. One of `"BGR"` or `"RGB"` + """ + def __init__(self, color_order: T.Literal["bgr", "rgb"]) -> None: + logger.debug(parse_class_init(locals())) + self._mask_channels = self._get_mask_channels() + self._inputs: list[keras.layers.Layer] = [] + self._names: list[str] = [] + self._funcs: dict[str, losses.LossWrapper | T.Callable[[KerasTensor, KerasTensor], + KerasTensor]] = {} + + self._loss_dict = {"ffl": LossClass(function=losses.FocalFrequencyLoss), + "flip": LossClass(function=losses.LDRFLIPLoss, + kwargs={"color_order": color_order}), + "gmsd": LossClass(function=losses.GMSDLoss), + "l_inf_norm": LossClass(function=losses.LInfNorm), + "laploss": LossClass(function=losses.LaplacianPyramidLoss), + "logcosh": LossClass(function=k_losses.LogCosh), + "lpips_alex": LossClass(function=losses.LPIPSLoss, + kwargs={"trunk_network": "alex"}), + "lpips_squeeze": LossClass(function=losses.LPIPSLoss, + kwargs={"trunk_network": "squeeze"}), + "lpips_vgg16": LossClass(function=losses.LPIPSLoss, + kwargs={"trunk_network": "vgg16"}), + "ms_ssim": LossClass(function=losses.MSSIMLoss), + "mae": LossClass(function=k_losses.MeanAbsoluteError), + "mse": LossClass(function=k_losses.MeanSquaredError), + "pixel_gradient_diff": LossClass(function=losses.GradientLoss), + "ssim": LossClass(function=losses.DSSIMObjective), + "smooth_loss": LossClass(function=losses.GeneralizedLoss)} + + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def names(self) -> list[str]: + """ list: The list of loss names for the model. """ + return self._names + + @property + def functions(self) -> dict[str, losses.LossWrapper | T.Callable[[KerasTensor, KerasTensor], + KerasTensor]]: + """ dict[str, :class:`~lib.model.losses.LossWrapper` | | Callable[[KerasTensor, + KerasTensor], KerasTensor]]]: The loss functions that apply to each model output. """ + return self._funcs + + @property + def _mask_inputs(self) -> list | None: + """ list: The list of input tensors to the model that contain the mask. Returns ``None`` + if there is no mask input to the model. """ + mask_inputs = [inp for inp in self._inputs if inp.name.startswith("mask")] + return None if not mask_inputs else mask_inputs + + @property + def _mask_shapes(self) -> list[tuple] | None: + """ list: The list of shape tuples for the mask input tensors for the model. Returns + ``None`` if there is no mask input. """ + if self._mask_inputs is None: + return None + return [mask_input.shape for mask_input in self._mask_inputs] + + def configure(self, model: keras.models.Model) -> None: + """ Configure the loss functions for the given inputs and outputs. + + Parameters + ---------- + model: :class:`keras.models.Model` + The model that is to be trained + """ + self._inputs = model.inputs + self._set_loss_names(model.outputs) + self._set_loss_functions(model.output_names) + self._names.insert(0, "total") + + def _set_loss_names(self, outputs: list[KerasTensor]) -> None: + """ Name the losses based on model output. + + This is used for correct naming in the state file, for display purposes only. + + Adds the loss names to :attr:`names` + + Parameters + ---------- + outputs: list[:class:`keras.KerasTensor`] + A list of output tensors from the model plugin + """ + # TODO Use output names if/when these are fixed upstream + split_outputs = [outputs[:len(outputs) // 2], outputs[len(outputs) // 2:]] + for side, side_output in zip(("a", "b"), split_outputs): + output_names = [output.name for output in side_output] + output_shapes = [output.shape[1:] for output in side_output] + output_types = ["mask" if shape[-1] == 1 else "face" for shape in output_shapes] + logger.debug("side: %s, output names: %s, output_shapes: %s, output_types: %s", + side, output_names, output_shapes, output_types) + for idx, name in enumerate(output_types): + suffix = "" if output_types.count(name) == 1 else f"_{idx}" + self._names.append(f"{name}_{side}{suffix}") + logger.debug(self._names) + + def _get_function(self, name: str) -> Callable[[KerasTensor, KerasTensor], KerasTensor]: + """ Obtain the requested Loss function + + Parameters + ---------- + name: str + The name of the loss function from the training configuration file + + Returns + ------- + Keras Loss Function + The requested loss function + """ + func = self._loss_dict[name] + retval = func.function(**func.kwargs) if func.init else func.function # type:ignore + logger.debug("Obtained loss function `%s` (%s)", name, retval) + return retval + + def _set_loss_functions(self, output_names: list[str]) -> None: + """ Set the loss functions and their associated weights. + + Adds the loss functions to the :attr:`functions` dictionary. + + Parameters + ---------- + output_names: list[str] + The output names from the model + """ + loss_funcs = [cfg_loss.loss_function(), + cfg_loss.loss_function_2(), + cfg_loss.loss_function_3(), + cfg_loss.loss_function_4()] + loss_amount = [100, + cfg_loss.loss_weight_2(), + cfg_loss.loss_weight_3(), + cfg_loss.loss_weight_4()] + face_losses = [(name, weight) for name, weight in zip(loss_funcs, loss_amount) + if name != "none" and weight > 0] + + for name, output_name in zip(self._names, output_names): + if name.startswith("mask"): + loss_func = self._get_function(cfg_loss.mask_loss_function()) + else: + loss_func = losses.LossWrapper() + for func, weight in face_losses: + self._add_face_loss_function(loss_func, func, weight / 100.) + + logger.debug("%s: (output_name: '%s', function: %s)", name, output_name, loss_func) + self._funcs[name] = loss_func + logger.debug("functions: %s", self._funcs) + + def _add_face_loss_function(self, + loss_wrapper: losses.LossWrapper, + loss_function: str, + weight: float) -> None: + """ Add the given face loss function at the given weight and apply any mouth and eye + multipliers + + Parameters + ---------- + loss_wrapper: :class:`lib.model.losses.LossWrapper` + The wrapper loss function that holds the face losses + loss_function: str + The loss function to add to the loss wrapper + weight: float + The amount of weight to apply to the given loss function + """ + logger.debug("Adding loss function: %s, weight: %s", loss_function, weight) + loss_wrapper.add_loss(self._get_function(loss_function), + weight=weight, + mask_channel=self._mask_channels[0]) + + channel_idx = 1 + for section, multiplier in zip( + ("eye_multiplier", "mouth_multiplier"), + (float(cfg_loss.eye_multiplier()), float(cfg_loss.mouth_multiplier()))): + mask_channel = self._mask_channels[channel_idx] + multiplier *= 1. + if multiplier > 1.: + logger.debug("Adding section loss %s: %s", section, multiplier) + loss_wrapper.add_loss(self._get_function(loss_function), + weight=weight * multiplier, + mask_channel=mask_channel) + channel_idx += 1 + + def _get_mask_channels(self) -> list[int]: + """ Obtain the channels from the face targets that the masks reside in from the training + data generator. + + Returns + ------- + list: + A list of channel indices that contain the mask for the corresponding config item + """ + eye_multiplier = cfg_loss.eye_multiplier() + mouth_multiplier = cfg_loss.mouth_multiplier() + if not cfg_loss.penalized_mask_loss() and (eye_multiplier > 1 or mouth_multiplier > 1): + logger.warning("You have selected eye/mouth loss multipliers greater than 1x, but " + "Penalized Mask Loss is disabled. Disabling all multipliers.") + eye_multiplier = 1 + mouth_multiplier = 1 + uses_masks = (cfg_loss.penalized_mask_loss(), eye_multiplier > 1, mouth_multiplier > 1) + mask_channels = [-1 for _ in range(len(uses_masks))] + current_channel = 3 + for idx, mask_required in enumerate(uses_masks): + if mask_required: + mask_channels[idx] = current_channel + current_channel += 1 + logger.debug("uses_masks: %s, mask_channels: %s", uses_masks, mask_channels) + return mask_channels + + +class Optimizer(): + """ Obtain the selected optimizer with the appropriate keyword arguments. """ + def __init__(self) -> None: + logger.debug(parse_class_init(locals())) + betas = {"ada_beta_1": "beta_1", "ada_beta_2": "beta_2"} + amsgrad = {"ada_amsgrad": "amsgrad"} + self._valid: dict[str, tuple[T.Type[Optimizer], dict[str, T.Any]]] = { + "adabelief": (AdaBelief, betas | amsgrad), + "adam": (optimizers.Adam, betas | amsgrad), + "adamax": (optimizers.Adamax, betas), + "adamw": (optimizers.AdamW, betas | amsgrad), + "lion": (optimizers.Lion, betas), + "nadam": (optimizers.Nadam, betas), + "rms-prop": (optimizers.RMSprop, {})} + + self._optimizer = self._valid[cfg_opt.optimizer()][0] + self._kwargs: dict[str, T.Any] = {"learning_rate": cfg_opt.learning_rate()} + if cfg_opt.optimizer() != "lion": + self._kwargs["epsilon"] = 10 ** int(cfg_opt.epsilon_exponent()) + + self._configure() + logger.info("Using %s optimizer", self._optimizer.__name__) + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def optimizer(self) -> optimizers.Optimizer: + """ :class:`keras.optimizers.Optimizer`: The requested optimizer. """ + return T.cast(optimizers.Optimizer, self._optimizer(**self._kwargs)) + + def _configure_clipping(self, + method: T.Literal["autoclip", "norm", "value", "none"], + value: float, + history: int) -> None: + """ Configure optimizer clipping related kwargs, if selected + + Parameters + ---------- + method: Literal["autoclip", "norm", "value", "none"] + The clipping method to use. ``None`` for no clipping + value: float + The value to clip by norm/value by. For autoclip, this is the clip percentile + (a value of 1.0 is a clip percentile of 10%) + history: int + autoclip only: The number of iterations to keep for calculating the normalized value + """ + logger.debug("method: '%s', value: %s, history: %s", method, value, history) + if method == "none": + logger.debug("clipping disabled") + return + + logger.info("Enabling Clipping: %s", method.replace("_", " ").replace("_", " ").title()) + clip_types = {"global_norm": "global_clipnorm", "norm": "clipnorm", "value": "clipvalue"} + if method in clip_types: + self._kwargs[clip_types[method]] = value + logger.debug("Setting clipping kwargs for '%s': %s", + method, {k: v for k, v in self._kwargs.items() + if k == clip_types[method]}) + return + + assert method == "autoclip" + # Test for if keras optimizer changes its structure to no longer have _clip_gradients. + # Ensures any tests fails in this situation + assert hasattr(self._optimizer, + "_clip_gradients"), "keras.BaseOptimizer._clip_gradients no longer exists" + + # TODO Keras3 has removed the ""gradient_transformers" kwarg, and there now appears to be + # no standardised method to add custom gradent transformers. Currently, we monkey patch its + # _clip_gradients function, which feels hacky and potentially problematic + setattr(self._optimizer, "_clip_gradients", AutoClipper(int(value * 10), + history_size=history)) + + def _configure_ema(self, enable: bool, momentum: float, frequency: int) -> None: + """ Confihure the optimizer kwargs for exponential moving average updates + + Parameters + ---------- + enable: bool + ``False`` to disable + momentum: float + the momentum to use when computing the EMA of the model's weights: new_average = + momentum * old_average + (1 - momentum) * current_variable_value + frequency: int + the number of iterations, to overwrite the model variable by its moving average. + """ + self._kwargs["use_ema"] = enable + if not enable: + logger.debug("ema disabled.") + return + + logger.info("Enabling EMA") + self._kwargs["ema_momentum"] = momentum + self._kwargs["ema_overwrite_frequency"] = frequency + logger.debug("ema enabled (momentum: %s, frequency: %s)", momentum, frequency) + + def _configure_kwargs(self, weight_decay: float, gradient_accumulation_steps: int) -> None: + """ Configure the remaining global optimizer kwargs + + Parameters + ---------- + weight_decay: float + The amount of weight decay to apply + gradient_accumulation_steps: int + The number of steps to accumulate gradients for before applying the average + """ + if weight_decay > 0.0: + logger.info("Enabling Weight Decay: %s", weight_decay) + self._kwargs["weight_decay"] = weight_decay + else: + logger.debug("weight decay disabled") + + if gradient_accumulation_steps > 1: + logger.info("Enabling Gradient Accumulation: %s", gradient_accumulation_steps) + self._kwargs["gradient_accumulation_steps"] = gradient_accumulation_steps + else: + logger.debug("gradient accumulation disabled") + + def _configure_specific(self) -> None: + """ Configure keyword optimizer specific keyword arguments based on user settings. """ + opts = self._valid[cfg_opt.optimizer()][1] + if not opts: + logger.debug("No additional kwargs to set for '%s'", cfg_opt.optimizer()) + return + + for key, val in opts.items(): + opt_val = getattr(cfg_opt, key)() + logger.debug("Setting kwarg '%s' from '%s' to: %s", val, key, opt_val) + self._kwargs[val] = opt_val + + def _configure(self) -> None: + """ Process the user configuration options into Keras Optimizer kwargs. """ + self._configure_clipping(T.cast(T.Literal["autoclip", "norm", "value", "none"], + cfg_opt.gradient_clipping()), + cfg_opt.clipping_value(), + cfg_opt.autoclip_history()) + + self._configure_ema(cfg_opt.use_ema(), + cfg_opt.ema_momentum(), + cfg_opt.ema_frequency()) + + self._configure_kwargs(cfg_opt.weight_decay(), + cfg_opt.gradient_accumulation()) + + self._configure_specific() + + logger.debug("Configured '%s' optimizer. kwargs: %s", cfg_opt.optimizer(), self._kwargs) + + +class Settings(): + """ Tensorflow core training settings. + + Sets backend tensorflow settings prior to launching the model. + + Tensorflow 2 uses distribution strategies for multi-GPU/system training. These are context + managers. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The arguments that were passed to the train or convert process as generated from + Faceswap's command line arguments + mixed_precision: bool + ``True`` if Mixed Precision training should be used otherwise ``False`` + is_predict: bool, optional + ``True`` if the model is being loaded for inference, ``False`` if the model is being loaded + for training. Default: ``False`` + """ + def __init__(self, + arguments: Namespace, + mixed_precision: bool, + is_predict: bool) -> None: + logger.debug("Initializing %s: (arguments: %s, mixed_precision: %s, is_predict: %s)", + self.__class__.__name__, arguments, mixed_precision, is_predict) + use_mixed_precision = not is_predict and mixed_precision + self._use_mixed_precision = use_mixed_precision + if use_mixed_precision: + logger.info("Enabling Mixed Precision Training.") + + self._set_keras_mixed_precision(use_mixed_precision) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def use_mixed_precision(self) -> bool: + """ bool: ``True`` if mixed precision training has been enabled, otherwise ``False``. """ + return self._use_mixed_precision + + @classmethod + def loss_scale_optimizer( + cls, + optimizer: optimizers.Optimizer) -> optimizers.LossScaleOptimizer: + """ Optimize loss scaling for mixed precision training. + + Parameters + ---------- + optimizer: :class:`keras.optimizers.Optimizer` + The optimizer instance to wrap + + Returns + -------- + :class:`keras.optimizers.LossScaleOptimizer` + The original optimizer with loss scaling applied + """ + return optimizers.LossScaleOptimizer(optimizer) + + @classmethod + def _set_keras_mixed_precision(cls, enable: bool) -> None: + """ Enable or disable Keras Mixed Precision. + + Parameters + ---------- + enable: bool + ``True`` to enable mixed precision. ``False`` to disable. + + Enables or disables the Keras Mixed Precision API if requested in the user configuration + file. + """ + policy = dtype_policies.DTypePolicy("mixed_float16" if enable else "float32") + k_config.set_dtype_policy(policy) + logger.debug("%s mixed precision. (Compute dtype: %s, variable_dtype: %s)", + "Enabling" if enable else "Disabling", + policy.compute_dtype, policy.variable_dtype) + +# def _get_strategy(self, +# strategy: T.Literal["default", "central-storage", "mirrored"] +# ) -> tf.distribute.Strategy | None: +# """ If we are running on Nvidia backend and the strategy is not ``None`` then return +# the correct tensorflow distribution strategy, otherwise return ``None``. +# +# Notes +# ----- +# By default Tensorflow defaults mirrored strategy to use the Nvidia NCCL method for +# reductions, however this is only available in Linux, so the method used falls back to +# `Hierarchical Copy All Reduce` if the OS is not Linux. +# +# Central Storage strategy is not compatible with Mixed Precision. However, in testing it +# worked fine when using a single GPU, so we monkey-patch out the tests for Mixed-Precision +# when using this strategy with a single GPU +# +# Parameters +# ---------- +# strategy: str +# One of 'default', 'central-storage' or 'mirrored'. +# +# Returns +# ------- +# :class:`tensorflow.distribute.Strategy` or `None` +# The request Tensorflow Strategy if the backend is Nvidia and the strategy is not +# `"Default"` otherwise ``None`` +# """ +# if get_backend() not in ("nvidia", "rocm"): +# retval = None +# elif strategy == "mirrored": +# retval = self._get_mirrored_strategy() +# elif strategy == "central-storage": +# retval = self._get_central_storage_strategy() +# else: +# retval = tf.distribute.get_strategy() +# logger.debug("Using strategy: %s", retval) +# return retval + +# @classmethod +# def _get_mirrored_strategy(cls) -> tf.distribute.MirroredStrategy: +# """ Obtain an instance of a Tensorflow Mirrored Strategy, setting the cross device +# operations appropriate for the OS in use. +# +# Returns +# ------- +# :class:`tensorflow.distribute.MirroredStrategy` +# The Mirrored Distribution Strategy object with correct cross device operations set +# """ +# if platform.system().lower() == "linux": +# cross_device_ops = tf.distribute.NcclAllReduce() +# else: +# cross_device_ops = tf.distribute.HierarchicalCopyAllReduce() +# logger.debug("cross_device_ops: %s", cross_device_ops) +# return tf.distribute.MirroredStrategy(cross_device_ops=cross_device_ops) + +# @classmethod +# def _get_central_storage_strategy(cls) -> tf.distribute.experimental.CentralStorageStrategy: +# """ Obtain an instance of a Tensorflow Central Storage Strategy. If the strategy is being +# run on a single GPU then monkey patch Tensorflows mixed-precision strategy checks to pass +# successfully. +# +# Returns +# ------- +# :class:`tensorflow.distribute.experimental.CentralStorageStrategy` +# The Central Storage Distribution Strategy object +# """ +# gpus = tf.config.get_visible_devices("GPU") +# if len(gpus) == 1: +# # TODO Remove these monkey patches when Strategy supports mixed-precision +# # pylint:disable=import-outside-toplevel +# from keras.mixed_precision import loss_scale_optimizer +# +# # Force a return of True on Loss Scale Optimizer Stategy check +# loss_scale_optimizer.strategy_supports_loss_scaling = lambda: True +# +# # As LossScaleOptimizer aggregates gradients internally, it passes `False` as the value +# # for `experimental_aggregate_gradients` in `OptimizerV2.apply_gradients`. This causes +# # the optimizer to fail when checking against this strategy. We could monkey patch +# # `Optimizer.apply_gradients`, but it is a lot more code to check, so we just switch +# # the `experimental_aggregate_gradients` back to `True`. In brief testing this does not +# # appear to have a negative impact. +# func = lambda s, grads, wvars, name: s._optimizer.apply_gradients( # noqa pylint:disable=protected-access,unnecessary-lambda-assignment +# list(zip(grads, wvars.value)), name, experimental_aggregate_gradients=True) +# loss_scale_optimizer.LossScaleOptimizer._apply_gradients = func # noqa pylint:disable=protected-access + +# return tf.distribute.experimental.CentralStorageStrategy(parameter_device="/cpu:0") + + @classmethod + def _dtype_from_config(cls, config: dict[str, T.Any]) -> str: + """ Obtain the dtype of a layer from the given layer config + + Parameters + ---------- + config: dict[str, Any] : The Keras layer configuration dictionary + + Returns + ------- + str + The datatype of the layer + """ + dtype = config["dtype"] + logger.debug("Obtaining layer dtype from config: %s", dtype) + if isinstance(dtype, str): + return dtype + # Fail tests if Keras changes the way it stores dtypes + assert isinstance(dtype, dict) and "config" in dtype, ( + "Keras config dtype storage method has changed") + + dtype_conf = dtype["config"] + # Fail tests if Keras changes the way it stores dtypes + assert isinstance(dtype_conf, dict) and "name" in dtype_conf, ( + "Keras config dtype storage method has changed") + + retval = dtype_conf["name"] + return retval + + def _get_mixed_precision_layers(self, layers: list[dict]) -> list[str]: + """ Obtain the names of the layers in a mixed precision model that have their dtype policy + explicitly set to mixed-float16. + + Parameters + ---------- + layers: List + The list of layers that appear in a keras's model configuration `dict` + + Returns + ------- + list + A list of layer names within the model that are assigned a float16 policy + """ + retval = [] + for layer in layers: + config = layer["config"] + + if layer["class_name"] in ("Functional", "Sequential"): # Recurse into sub-models + retval.extend(self._get_mixed_precision_layers(config["layers"])) + continue + + if "dtype" not in config: + logger.debug("Skipping unsupported layer: %s %s", + layer.get("name", f"class_name: {layer['class_name']}"), config) + continue + dtype = self._dtype_from_config(config) + logger.debug("layer: '%s', dtype: '%s'", config["name"], dtype) + + if dtype == "mixed_float16": + logger.debug("Adding supported mixed precision layer: %s %s", + layer["config"]["name"], dtype) + retval.append(layer["config"]["name"]) + else: + logger.debug("Skipping unsupported layer: %s %s", + layer["config"].get("name", f"class_name: {layer['class_name']}"), + dtype) + return retval + + def _switch_precision(self, layers: list[dict], compatible: list[str]) -> None: + """ Switch a model's datatype between mixed-float16 and float32. + + Parameters + ---------- + layers: List + The list of layers that appear in a keras's model configuration `dict` + compatible: List + A list of layer names that are compatible to have their datatype switched + """ + dtype = "mixed_float16" if self.use_mixed_precision else "float32" + + for layer in layers: + config = layer["config"] + + if layer["class_name"] in ["Functional", "Sequential"]: # Recurse into sub-models + self._switch_precision(config["layers"], compatible) + continue + + if layer["config"]["name"] not in compatible: + logger.debug("Skipping incompatible layer: %s", layer["config"]["name"]) + continue + + logger.debug("Updating dtype for %s from: %s to: %s", + layer["config"]["name"], config["dtype"], dtype) + config["dtype"] = dtype + + def get_mixed_precision_layers(self, + build_func: Callable[[list[keras.layers.Layer]], + keras.models.Model], + inputs: list[keras.layers.Layer] + ) -> tuple[keras.models.Model, list[str]]: + """ Get and store the mixed precision layers from a full precision enabled model. + + Parameters + ---------- + build_func: Callable + The function to be called to compile the newly created model + inputs: + The inputs to the model to be compiled + + Returns + ------- + model: :class:`keras.model` + The built model in fp32 + list + The list of layer names within the full precision model that can be switched + to mixed precision + """ + logger.debug("Storing Mixed Precision compatible layers.") + self._set_keras_mixed_precision(True) + with keras.device("CPU"): + model = build_func(inputs) + layers = self._get_mixed_precision_layers(model.get_config()["layers"]) + + del model + keras.backend.clear_session() + + self._set_keras_mixed_precision(False) + reset_naming() + model = build_func(inputs) + + logger.debug("model: %s, mixed precision layers: %s", model, layers) + return model, layers + + def check_model_precision(self, + model: keras.models.Model, + state: "State") -> keras.models.Model: + """ Check the model's precision. + + If this is a new model, then + Rewrite an existing model's training precsion mode from mixed-float16 to float32 or + vice versa. + + This is not easy to do in keras, so we edit the model's config to change the dtype policy + for compatible layers. Create a new model from this config, then port the weights from the + old model to the new model. + + Parameters + ---------- + model: :class:`keras.models.Model` + The original saved keras model to rewrite the dtype + state: ~:class:`plugins.train.model._base.model.State` + The State information for the model + + Returns + ------- + :class:`keras.models.Model` + The original model with the datatype updated + """ + if self.use_mixed_precision and not state.mixed_precision_layers: + # Switching to mixed precision on a model which was started in FP32 prior to the + # ability to switch between precisions on a saved model is not supported as we + # do not have the compatible layer names + logger.warning("Switching from Full Precision to Mixed Precision is not supported on " + "older model files. Reverting to Full Precision.") + return model + + config = model.get_config() + weights = model.get_weights() + + if not self.use_mixed_precision and not state.mixed_precision_layers: + # Switched to Full Precision, get compatible layers from model if not already stored + state.add_mixed_precision_layers(self._get_mixed_precision_layers(config["layers"])) + + self._switch_precision(config["layers"], state.mixed_precision_layers) + + del model + keras.backend.clear_session() + new_model = keras.models.Model().from_config(config) + + new_model.set_weights(weights) + logger.info("Mixed precision has been %s", + "enabled" if self.use_mixed_precision else "disabled") + return new_model + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/_base/state.py b/plugins/train/model/_base/state.py new file mode 100644 index 0000000000..54f119ffe3 --- /dev/null +++ b/plugins/train/model/_base/state.py @@ -0,0 +1,446 @@ +#! /usr/env/bin/python3 +""" Handles the loading and saving of a model's state file """ +from __future__ import annotations + +import logging +import os +import time +import typing as T +from importlib import import_module +from inspect import isclass + +from lib.logger import parse_class_init +from lib.serializer import get_serializer +from lib.utils import get_module_objects + +from lib.config.objects import ConfigItem, GlobalSection +from plugins.train import train_config as cfg + +if T.TYPE_CHECKING: + from lib.config import ConfigValueType + + +logger = logging.getLogger(__name__) + + +class State(): # pylint:disable=too-many-instance-attributes + """ Holds state information relating to the plugin's saved model. + + Parameters + ---------- + model_dir: str + The full path to the model save location + model_name: str + The name of the model plugin + no_logs: bool + ``True`` if Tensorboard logs should not be generated, otherwise ``False`` + """ + def __init__(self, + model_dir: str, + model_name: str, + no_logs: bool) -> None: + logger.debug(parse_class_init(locals())) + self._serializer = get_serializer("json") + filename = f"{model_name}_state.{self._serializer.file_extension}" + self._filename = os.path.join(model_dir, filename) + self._name = model_name + self._iterations = 0 + self._mixed_precision_layers: list[str] = [] + self._lr_finder = -1.0 + self._rebuild_model = False + self._sessions: dict[int, dict] = {} + self.lowest_avg_loss: float = 0.0 + """float: The lowest average loss seen between save intervals. """ + + self._config: dict[str, ConfigValueType] = {} + self._updateable_options: list[str] = [] + + self._load() + self._session_id = self._new_session_id() + self._create_new_session(no_logs) + logger.debug("Initialized %s:", self.__class__.__name__) + + @property + def filename(self) -> str: + """ str: Full path to the state filename """ + return self._filename + + @property + def loss_names(self) -> list[str]: + """ list: The loss names for the current session """ + return self._sessions[self._session_id]["loss_names"] + + @property + def current_session(self) -> dict: + """ dict: The state dictionary for the current :attr:`session_id`. """ + return self._sessions[self._session_id] + + @property + def iterations(self) -> int: + """ int: The total number of iterations that the model has trained. """ + return self._iterations + + @property + def session_id(self) -> int: + """ int: The current training session id. """ + return self._session_id + + @property + def sessions(self) -> dict[int, dict[str, T.Any]]: + """ dict[int, dict[str, Any]]: The session information for each session in the state + file """ + return {int(k): v for k, v in self._sessions.items()} + + @property + def mixed_precision_layers(self) -> list[str]: + """list: Layers that can be switched between mixed-float16 and float32. """ + return self._mixed_precision_layers + + @property + def lr_finder(self) -> float: + """ The value discovered from the learning rate finder. -1 if no value stored """ + return self._lr_finder + + @property + def model_needs_rebuild(self) -> bool: + """bool: ``True`` if mixed precision policy has changed so model needs to be rebuilt + otherwise ``False`` """ + return self._rebuild_model + + def _new_session_id(self) -> int: + """ Generate a new session id. Returns 1 if this is a new model, or the last session id + 1 + if it is a pre-existing model. + + Returns + ------- + int + The newly generated session id + """ + if not self._sessions: + session_id = 1 + else: + session_id = max(int(key) for key in self._sessions.keys()) + 1 + logger.debug(session_id) + return session_id + + def _create_new_session(self, no_logs: bool) -> None: + """ Initialize a new session, creating the dictionary entry for the session in + :attr:`_sessions`. + + Parameters + ---------- + no_logs: bool + ``True`` if Tensorboard logs should not be generated, otherwise ``False`` + """ + logger.debug("Creating new session. id: %s", self._session_id) + self._sessions[self._session_id] = {"timestamp": time.time(), + "no_logs": no_logs, + "loss_names": [], + "batchsize": 0, + "iterations": 0, + "config": {k: v for k, v in self._config.items() + if k in self._updateable_options}} + + def update_session_config(self, key: str, value: T.Any) -> None: + """ Update a configuration item of the currently loaded session. + + Parameters + ---------- + key: str + The configuration item to update for the current session + value: any + The value to update to + """ + old_val = self.current_session["config"][key] + assert isinstance(value, type(old_val)) + logger.debug("Updating configuration item '%s' from '%s' to '%s'", key, old_val, value) + self.current_session["config"][key] = value + + def add_session_loss_names(self, loss_names: list[str]) -> None: + """ Add the session loss names to the sessions dictionary. + + The loss names are used for Tensorboard logging + + Parameters + ---------- + loss_names: list + The list of loss names for this session. + """ + logger.debug("Adding session loss_names: %s", loss_names) + self._sessions[self._session_id]["loss_names"] = loss_names + + def add_session_batchsize(self, batch_size: int) -> None: + """ Add the session batch size to the sessions dictionary. + + Parameters + ---------- + batch_size: int + The batch size for the current training session + """ + logger.debug("Adding session batch size: %s", batch_size) + self._sessions[self._session_id]["batchsize"] = batch_size + + def increment_iterations(self) -> None: + """ Increment :attr:`iterations` and session iterations by 1. """ + self._iterations += 1 + self._sessions[self._session_id]["iterations"] += 1 + + def add_mixed_precision_layers(self, layers: list[str]) -> None: + """ Add the list of model's layers that are compatible for mixed precision to the + state dictionary """ + logger.debug("Storing mixed precision layers: %s", layers) + self._mixed_precision_layers = layers + + def add_lr_finder(self, learning_rate: float) -> None: + """ Add the optimal discovered learning rate from the learning rate finder + + Parameters + ---------- + learning_rate : float + The discovered learning rate + """ + logger.debug("Storing learning rate from LR Finder: %s", learning_rate) + self._lr_finder = learning_rate + + def save(self) -> None: + """ Save the state values to the serialized state file. """ + state = {"name": self._name, + "sessions": {k: v for k, v in self._sessions.items() + if v.get("iterations", 0) > 0}, + "lowest_avg_loss": self.lowest_avg_loss, + "iterations": self._iterations, + "mixed_precision_layers": self._mixed_precision_layers, + "lr_finder": self._lr_finder, + "config": self._config} + logger.debug("Saving State: %s", state) + self._serializer.save(self._filename, state) + logger.debug("Saved State: '%s'", self._filename) + + def _update_legacy_config(self) -> bool: + """ Legacy updates for new config additions. + + When new config items are added to the Faceswap code, existing model state files need to be + updated to handle these new items. + + Current existing legacy update items: + + * loss - If old `dssim_loss` is ``true`` set new `loss_function` to `ssim` otherwise + set it to `mae`. Remove old `dssim_loss` item + + * l2_reg_term - If this exists, set loss_function_2 to ``mse`` and loss_weight_2 to + the value held in the old ``l2_reg_term`` item + + * masks - If `learn_mask` does not exist then it is set to ``True`` if `mask_type` is + not ``None`` otherwise it is set to ``False``. + + * masks type - Replace removed masks 'dfl_full' and 'facehull' with `components` mask + + * clipnorm - Only existed in 2 models (DFL-SAE + Unbalanced). Replaced with global + option autoclip + + * Clip model - layer names have had to be changed to replace dots with underscores, so + replace these + + Returns + ------- + bool + ``True`` if legacy items exist and state file has been updated, otherwise ``False`` + """ + logger.debug("Checking for legacy state file update") + priors = ["dssim_loss", "mask_type", "mask_type", "l2_reg_term", "clipnorm", "autoclip"] + new_items = ["loss_function", "learn_mask", "mask_type", "loss_function_2", + "gradient_clipping", "clipping"] + updated = False + for old, new in zip(priors, new_items): + if old not in self._config: + logger.debug("Legacy item '%s' not in state config. Skipping update", old) + continue + + # dssim_loss > loss_function + if old == "dssim_loss": + self._config[new] = "ssim" if self._config[old] else "mae" + del self._config[old] + updated = True + logger.info("Updated state config from legacy dssim format. New config loss " + "function: '%s'", self._config[new]) + continue + + # Add learn mask option and set to True if model has "penalized_mask_loss" specified + if old == "mask_type" and new == "learn_mask" and new not in self._config: + self._config[new] = self._config["mask_type"] is not None + updated = True + logger.info("Added new 'learn_mask' state config item for this model. Value set " + "to: %s", self._config[new]) + continue + + # Replace removed masks with most similar equivalent + if old == "mask_type" and new == "mask_type" and self._config[old] in ("facehull", + "dfl_full"): + old_mask = self._config[old] + self._config[new] = "components" + updated = True + logger.info("Updated 'mask_type' from '%s' to '%s' for this model", + old_mask, self._config[new]) + + # Replace l2_reg_term with the correct loss_2_function and update the value of + # loss_2_weight + if old == "l2_reg_term": + self._config[new] = "mse" + self._config["loss_weight_2"] = self._config[old] + del self._config[old] + updated = True + logger.info("Updated state config from legacy 'l2_reg_term' to 'loss_function_2'") + + # Replace clipnorm with correct gradient clipping type and value + if old == "clipnorm": + self._config[new] = "norm" + del self._config[old] + updated = True + logger.info("Updated state config from legacy '%s' to '%s: %s'", old, new, old) + + # Replace autoclip with correct gradient clipping type + if old == "autoclip": + self._config[new] = old + del self._config[old] + updated = True + logger.info("Updated state config from legacy '%s' to '%s: %s'", old, new, old) + + # Update Clip layer names from dots to underscores + mixed_precision = self._mixed_precision_layers + if any("." in name for name in mixed_precision): + self._mixed_precision_layers = [x.replace(".", "_") for x in mixed_precision] + updated = True + logger.info("Updated state config for legacy 'mixed_precision' storage of Clip layers") + + logger.debug("State file updated for legacy config: %s", updated) + return updated + + def _get_global_options(self) -> dict[str, ConfigItem]: + """ Obtain all of the current global user config options + + Returns + ------- + dict[str, :class:`lib.config.objects.ConfigItem`] + All of the current global user configuration options + """ + objects = {key: val for key, val in vars(cfg).items() + if isinstance(val, ConfigItem) + or isclass(val) and issubclass(val, GlobalSection) and val != GlobalSection} + + retval: dict[str, ConfigItem] = {} + for key, obj in objects.items(): + if isinstance(obj, ConfigItem): + retval[key] = obj + continue + for name, opt in obj.__dict__.items(): + if isinstance(opt, ConfigItem): + retval[name] = opt + logger.debug("Loaded global config options: %s", {k: v.value for k, v in retval.items()}) + return retval + + def _get_model_options(self) -> dict[str, ConfigItem]: + """ Obtain all of the currently configured model user config options """ + mod_name = f"plugins.train.model.{self._name}_defaults" + try: + mod = import_module(mod_name) + except ModuleNotFoundError: + logger.debug("No plugin specific defaults file found at '%s'", mod_name) + return {} + + retval = {k: v for k, v in vars(mod).items() if isinstance(v, ConfigItem)} + logger.debug("Loaded '%s' config options: %s", + self._name, {k: v.value for k, v in retval.items()}) + return retval + + def _update_config(self) -> None: + """ Update the loaded training config with the one contained within the values loaded + from the state file. + + Check for any `fixed`=``False`` parameter changes and log info changes. + + Update any legacy config items to their current versions. + """ + legacy_update = self._update_legacy_config() + # Add any new items to state config for legacy purposes where the new default may be + # detrimental to an existing model. + legacy_defaults: dict[str, str | int | bool | float] = {"centering": "legacy", + "coverage": 62.5, + "mask_loss_function": "mse", + "optimizer": "adam", + "mixed_precision": False} + rebuild_tasks = ["mixed_precision"] + options = self._get_global_options() | self._get_model_options() + for key, opt in options.items(): + val: ConfigValueType = opt() + + if key not in self._config: + val = legacy_defaults.get(key, val) + logger.info("Adding new config item to state file: '%s': %s", key, repr(val)) + self._config[key] = val + + old_val = self._config[key] + old_val = "none" if old_val is None else old_val # We used to allow NoneType. No more + + if not opt.fixed: + self._updateable_options.append(key) + + if not opt.fixed and val != old_val: + self._config[key] = val + logger.info("Config item: '%s' has been updated from %s to %s", + key, repr(old_val), repr(val)) + self._rebuild_model = self._rebuild_model or key in rebuild_tasks + continue + + if val != old_val: + logger.debug("Fixed config item '%s' Updated from %s to %s from state file", + key, repr(val), repr(old_val)) + opt.set(old_val) + + if legacy_update: + self.save() + logger.info("Using configuration saved in state file") + logger.debug("Updateable items: %s", self._updateable_options) + + def _generate_config(self) -> None: + """ Generate an initial state config based on the currently selected user config """ + options = self._get_global_options() | self._get_model_options() + for key, val in options.items(): + self._config[key] = val.value + if not val.fixed: + self._updateable_options.append(key) + + logger.debug("Generated initial state config for '%s': %s", self._name, self._config) + logger.debug("Updateable items: %s", self._updateable_options) + + def _load(self) -> None: + """ Load a state file and set the serialized values to the class instance. + + Updates the model's config with the values stored in the state file. + """ + logger.debug("Loading State") + + if not os.path.exists(self._filename): + logger.info("No existing state file found. Generating.") + self._generate_config() + return + + state = self._serializer.load(self._filename) + self._name = state.get("name", self._name) + self._sessions = state.get("sessions", {}) + + self.lowest_avg_loss = state.get("lowest_avg_loss", 0.0) + if isinstance(self.lowest_avg_loss, dict): + lowest_avg_loss = sum(self.lowest_avg_loss.values()) + logger.debug("Collating legacy lowest_avg_loss from %s to %s", + self.lowest_avg_loss, lowest_avg_loss) + self.lowest_avg_loss = lowest_avg_loss + + self._iterations = state.get("iterations", 0) + self._mixed_precision_layers = state.get("mixed_precision_layers", []) + self._lr_finder = state.get("lr_finder", -1.0) + self._config = state.get("config", {}) + logger.debug("Loaded state: %s", state) + self._update_config() + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/_base/update.py b/plugins/train/model/_base/update.py new file mode 100644 index 0000000000..2130c0b464 --- /dev/null +++ b/plugins/train/model/_base/update.py @@ -0,0 +1,520 @@ +#! /usr/env/bin/python3 +""" Updating legacy faceswap models to the current version """ +import json +import logging +import os +import typing as T +import zipfile +from shutil import copyfile, copytree + +import h5py +import numpy as np +from keras import models as kmodels + +from lib.logger import parse_class_init +from lib.model.layers import ScalarOp +from lib.model.networks import TypeModelsViT, ViT +from lib.utils import get_module_objects, FaceswapError + +logger = logging.getLogger(__name__) + + +class Legacy: # pylint:disable=too-few-public-methods + """ Handles the updating of Keras 2.x models to Keras 3.x + + Generally Keras 2.x models will open in Keras 3.x. There are a couple of bugs in Keras 3 + legacy loading code which impacts Faceswap models: + - When a model receives a shared functional model as an inbound node, the node index needs + reducing by 1 (non-trivial to fix upstream) + - Keras 3 does not accept nested outputs, so Keras 2 FS models need to have the outputs + flattened + + Parameters + ---------- + model_path: str + Full path to the legacy Keras 2.x model h5 file to upgrade + """ + def __init__(self, model_path: str): + logger.debug(parse_class_init(locals())) + self._old_model_file = model_path + """str: Full path to the old .h5 model file""" + self._new_model_file = f"{os.path.splitext(model_path)[0]}.keras" + """str: Full path to the new .keras model file""" + self._functionals: set[str] = set() + """set[str]: The name of any Functional models discovered in the keras 2 model config""" + + self._upgrade_model() + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_model_config(self) -> dict[str, T.Any]: + """ Obtain a keras 2.x config from a keras 2.x .h5 file. + + As keras 3.x will error out loading the file, we collect it directly from the .h5 file + + Returns + ------- + dict[str, Any] + A keras 2.x model configuration dictionary + + Raises + ------ + FaceswapError + If the file is not a valid Faceswap 2 .h5 model file + """ + h5file = h5py.File(self._old_model_file, "r") + s_version = T.cast(str | None, h5file.attrs.get("keras_version")) + s_config = T.cast(str | None, h5file.attrs.get("model_config")) + if not s_version or not s_config: + raise FaceswapError(f"'{self._old_model_file}' is not a valid Faceswap 2 model file") + + version = s_version.split(".")[:2] + if len(version) != 2 or version[0] != "2": + raise FaceswapError(f"'{self._old_model_file}' is not a valid Faceswap 2 model file") + + retval = json.loads(s_config) + logger.debug("Loaded keras 2.x model config: %s", retval) + return retval + + @classmethod + def _unwrap_outputs(cls, outputs: list[list[T.Any]]) -> list[list[str | int]]: + """ Unwrap nested output tensors from a config dict to be a single list of output tensor + + Parameters + ---------- + outputs: list[list[Any]] + The outputs that exist within the Keras 2 config dict that may be nested + + Returns + ------- + list[list[str | int]] + The output configuration formatted to be compatible with Keras 3 + """ + retval = np.array(outputs).reshape(-1, 3).tolist() + for item in retval: + item[1] = int(item[1]) + item[2] = int(item[2]) + logger.debug("Unwrapped outputs: %s to: %s", outputs, retval) + return retval + + def _get_clip_config(self) -> dict[str, T.Any]: + """ Build a clip model from the configuration information stored in the legacy state file + + Returns + ------- + dict[str, T.Any] + The new keras configuration for a Clip model + + Raises + ------ + FaceswapError + If the clip model cannot be built + """ + state_file = f"{os.path.splitext(self._old_model_file)[0]}_state.json" + if not os.path.isfile(state_file): + raise FaceswapError( + f"The state file '{state_file}' does not exist. This model cannot be ported") + + with open(state_file, "r", encoding="utf-8") as ifile: + config = json.load(ifile) + + logger.debug("Loaded legacy config '%s': %s", state_file, config) + net_name = config.get("config", {}).get("enc_architecture", "") + scaling = config.get("config", {}).get("enc_scaling", 0) / 100 + + # Import here to prevent circular imports + from plugins.train.model.phaze_a import _MODEL_MAPPING # pylint:disable=C0415 + vit_info = _MODEL_MAPPING.get(net_name) + + if not scaling or not vit_info: + raise FaceswapError( + f"Clip network could not be found in '{state_file}'. Discovered network is " + f"'{net_name}' with encoder scaling: {scaling}. This model cannot be ported") + + input_size = int(max(vit_info.min_size, ((vit_info.default_size * scaling) // 16) * 16)) + vit_model = ViT(T.cast(TypeModelsViT, vit_info.keras_name), input_size=input_size)() + + retval = vit_model.get_config() + del vit_model + logger.debug("Got new config for '%s' at input size: %s: %s", net_name, input_size, retval) + return retval + + def _convert_lambda_config(self, layer: dict[str, T.Any]): + """ Keras 2 TFLambdaOps are not compatible with Keras 3. Scalar operations can be + relatively easily substituted with a :class:`~lib.model.layers.ScalarOp` layer + + Parameters + ---------- + layer: dict[str, Any] + An existing Keras 2 TFLambdaOp layer + + Raises + ------ + FaceswapError + If the TFLambdaOp is not currently supported + """ + name = layer["config"]["name"] + operation = name.rsplit(".", maxsplit=1)[-1] + if operation not in ("multiply", "truediv", "add", "subtract"): + raise FaceswapError(f"The TFLambdaOp '{name}' is not supported") + value = layer["inbound_nodes"][0][-1]["y"] + + if isinstance(layer["config"]["dtype"], str): + dtype = layer["config"]["dtype"] + else: + dtype = layer["config"]["dtype"]["config"]["name"] + new_layer = ScalarOp(operation, value, name=name, dtype=dtype) + + logger.debug("Converting legacy TFLambdaOp: %s", layer) + + layer["class_name"] = "ScalarOp" + layer["config"] = new_layer.get_config() + for n in layer["inbound_nodes"]: + n[-1] = {} + layer["inbound_nodes"] = [layer["inbound_nodes"]] + logger.debug("Converted legacy TFLambdaOp to %s", layer) + + def _process_deprecations(self, layer: dict[str, T.Any]) -> None: + """ Some layer kwargs are deprecated between Keras 2 and Keras 3. Some are not mission + critical, but updating these here prevents Keras from outputting warnings about deprecated + arguments. Others will fail to load the legacy model (eg Clip) so are replaced with a new + config. Operation is performed in place + + Parameters + ---------- + layer: dict[str, T.Any] + A keras model config item representing a keras layer + """ + if layer["class_name"] == "LeakyReLU": + # Non mission-critical, but prevents scary deprecation messages + config = layer["config"] + old, new = "alpha", "negative_slope" + if old in config: + logger.debug("Updating '%s' kwarg '%s' to '%s'", layer["name"], old, new) + config[new] = config[old] + del config[old] + + if layer["name"] == "visual": + # MultiHeadAttention is not backwards compatible, so get new config for Clip models + logger.debug("Getting new config for 'visual' model") + layer["config"] = self._get_clip_config() + + if layer["class_name"] == "TFOpLambda": + # TFLambdaOp are not supported + self._convert_lambda_config(layer) + + if layer["class_name"] in ("DepthwiseConv2D", + "Conv2DTranspose") and "groups" in layer["config"]: + # groups parameter doesn't exist in Keras 3. Hopefully it still works the same + logger.debug("Removing groups from %s '%s'", layer["class_name"], layer["name"]) + del layer["config"]["groups"] + + if "dtype" in layer["config"]: + # Incorrectly stored dtypes error when deserializing the new config. May be a Keras bug + actual_dtype = None + old_dtype = layer["config"]["dtype"] + if isinstance(old_dtype, str): + actual_dtype = layer["config"]["dtype"] + if isinstance(old_dtype, dict) and old_dtype.get("class_name") == "Policy": + actual_dtype = old_dtype["config"]["name"] + + if actual_dtype is not None: + new_dtype = {"module": "keras", + "class_name": "DTypePolicy", + "config": {"name": actual_dtype}, + "registered_name": None} + logger.debug("Updating dtype for '%s' from %s to %s", layer["name"], + old_dtype, new_dtype) + layer["config"]["dtype"] = new_dtype + + def _process_inbounds(self, + layer_name: str, + inbound_nodes: list[list[list[str | int]]] | list[list[str | int]] + ) -> None: + """ If the inbound nodes are from a shared functional model, decrement the node index by + one. Operation is performed in place + + Parameters + ---------- + layer_name: str + The name of the layer (for logging) + inbound_nodes: list[list[list[str | int]]] | list[list[str | int]] + The inbound nodes from a Keras 2 config dict to process + """ + to_process = T.cast( + list[list[list[str | int]]], + inbound_nodes if isinstance(inbound_nodes[0][0], list) else [inbound_nodes]) + + for inbound in to_process: + for node in inbound: + name, node_index = node[0], node[1] + assert isinstance(name, str) and isinstance(node_index, int) + if name in self._functionals and node_index > 0: + logger.debug("Updating '%s' inbound node index for '%s' from %s to %s", + layer_name, name, node_index, node_index - 1) + node[1] = node_index - 1 + + def _update_layers(self, layer_list: list[dict[str, T.Any]]) -> None: + """ Given a list of keras layers from a keras 2 config dict, increment the indices for + any inbound nodes that come from a shared Functional model. Flatten any nested output + tensor lists. Operations are performed in place + + Parameters + ---------- + layers: list[dict[str, Any]] + A list of layers that belong to a keras 2 functional model config dictionary + """ + for layer in layer_list: + if layer["class_name"] == "Functional": + logger.debug("Found Functional layer. Keys: %s", list(layer)) + + if layer.get("name"): + logger.debug("Storing layer: '%s'", layer["name"]) + self._functionals.add(layer["name"]) + + layer["config"]["output_layers"] = self._unwrap_outputs( + layer["config"]["output_layers"]) + + self._update_layers(layer["config"]["layers"]) + + if not layer.get("inbound_nodes"): + continue + + self._process_deprecations(layer) + self._process_inbounds(layer["name"], layer["inbound_nodes"]) + + def _archive_model(self) -> str: + """ Archive an existing Keras 2 model to a new archive location + + Raises + ------ + FaceswapError + If the destination archive folder exists and is not empty + + Returns + ------- + str + The path to the archived keras 2 model folder + """ + model_dir = os.path.dirname(self._old_model_file) + dst_path = f"{model_dir}_fs2_backup" + if os.path.exists(dst_path) and os.listdir(dst_path): + raise FaceswapError( + f"The destination archive folder '{dst_path}' already exists. Either delete this " + "folder, select a different model folder, or remove the legacy model files from " + f"your model folder '{model_dir}'.") + + if os.path.exists(dst_path): + logger.info("Removing pre-existing empty folder '%s'", dst_path) + os.rmdir(dst_path) + + logger.info("Archiving model folder '%s' to '%s'", model_dir, dst_path) + os.rename(model_dir, dst_path) + return dst_path + + def _restore_files(self, archive_dir: str) -> None: + """ Copy the state.json file and the logs folder from the archive folder to the new model + folder + + Parameters + ---------- + archive_dir: str + The full path to the archived Keras 2 model + """ + model_dir = os.path.dirname(self._new_model_file) + model_name = os.path.splitext(os.path.basename(self._new_model_file))[0] + logger.debug("Restoring required '%s 'files from '%s' to '%s'", + model_name, archive_dir, model_dir) + + for fname in os.listdir(archive_dir): + fullpath = os.path.join(archive_dir, fname) + new_path = os.path.join(model_dir, fname) + + if fname == f"{model_name}_logs" and os.path.isdir(fullpath): + logger.debug("Restoring '%s' to '%s'", fullpath, new_path) + copytree(fullpath, new_path) + continue + + if fname == f"{model_name}_state.json" and os.path.isfile(fullpath): + logger.debug("Restoring '%s' to '%s'", fullpath, new_path) + copyfile(fullpath, new_path) + continue + + logger.debug("Skipping file: '%s'", fname) + + def _upgrade_model(self) -> None: + """ Get the model configuration of a Faceswap 2 model and upgrade it to Faceswap 3 + compatible """ + logger.info("Upgrading model file from Faceswap 2 to Faceswap 3...") + config = self._get_model_config() + self._update_layers([config]) + + logger.debug("Migrating data to new model...") + model = kmodels.Model.from_config(config["config"]) + model.load_weights(self._old_model_file) + + archive_dir = self._archive_model() + + dirname = os.path.dirname(self._new_model_file) + logger.debug("Saving model '%s'", self._new_model_file) + os.mkdir(dirname) + model.save(self._new_model_file) + logger.debug("Saved model '%s'", self._new_model_file) + + self._restore_files(archive_dir) + logger.info("Model upgraded: '%s'", dirname) + + +class PatchKerasConfig: + """ This class exists to patch breaking changes when moving from older keras 3.x models to + newer versions + + Parameters + ---------- + model_path : str + Full path to the keras model to be patched for the current version + """ + def __init__(self, model_path: str) -> None: + logger.debug(parse_class_init(locals())) + self._model_path = model_path + self._items, self._config = self._load_model() + metadata = json.loads(self._items["metadata.json"]) + self._version = tuple(int(x) for x in metadata['keras_version'].split(".")[:2]) + logger.debug("Initialized: %s", self.__class__.__name__) + + def _load_model(self) -> tuple[dict[str, bytes], dict[str, T.Any]]: + """ Load the objects from the compressed keras model + + Returns + ------- + items : dict[str, bytes] + The filename and file objects within the keras 3 model file that are not the model + config + config : dict[str, Any] + The model configuration dictionary from the keras 3 model file + """ + with zipfile.ZipFile(self._model_path, "r") as zf: + items = {f.filename: zf.read(f) for f in zf.filelist if f.filename != "config.json"} + config = json.loads(zf.read("config.json")) + + logger.debug("Loaded legacy existing items %s and 'config.json' from model '%s'", + list(items), self._model_path) + return items, config + + def _update_nn_blocks(self, layer: dict[str, T.Any]): + """ In older versions of keras our :class:`lib.model.nn_blocks.Conv2D` and + :class:`lib.model.nn_blocks.DepthwiseConv2D` inherited from their respective Keras layers. + Sometime between 3.3.3 and 3.12 (during beta testing) this stopped working, raising a + TypeError. Subsequently we have refactored those classes to no longer inherit, and call the + underlying keras layer directly instead. The keras config needs to be rewritten to reflect + this. + + Parameters + ---------- + layer dict[str, Any] + A layer config dictionary from a keras 3 model + """ + if (layer.get("module") == "lib.model.nn_blocks" and + layer.get("class_name") in ("Conv2D", "DepthwiseConv2D")): + new_module = "keras.layers" + logger.debug("Updating Keras %s layer '%s' to '%s': %s", + ".".join(str(x) for x in self._version), + f"{layer['module']}.{layer['class_name']}", + f"{new_module}.{layer['class_name']}", + layer["name"]) + layer["module"] = new_module + + def _parse_inbound_args(self, inbound: list | dict[str, T.Any]) -> None: + """ Recurse through keras inbound node args until we arrive at a dictionary + + Parameters + ---------- + list[lisr | dict[str, Any]] + A Keras inbound nodes args entry or the nested dictionary + """ + if not isinstance(inbound, (list, dict)): + return + + if isinstance(inbound, list): + for arg in inbound: + self._parse_inbound_args(arg) + return + + arg_conf = inbound["config"] + if "keras_history" not in arg_conf: + return + + if "." in arg_conf["keras_history"][0]: + new_hist = arg_conf["keras_history"][:] + new_hist[0] = new_hist[0].replace(".", "_") + logger.debug("Updating Inbound Keras history from '%s' to '%s'", + arg_conf["keras_history"], new_hist) + arg_conf["keras_history"] = new_hist + + def _update_dot_naming(self, layer: dict[str, T.Any]): + """ Sometime between 3.3.3 and 3.12 (during beta testing) layers with "." in the name + started generating a KeyError. This is odd as the error comes from Torch, but dot naming is + standard. To work around this all dots (.) in layer names have been converted to + underscores (_). The keras config needs to be rewritten to reflect this. This only impacts + FS models that used the CLiP encoder + + Parameters + ---------- + layer dict[str, Any] + A layer config dictionary from a keras 3 model + """ + if "." in layer["name"]: + new_name = layer["name"].replace(".", "_") + logger.debug("Updating Keras layer name from '%s' to '%s'", layer["name"], new_name) + layer["name"] = new_name + + config = layer["config"] + if "." in config["name"]: + new_name = config["name"].replace(".", "_") + logger.debug("Updating Keras config layer name from '%s' to '%s'", + config["name"], new_name) + config["name"] = new_name + + inbound = layer["inbound_nodes"] + for in_ in inbound: + for arg in in_["args"]: + self._parse_inbound_args(arg) + + def _update_config(self, config: dict[str, T.Any]) -> dict[str, T.Any]: + """ Recursively update the `config` dictionary from a full keras config in place + + Parameters + ---------- + config : dict[str, Any] + A 'config' section of keras config + + Returns + ------- + dict[str, Any] + The updated `config` section of a keras config + """ + layer: dict[str, T.Any] + for layer in config["layers"]: + if layer.get("class_name") == "Functional": + self._update_config(layer["config"]) + if self._version <= (3, 3): + self._update_nn_blocks(layer) + self._update_dot_naming(layer) + return config + + def _save_model(self) -> None: + """ Save the updated keras model """ + logger.info("Updating Keras model '%s'...", self._model_path) + with zipfile.ZipFile(self._model_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: + for filename, data in self._items.items(): + zf.writestr(filename, data) + zf.writestr("config.json", json.dumps(self._config).encode("utf-8")) + + def __call__(self) -> None: + """ Update the keras configuration saved in a keras model file and save over the original + model """ + logger.debug("Updating saved config for keras version %s", self._version) + self._config["config"] = self._update_config(self._config["config"]) + self._save_model() + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/dfaker.py b/plugins/train/model/dfaker.py index d097e20ac4..fe6fb88711 100644 --- a/plugins/train/model/dfaker.py +++ b/plugins/train/model/dfaker.py @@ -1,62 +1,62 @@ #!/usr/bin/env python3 """ DFaker Model Based on the dfaker model: https://github.com/dfaker """ +import logging +import sys +from keras import initializers, Input, layers, Model as KModel -from keras.initializers import RandomNormal -from keras.layers import Conv2D, Input -from keras.models import Model as KerasModel +from lib.model.nn_blocks import Conv2DOutput, UpscaleBlock, ResidualBlock +from plugins.train.train_config import Loss as cfg_loss +from .original import Model as OriginalModel +from . import dfaker_defaults as cfg -from .original import logger, Model as OriginalModel +logger = logging.getLogger(__name__) +# pylint:disable=duplicate-code class Model(OriginalModel): - """ Improved Autoeencoder Model """ + """ Dfaker Model """ def __init__(self, *args, **kwargs): - logger.debug("Initializing %s: (args: %s, kwargs: %s", - self.__class__.__name__, args, kwargs) - kwargs["input_shape"] = (64, 64, 3) - kwargs["encoder_dim"] = 1024 - self.kernel_initializer = RandomNormal(0, 0.02) super().__init__(*args, **kwargs) - logger.debug("Initialized %s", self.__class__.__name__) - - def build_autoencoders(self): - """ Initialize Dfaker model """ - logger.debug("Initializing model") - inputs = [Input(shape=self.input_shape, name="face")] - if self.config.get("mask_type", None): - mask_shape = (self.input_shape[0] * 2, self.input_shape[1] * 2, 1) - inputs.append(Input(shape=mask_shape, name="mask")) - - for side in ("a", "b"): - decoder = self.networks["decoder_{}".format(side)].network - output = decoder(self.networks["encoder"].network(inputs[0])) - autoencoder = KerasModel(inputs, output) - self.add_predictor(side, autoencoder) - logger.debug("Initialized model") - - def decoder(self): + self._output_size = cfg.output_size() + if self._output_size not in (128, 256): + logger.error("Dfaker output shape should be 128 or 256 px") + sys.exit(1) + self.input_shape = (self._output_size // 2, self._output_size // 2, 3) + self.encoder_dim = 1024 + self.kernel_initializer = initializers.RandomNormal(0, 0.02) + + def decoder(self, side): """ Decoder Network """ input_ = Input(shape=(8, 8, 512)) var_x = input_ - var_x = self.blocks.upscale(var_x, 512, res_block_follows=True) - var_x = self.blocks.res_block(var_x, 512, kernel_initializer=self.kernel_initializer) - var_x = self.blocks.upscale(var_x, 256, res_block_follows=True) - var_x = self.blocks.res_block(var_x, 256, kernel_initializer=self.kernel_initializer) - var_x = self.blocks.upscale(var_x, 128, res_block_follows=True) - var_x = self.blocks.res_block(var_x, 128, kernel_initializer=self.kernel_initializer) - var_x = self.blocks.upscale(var_x, 64) - var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + if self._output_size == 256: + var_x = UpscaleBlock(1024, activation=None)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(1024, kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(512, activation=None)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(512, kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(256, activation=None)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(256, kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(128, activation=None)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(128, kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(64, activation="leakyrelu")(var_x) + var_x = Conv2DOutput(3, 5, name=f"face_out_{side}")(var_x) outputs = [var_x] - if self.config.get("mask_type", None): + if cfg_loss.learn_mask(): var_y = input_ - var_y = self.blocks.upscale(var_y, 512) - var_y = self.blocks.upscale(var_y, 256) - var_y = self.blocks.upscale(var_y, 128) - var_y = self.blocks.upscale(var_y, 64) - var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + if self._output_size == 256: + var_y = UpscaleBlock(1024, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(512, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(256, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(128, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(64, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name=f"mask_out_{side}")(var_y) outputs.append(var_y) - return KerasModel([input_], outputs=outputs) + return KModel([input_], outputs=outputs, name=f"decoder_{side}") diff --git a/plugins/train/model/dfaker_defaults.py b/plugins/train/model/dfaker_defaults.py new file mode 100644 index 0000000000..3ece2836a5 --- /dev/null +++ b/plugins/train/model/dfaker_defaults.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Dfl_SAE Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "Dfaker Model (Adapted from https://github.com/dfaker/df)" + + +output_size = ConfigItem( + datatype=int, + default=128, + group="size", + info="Resolution (in pixels) of the output image to generate on.\n" + "BE AWARE Larger resolution will dramatically increase VRAM requirements.\n" + "Must be 128 or 256.", + rounding=128, + min_max=(128, 256), + fixed=True) diff --git a/plugins/train/model/dfl_h128.py b/plugins/train/model/dfl_h128.py index 7e780c78cc..c55bc46a4b 100644 --- a/plugins/train/model/dfl_h128.py +++ b/plugins/train/model/dfl_h128.py @@ -1,53 +1,51 @@ #!/usr/bin/env python3 -""" DeepFakesLab H128 Model +""" DeepFaceLab H128 Model Based on https://github.com/iperov/DeepFaceLab """ -from keras.layers import Conv2D, Dense, Flatten, Input, Reshape -from keras.models import Model as KerasModel +from keras import Input, layers, Model as KModel -from .original import logger, Model as OriginalModel +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, UpscaleBlock +from plugins.train.train_config import Loss as cfg_loss +from .original import Model as OriginalModel +from . import dfl_h128_defaults as cfg class Model(OriginalModel): - """ Low Memory version of Original Faceswap Model """ + """ H128 Model from DFL """ def __init__(self, *args, **kwargs): - logger.debug("Initializing %s: (args: %s, kwargs: %s", - self.__class__.__name__, args, kwargs) - - kwargs["input_shape"] = (128, 128, 3) - kwargs["encoder_dim"] = 256 if self.config["lowmem"] else 512 - super().__init__(*args, **kwargs) - logger.debug("Initialized %s", self.__class__.__name__) + self.input_shape = (128, 128, 3) + self.encoder_dim = 256 if cfg.lowmem() else 512 def encoder(self): """ DFL H128 Encoder """ input_ = Input(shape=self.input_shape) - var_x = input_ - var_x = self.blocks.conv(var_x, 128) - var_x = self.blocks.conv(var_x, 256) - var_x = self.blocks.conv(var_x, 512) - var_x = self.blocks.conv(var_x, 1024) - var_x = Dense(self.encoder_dim)(Flatten()(var_x)) - var_x = Dense(8 * 8 * self.encoder_dim)(var_x) - var_x = Reshape((8, 8, self.encoder_dim))(var_x) - var_x = self.blocks.upscale(var_x, self.encoder_dim) - return KerasModel(input_, var_x) - - def decoder(self): + var_x = Conv2DBlock(128, activation="leakyrelu")(input_) + var_x = Conv2DBlock(256, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(512, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(1024, activation="leakyrelu")(var_x) + var_x = layers.Dense(self.encoder_dim)(layers.Flatten()(var_x)) + var_x = layers.Dense(8 * 8 * self.encoder_dim)(var_x) + var_x = layers.Reshape((8, 8, self.encoder_dim))(var_x) + var_x = UpscaleBlock(self.encoder_dim, activation="leakyrelu")(var_x) + return KModel(input_, var_x, name="encoder") + + def decoder(self, side): """ DFL H128 Decoder """ input_ = Input(shape=(16, 16, self.encoder_dim)) - var = input_ - var = self.blocks.upscale(var, self.encoder_dim) - var = self.blocks.upscale(var, self.encoder_dim // 2) - var = self.blocks.upscale(var, self.encoder_dim // 4) - - # Face - var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var) + var_x = input_ + var_x = UpscaleBlock(self.encoder_dim, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(self.encoder_dim // 2, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(self.encoder_dim // 4, activation="leakyrelu")(var_x) + var_x = Conv2DOutput(3, 5, name=f"face_out_{side}")(var_x) outputs = [var_x] - # Mask - if self.config.get("mask_type", None): - var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var) + + if cfg_loss.learn_mask(): + var_y = input_ + var_y = UpscaleBlock(self.encoder_dim, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(self.encoder_dim // 2, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(self.encoder_dim // 4, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name=f"mask_out_{side}")(var_y) outputs.append(var_y) - return KerasModel(input_, outputs=outputs) + return KModel(input_, outputs=outputs, name=f"decoder_{side}") diff --git a/plugins/train/model/dfl_h128_defaults.py b/plugins/train/model/dfl_h128_defaults.py new file mode 100755 index 0000000000..d1edce7789 --- /dev/null +++ b/plugins/train/model/dfl_h128_defaults.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Dfl_H128 Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = "DFL H128 Model (Adapted from https://github.com/iperov/DeepFaceLab)" + + +lowmem = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\n" + "NB: Models with a changed lowmem mode are not compatible with each other.", + fixed=True) diff --git a/plugins/train/model/dfl_sae.py b/plugins/train/model/dfl_sae.py new file mode 100644 index 0000000000..f6b6686a76 --- /dev/null +++ b/plugins/train/model/dfl_sae.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" DeepFaceLab SAE Model + Based on https://github.com/iperov/DeepFaceLab +""" +import logging + +import numpy as np + +from keras import Input, layers, Model as KModel + +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, ResidualBlock, UpscaleBlock +from plugins.train.train_config import Loss as cfg_loss + +from ._base import ModelBase +from . import dfl_sae_defaults as cfg + +logger = logging.getLogger(__name__) + + +class Model(ModelBase): + """ SAE Model from DFL """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_shape = (cfg.input_size(), cfg.input_size(), 3) + self.architecture = cfg.architecture().lower() + self.use_mask = cfg_loss.learn_mask() + self.multiscale_count = 3 if cfg.multiscale_decoder() else 1 + self.encoder_dim = cfg.encoder_dims() + self.decoder_dim = cfg.decoder_dims() + + @property + def model_name(self): + """ str: The name of the keras model. Varies depending on selected architecture. """ + return f"{self.name}_{self.architecture}" + + @property + def ae_dims(self): + """ Set the Autoencoder Dimensions or set to default """ + retval = cfg.autoencoder_dims() + if retval == 0: + retval = 256 if self.architecture == "liae" else 512 + return retval + + @property + def freeze_layers(self) -> list[str]: + """ list[str] : The layer name for freezing based on the configured architecture """ + return [f"encoder_{self.architecture}"] + + @property + def load_layers(self) -> list[str]: + """ list[str] : The layer name for loading based on the configured architecture """ + return [f"encoder_{self.architecture}"] + + def build_model(self, inputs): + """ Build the DFL-SAE Model """ + encoder = getattr(self, f"encoder_{self.architecture}")() + enc_output_shape = encoder.output_shape[1:] + encoder_a = encoder(inputs[0]) + encoder_b = encoder(inputs[1]) + + if self.architecture == "liae": + inter_both = self.inter_liae("both", enc_output_shape) + int_output_shape = (np.array(inter_both.output_shape[1:]) * (1, 1, 2)).tolist() + + inter_a = layers.Concatenate()([inter_both(encoder_a), inter_both(encoder_a)]) + inter_b = layers.Concatenate()([self.inter_liae("b", enc_output_shape)(encoder_b), + inter_both(encoder_b)]) + + decoder = self.decoder("both", int_output_shape) + outputs = decoder(inter_a) + decoder(inter_b) + else: + outputs = (self.decoder("a", enc_output_shape)(encoder_a) + + self.decoder("b", enc_output_shape)(encoder_b)) + autoencoder = KModel(inputs, outputs, name=self.model_name) + return autoencoder + + def encoder_df(self): + """ DFL SAE DF Encoder Network""" + input_ = Input(shape=self.input_shape) + dims = self.input_shape[-1] * self.encoder_dim + lowest_dense_res = self.input_shape[0] // 16 + var_x = Conv2DBlock(dims, activation="leakyrelu")(input_) + var_x = Conv2DBlock(dims * 2, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(dims * 4, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(dims * 8, activation="leakyrelu")(var_x) + var_x = layers.Dense(self.ae_dims)(layers.Flatten()(var_x)) + var_x = layers.Dense(lowest_dense_res * lowest_dense_res * self.ae_dims)(var_x) + var_x = layers.Reshape((lowest_dense_res, lowest_dense_res, self.ae_dims))(var_x) + var_x = UpscaleBlock(self.ae_dims, activation="leakyrelu")(var_x) + return KModel(input_, var_x, name="encoder_df") + + def encoder_liae(self): + """ DFL SAE LIAE Encoder Network """ + input_ = Input(shape=self.input_shape) + dims = self.input_shape[-1] * self.encoder_dim + var_x = Conv2DBlock(dims, activation="leakyrelu")(input_) + var_x = Conv2DBlock(dims * 2, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(dims * 4, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(dims * 8, activation="leakyrelu")(var_x) + var_x = layers.Flatten()(var_x) + return KModel(input_, var_x, name="encoder_liae") + + def inter_liae(self, side, input_shape): + """ DFL SAE LIAE Intermediate Network """ + input_ = Input(shape=input_shape) + lowest_dense_res = self.input_shape[0] // 16 + var_x = input_ + var_x = layers.Dense(self.ae_dims)(var_x) + var_x = layers.Dense(lowest_dense_res * lowest_dense_res * self.ae_dims * 2)(var_x) + var_x = layers.Reshape((lowest_dense_res, lowest_dense_res, self.ae_dims * 2))(var_x) + var_x = UpscaleBlock(self.ae_dims * 2, activation="leakyrelu")(var_x) + return KModel(input_, var_x, name=f"intermediate_{side}") + + def decoder(self, side, input_shape): + """ DFL SAE Decoder Network""" + input_ = Input(shape=input_shape) + outputs = [] + + dims = self.input_shape[-1] * self.decoder_dim + var_x = input_ + + var_x1 = UpscaleBlock(dims * 8, activation=None)(var_x) + var_x1 = layers.LeakyReLU(negative_slope=0.2)(var_x1) + var_x1 = ResidualBlock(dims * 8)(var_x1) + var_x1 = ResidualBlock(dims * 8)(var_x1) + if self.multiscale_count >= 3: + outputs.append(Conv2DOutput(3, 5, name=f"face_out_32_{side}")(var_x1)) + + var_x2 = UpscaleBlock(dims * 4, activation=None)(var_x1) + var_x2 = layers.LeakyReLU(negative_slope=0.2)(var_x2) + var_x2 = ResidualBlock(dims * 4)(var_x2) + var_x2 = ResidualBlock(dims * 4)(var_x2) + if self.multiscale_count >= 2: + outputs.append(Conv2DOutput(3, 5, name=f"face_out_64_{side}")(var_x2)) + + var_x3 = UpscaleBlock(dims * 2, activation=None)(var_x2) + var_x3 = layers.LeakyReLU(negative_slope=0.2)(var_x3) + var_x3 = ResidualBlock(dims * 2)(var_x3) + var_x3 = ResidualBlock(dims * 2)(var_x3) + + outputs.append(Conv2DOutput(3, 5, name=f"face_out_128_{side}")(var_x3)) + + if self.use_mask: + var_y = input_ + var_y = UpscaleBlock(self.decoder_dim * 8, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(self.decoder_dim * 4, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(self.decoder_dim * 2, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name=f"mask_out_{side}")(var_y) + outputs.append(var_y) + return KModel(input_, outputs=outputs, name=f"decoder_{side}") diff --git a/plugins/train/model/dfl_sae_defaults.py b/plugins/train/model/dfl_sae_defaults.py new file mode 100644 index 0000000000..36564143a5 --- /dev/null +++ b/plugins/train/model/dfl_sae_defaults.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Dfl_SAE Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = "DFL SAE Model (Adapted from https://github.com/iperov/DeepFaceLab)" + + +input_size = ConfigItem( + datatype=int, + default=128, + group="size", + info="Resolution (in pixels) of the input image to train on.\n" + "BE AWARE Larger resolution will dramatically increase VRAM requirements.\n" + "\nMust be divisible by 16.", + rounding=16, + min_max=(64, 256), + fixed=True) + +architecture = ConfigItem( + datatype=str, + default="df", + group="network", + info="Model architecture:" + "\n\t'df': Keeps the faces more natural." + "\n\t'liae': Can help fix overly different face shapes.", + choices=["df", "liae"], + gui_radio=True, + fixed=True) + +autoencoder_dims = ConfigItem( + datatype=int, + default=0, + group="network", + info="Face information is stored in AutoEncoder dimensions. If there are not enough " + "dimensions then certain facial features may not be recognized." + "\nHigher number of dimensions are better, but require more VRAM." + "\nSet to 0 to use the architecture defaults (256 for liae, 512 for df).", + rounding=32, + min_max=(0, 1024), + fixed=True) + +encoder_dims = ConfigItem( + datatype=int, + default=42, + group="network", + info="Encoder dimensions per channel. Higher number of encoder dimensions will help " + "the model to recognize more facial features, but will require more VRAM.", + rounding=1, + min_max=(21, 85), + fixed=True) + +decoder_dims = ConfigItem( + datatype=int, + default=21, + group="network", + info="Decoder dimensions per channel. Higher number of decoder dimensions will help " + "the model to improve details, but will require more VRAM.", + rounding=1, + min_max=(10, 85), + fixed=True) + +multiscale_decoder = ConfigItem( + datatype=bool, + default=False, + group="network", + info="Multiscale decoder can help to obtain better details.", + fixed=True) diff --git a/plugins/train/model/dlight.py b/plugins/train/model/dlight.py new file mode 100644 index 0000000000..07f402bfc1 --- /dev/null +++ b/plugins/train/model/dlight.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +""" A lightweight variant of DFaker Model + By AnDenix, 2018-2019 + Based on the dfaker model: https://github.com/dfaker + + Acknowledgments: + kvrooman for numerous insights and invaluable aid + DeepHomage for lots of testing + """ +import logging + +from keras import layers, Input, Model as KModel + +from lib.model.nn_blocks import (Conv2DOutput, Conv2DBlock, ResidualBlock, UpscaleBlock, + Upscale2xBlock) +from lib.utils import FaceswapError +from plugins.train.train_config import Loss as cfg_loss + +from ._base import ModelBase +from . import dlight_defaults as cfg + + +logger = logging.getLogger(__name__) + + +class Model(ModelBase): + """ DLight Autoencoder Model """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_shape = (128, 128, 3) + + self.features = {"lowmem": 0, "fair": 1, "best": 2}[cfg.features()] + self.encoder_filters = 64 if self.features > 0 else 48 + + bonum_fortunam = 128 + self.encoder_dim = {0: 512 + bonum_fortunam, + 1: 1024 + bonum_fortunam, + 2: 1536 + bonum_fortunam}[self.features] + self.details = {"fast": 0, "good": 1}[cfg.details()] + try: + self.upscale_ratio = {128: 2, + 256: 4, + 384: 6}[cfg.output_size()] + except KeyError as err: + logger.error("Config error: output_size must be one of: 128, 256, or 384.") + raise FaceswapError("Config error: output_size must be one of: " + "128, 256, or 384.") from err + + logger.debug("output_size: %s, features: %s, encoder_filters: %s, encoder_dim: %s, " + " details: %s, upscale_ratio: %s", cfg.output_size(), self.features, + self.encoder_filters, self.encoder_dim, self.details, self.upscale_ratio) + + def build_model(self, inputs): + """ Build the Dlight Model. """ + encoder = self.encoder() + encoder_a = encoder(inputs[0]) + encoder_b = encoder(inputs[1]) + + decoder_b = self.decoder_b if self.details > 0 else self.decoder_b_fast + + outputs = self.decoder_a()(encoder_a) + decoder_b()(encoder_b) + + autoencoder = KModel(inputs, outputs, name=self.model_name) + return autoencoder + + def encoder(self): + """ DeLight Encoder Network """ + input_ = Input(shape=self.input_shape) + var_x = input_ + + var_x1 = Conv2DBlock(self.encoder_filters // 2, activation="leakyrelu")(var_x) + var_x2 = layers.AveragePooling2D(pool_size=(2, 2))(var_x) + var_x2 = layers.LeakyReLU(0.1)(var_x2) + var_x = layers.Concatenate()([var_x1, var_x2]) + + var_x1 = Conv2DBlock(self.encoder_filters, activation="leakyrelu")(var_x) + var_x2 = layers.AveragePooling2D(pool_size=(2, 2))(var_x) + var_x2 = layers.LeakyReLU(0.1)(var_x2) + var_x = layers.Concatenate()([var_x1, var_x2]) + + var_x1 = Conv2DBlock(self.encoder_filters * 2, activation="leakyrelu")(var_x) + var_x2 = layers.AveragePooling2D(pool_size=(2, 2))(var_x) + var_x2 = layers.LeakyReLU(0.1)(var_x2) + var_x = layers.Concatenate()([var_x1, var_x2]) + + var_x1 = Conv2DBlock(self.encoder_filters * 4, activation="leakyrelu")(var_x) + var_x2 = layers.AveragePooling2D(pool_size=(2, 2))(var_x) + var_x2 = layers.LeakyReLU(0.1)(var_x2) + var_x = layers.Concatenate()([var_x1, var_x2]) + + var_x1 = Conv2DBlock(self.encoder_filters * 8, activation="leakyrelu")(var_x) + var_x2 = layers.AveragePooling2D(pool_size=(2, 2))(var_x) + var_x2 = layers.LeakyReLU(0.1)(var_x2) + var_x = layers.Concatenate()([var_x1, var_x2]) + + var_x = layers.Dense(self.encoder_dim)(layers.Flatten()(var_x)) + var_x = layers.Dropout(0.05)(var_x) + var_x = layers.Dense(4 * 4 * 1024)(var_x) + var_x = layers.Dropout(0.05)(var_x) + var_x = layers.Reshape((4, 4, 1024))(var_x) + + return KModel(input_, var_x, name="encoder") + + def decoder_a(self): + """ DeLight Decoder A(old face) Network """ + input_ = Input(shape=(4, 4, 1024)) + dec_a_complexity = 256 + mask_complexity = 128 + + var_xy = input_ + var_xy = layers.UpSampling2D(self.upscale_ratio, interpolation='bilinear')(var_xy) + + var_x = var_xy + var_x = Upscale2xBlock(dec_a_complexity, activation="leakyrelu", fast=False)(var_x) + var_x = Upscale2xBlock(dec_a_complexity // 2, activation="leakyrelu", fast=False)(var_x) + var_x = Upscale2xBlock(dec_a_complexity // 4, activation="leakyrelu", fast=False)(var_x) + var_x = Upscale2xBlock(dec_a_complexity // 8, activation="leakyrelu", fast=False)(var_x) + + var_x = Conv2DOutput(3, 5, name="face_out")(var_x) + + outputs = [var_x] + + if cfg_loss.learn_mask(): + var_y = var_xy # mask decoder + var_y = Upscale2xBlock(mask_complexity, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 2, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 4, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 8, activation="leakyrelu", fast=False)(var_y) + + var_y = Conv2DOutput(1, 5, name="mask_out")(var_y) + + outputs.append(var_y) + + return KModel([input_], outputs=outputs, name="decoder_a") + + def decoder_b_fast(self): + """ DeLight Fast Decoder B(new face) Network """ + input_ = Input(shape=(4, 4, 1024)) + + dec_b_complexity = 512 + mask_complexity = 128 + + var_xy = input_ + + var_xy = UpscaleBlock(512, scale_factor=self.upscale_ratio, activation="leakyrelu")(var_xy) + var_x = var_xy + + var_x = Upscale2xBlock(dec_b_complexity, activation="leakyrelu", fast=True)(var_x) + var_x = Upscale2xBlock(dec_b_complexity // 2, activation="leakyrelu", fast=True)(var_x) + var_x = Upscale2xBlock(dec_b_complexity // 4, activation="leakyrelu", fast=True)(var_x) + var_x = Upscale2xBlock(dec_b_complexity // 8, activation="leakyrelu", fast=True)(var_x) + + var_x = Conv2DOutput(3, 5, name="face_out")(var_x) + + outputs = [var_x] + + if cfg_loss.learn_mask(): + var_y = var_xy # mask decoder + + var_y = Upscale2xBlock(mask_complexity, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 2, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 4, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 8, activation="leakyrelu", fast=False)(var_y) + + var_y = Conv2DOutput(1, 5, name="mask_out")(var_y) + + outputs.append(var_y) + + return KModel([input_], outputs=outputs, name="decoder_b_fast") + + def decoder_b(self): + """ DeLight Decoder B(new face) Network """ + input_ = Input(shape=(4, 4, 1024)) + + dec_b_complexity = 512 + mask_complexity = 128 + + var_xy = input_ + + var_xy = Upscale2xBlock(512, + scale_factor=self.upscale_ratio, + activation=None, + fast=False)(var_xy) + var_x = var_xy + + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(512, use_bias=True)(var_x) + var_x = ResidualBlock(512, use_bias=False)(var_x) + var_x = ResidualBlock(512, use_bias=False)(var_x) + var_x = Upscale2xBlock(dec_b_complexity, activation=None, fast=False)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(dec_b_complexity, use_bias=True)(var_x) + var_x = ResidualBlock(dec_b_complexity, use_bias=False)(var_x) + var_x = layers.BatchNormalization()(var_x) + var_x = Upscale2xBlock(dec_b_complexity // 2, activation=None, fast=False)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(dec_b_complexity // 2, use_bias=True)(var_x) + var_x = Upscale2xBlock(dec_b_complexity // 4, activation=None, fast=False)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(dec_b_complexity // 4, use_bias=False)(var_x) + var_x = layers.BatchNormalization()(var_x) + var_x = Upscale2xBlock(dec_b_complexity // 8, activation="leakyrelu", fast=False)(var_x) + + var_x = Conv2DOutput(3, 5, name="face_out")(var_x) + + outputs = [var_x] + + if cfg_loss.learn_mask(): + var_y = var_xy # mask decoder + var_y = layers.LeakyReLU(negative_slope=0.1)(var_y) + + var_y = Upscale2xBlock(mask_complexity, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 2, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 4, activation="leakyrelu", fast=False)(var_y) + var_y = Upscale2xBlock(mask_complexity // 8, activation="leakyrelu", fast=False)(var_y) + + var_y = Conv2DOutput(1, 5, name="mask_out")(var_y) + + outputs.append(var_y) + + return KModel([input_], outputs=outputs, name="decoder_b") diff --git a/plugins/train/model/dlight_defaults.py b/plugins/train/model/dlight_defaults.py new file mode 100644 index 0000000000..4f9d51b280 --- /dev/null +++ b/plugins/train/model/dlight_defaults.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Dfaker Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT = ("A lightweight, high resolution Dfaker variant " + "(Adapted from https://github.com/dfaker/df)") + + +features = ConfigItem( + datatype=str, + default="best", + group="settings", + info="Higher settings will allow learning more features such as tatoos, piercing and " + "wrinkles.\nStrongly affects VRAM usage.", + choices=["lowmem", "fair", "best"], + gui_radio=True, + fixed=True) + +details = ConfigItem( + datatype=str, + default="good", + group="settings", + info="Defines detail fidelity. Lower setting can appear 'rugged' while 'good' might take " + "a longer time to train.\nAffects VRAM usage.", + choices=["fast", "good"], + gui_radio=True, + fixed=True) + +output_size = ConfigItem( + datatype=int, + default=256, + group="settings", + info="Output image resolution (in pixels).\nBe aware that larger resolution will increase " + "VRAM requirements.\nNB: Must be either 128, 256, or 384.", + rounding=128, + min_max=(128, 384), + fixed=True) diff --git a/plugins/train/model/iae.py b/plugins/train/model/iae.py index 667cbe3ba0..320ec71bee 100644 --- a/plugins/train/model/iae.py +++ b/plugins/train/model/iae.py @@ -1,84 +1,75 @@ #!/usr/bin/env python3 """ Improved autoencoder for faceswap """ -from keras.layers import Concatenate, Conv2D, Dense, Flatten, Input, Reshape -from keras.models import Model as KerasModel +from keras import Input, layers, Model as KModel -from ._base import ModelBase, logger +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, UpscaleBlock +from plugins.train.train_config import Loss as cfg_loss + +from ._base import ModelBase +# pylint:disable=duplicate-code class Model(ModelBase): - """ Improved Autoeencoder Model """ + """ Improved Autoencoder Model """ def __init__(self, *args, **kwargs): - logger.debug("Initializing %s: (args: %s, kwargs: %s", - self.__class__.__name__, args, kwargs) - kwargs["input_shape"] = (64, 64, 3) - kwargs["encoder_dim"] = 1024 super().__init__(*args, **kwargs) - logger.debug("Initialized %s", self.__class__.__name__) + self.input_shape = (64, 64, 3) + self.encoder_dim = 1024 - def add_networks(self): - """ Add the IAE model weights """ - logger.debug("Adding networks") - self.add_network("encoder", None, self.encoder()) - self.add_network("decoder", None, self.decoder()) - self.add_network("intermediate", "a", self.intermediate()) - self.add_network("intermediate", "b", self.intermediate()) - self.add_network("inter", None, self.intermediate()) - logger.debug("Added networks") + def build_model(self, inputs): + """ Build the IAE Model """ + encoder = self.encoder() + decoder = self.decoder() + inter_a = self.intermediate("a") + inter_b = self.intermediate("b") + inter_both = self.intermediate("both") - def build_autoencoders(self): - """ Initialize IAE model """ - logger.debug("Initializing model") - inputs = [Input(shape=self.input_shape, name="face")] - if self.config.get("mask_type", None): - mask_shape = (self.input_shape[:2] + (1, )) - inputs.append(Input(shape=mask_shape, name="mask")) + encoder_a = encoder(inputs[0]) + encoder_b = encoder(inputs[1]) - decoder = self.networks["decoder"].network - encoder = self.networks["encoder"].network - inter_both = self.networks["inter"].network - for side in ("a", "b"): - inter_side = self.networks["intermediate_{}".format(side)].network - output = decoder(Concatenate()([inter_side(encoder(inputs[0])), - inter_both(encoder(inputs[0]))])) + outputs = (decoder(layers.Concatenate()([inter_a(encoder_a), inter_both(encoder_a)])) + + decoder(layers.Concatenate()([inter_b(encoder_b), inter_both(encoder_b)]))) - autoencoder = KerasModel(inputs, output) - self.add_predictor(side, autoencoder) - logger.debug("Initialized model") + autoencoder = KModel(inputs, outputs, name=self.model_name) + return autoencoder def encoder(self): """ Encoder Network """ input_ = Input(shape=self.input_shape) var_x = input_ - var_x = self.blocks.conv(var_x, 128) - var_x = self.blocks.conv(var_x, 256) - var_x = self.blocks.conv(var_x, 512) - var_x = self.blocks.conv(var_x, 1024) - var_x = Flatten()(var_x) - return KerasModel(input_, var_x) + var_x = Conv2DBlock(128, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(256, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(512, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(1024, activation="leakyrelu")(var_x) + var_x = layers.Flatten()(var_x) + return KModel(input_, var_x, name="encoder") - def intermediate(self): + def intermediate(self, side): """ Intermediate Network """ - input_ = Input(shape=(None, 4 * 4 * 1024)) - var_x = input_ - var_x = Dense(self.encoder_dim)(var_x) - var_x = Dense(4 * 4 * int(self.encoder_dim/2))(var_x) - var_x = Reshape((4, 4, int(self.encoder_dim/2)))(var_x) - return KerasModel(input_, var_x) + input_ = Input(shape=(4 * 4 * 1024, )) + var_x = layers.Dense(self.encoder_dim)(input_) + var_x = layers.Dense(4 * 4 * int(self.encoder_dim/2))(var_x) + var_x = layers.Reshape((4, 4, int(self.encoder_dim/2)))(var_x) + return KModel(input_, var_x, name=f"inter_{side}") def decoder(self): """ Decoder Network """ input_ = Input(shape=(4, 4, self.encoder_dim)) var_x = input_ - var_x = self.blocks.upscale(var_x, 512) - var_x = self.blocks.upscale(var_x, 256) - var_x = self.blocks.upscale(var_x, 128) - var_x = self.blocks.upscale(var_x, 64) - var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x) + var_x = UpscaleBlock(512, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(256, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(128, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(64, activation="leakyrelu")(var_x) + var_x = Conv2DOutput(3, 5, name="face_out")(var_x) outputs = [var_x] - if self.config.get("mask_type", None): - var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var_x) + if cfg_loss.learn_mask(): + var_y = input_ + var_y = UpscaleBlock(512, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(256, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(128, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(64, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name="mask_out")(var_y) outputs.append(var_y) - return KerasModel(input_, outputs=outputs) + return KModel(input_, outputs=outputs, name="decoder") diff --git a/plugins/train/model/lightweight.py b/plugins/train/model/lightweight.py new file mode 100644 index 0000000000..c1fc6acf53 --- /dev/null +++ b/plugins/train/model/lightweight.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +""" Lightweight Model by torzdf + An extremely limited model for training on low-end graphics cards + Based on the original https://www.reddit.com/r/deepfakes/ + code sample + contributions """ + +from keras import Input, layers, Model as KModel + +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, UpscaleBlock +from plugins.train.train_config import Loss as cfg_loss + +from .original import Model as OriginalModel +# pylint:disable=duplicate-code + + +class Model(OriginalModel): + """ Lightweight Model for ~2GB Graphics Cards """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.encoder_dim = 512 + + def encoder(self): + """ Encoder Network """ + input_ = Input(shape=self.input_shape) + var_x = input_ + var_x = Conv2DBlock(128, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(256, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(512, activation="leakyrelu")(var_x) + var_x = layers.Dense(self.encoder_dim)(layers.Flatten()(var_x)) + var_x = layers.Dense(4 * 4 * 512)(var_x) + var_x = layers.Reshape((4, 4, 512))(var_x) + var_x = UpscaleBlock(256, activation="leakyrelu")(var_x) + return KModel(input_, var_x, name="encoder") + + def decoder(self, side): + """ Decoder Network """ + input_ = Input(shape=(8, 8, 256)) + var_x = input_ + var_x = UpscaleBlock(512, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(256, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(128, activation="leakyrelu")(var_x) + var_x = Conv2DOutput(3, 5, activation="sigmoid", name=f"face_out_{side}")(var_x) + outputs = [var_x] + + if cfg_loss.learn_mask(): + var_y = input_ + var_y = UpscaleBlock(512, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(256, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(128, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, + activation="sigmoid", + name=f"mask_out_{side}")(var_y) + outputs.append(var_y) + return KModel(input_, outputs=outputs, name=f"decoder_{side}") diff --git a/plugins/train/model/original.py b/plugins/train/model/original.py index e065aa00c8..3637a61157 100644 --- a/plugins/train/model/original.py +++ b/plugins/train/model/original.py @@ -1,83 +1,158 @@ #!/usr/bin/env python3 """ Original Model - Based on the original https://www.reddit.com/r/deepfakes/ - code sample + contribs """ +Based on the original https://www.reddit.com/r/deepfakes/ code sample + contributions. -from keras.layers import Conv2D, Dense, Flatten, Input, Reshape +This model is heavily documented as it acts as a template that other model plugins can be developed +from. +""" -from keras.models import Model as KerasModel +from keras import Input, layers, Model as KModel -from ._base import ModelBase, logger +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, UpscaleBlock +from lib.utils import get_module_objects +from plugins.train.train_config import Loss as cfg_loss +from ._base import ModelBase +from . import original_defaults as cfg +# pylint:disable=duplicate-code class Model(ModelBase): - """ Original Faceswap Model """ - def __init__(self, *args, **kwargs): - logger.debug("Initializing %s: (args: %s, kwargs: %s", - self.__class__.__name__, args, kwargs) + """ Original Faceswap Model. + + This is the original faceswap model and acts as a template for plugin development. + + All plugins must define the following attribute override after calling the parent's + :func:`__init__` method: - if "input_shape" not in kwargs: - kwargs["input_shape"] = (64, 64, 3) - if "encoder_dim" not in kwargs: - kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024 + * :attr:`input_shape` (`tuple` or `list`): a tuple of ints defining the shape of the \ + faces that the model takes as input. If the input size is the same for both sides, this \ + can be a single 3 dimensional tuple. If the inputs have different sizes for "A" and "B" \ + this should be a list of 2 3 dimensional shape tuples, 1 for each side. + Any additional attributes used exclusively by this model should be defined here, but make sure + that you are not accidentally overriding any existing + :class:`~plugins.train.model._base.ModelBase` attributes. + + Parameters + ---------- + args: varies + The default command line arguments passed in from :class:`~scripts.train.Train` or + :class:`~scripts.train.Convert` + kwargs: varies + The default keyword arguments passed in from :class:`~scripts.train.Train` or + :class:`~scripts.train.Convert` + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - logger.debug("Initialized %s", self.__class__.__name__) - - def add_networks(self): - """ Add the original model weights """ - logger.debug("Adding networks") - self.add_network("decoder", "a", self.decoder()) - self.add_network("decoder", "b", self.decoder()) - self.add_network("encoder", None, self.encoder()) - logger.debug("Added networks") - - def build_autoencoders(self): - """ Initialize original model """ - logger.debug("Initializing model") - inputs = [Input(shape=self.input_shape, name="face")] - if self.config.get("mask_type", None): - mask_shape = (self.input_shape[:2] + (1, )) - inputs.append(Input(shape=mask_shape, name="mask")) - - for side in ("a", "b"): - logger.debug("Adding Autoencoder. Side: %s", side) - decoder = self.networks["decoder_{}".format(side)].network - output = decoder(self.networks["encoder"].network(inputs[0])) - autoencoder = KerasModel(inputs, output) - self.add_predictor(side, autoencoder) - logger.debug("Initialized model") + self.input_shape = (64, 64, 3) + self.low_mem = cfg.lowmem() + self.learn_mask = cfg_loss.learn_mask() + self.encoder_dim = 512 if self.low_mem else 1024 + + def build_model(self, inputs): + """ Create the model's structure. + + This function is automatically called immediately after :func:`__init__` has been called if + a new model is being created. It is ignored if an existing model is being loaded from disk + as the model structure will be defined in the saved model file. + + The model's final structure is defined here. + + For the original model, An encoder instance is defined, then the same instance is + referenced twice, one for each input "A" and "B" so that the same model is used for + both inputs. + + 2 Decoders are then defined (one for each side) with the encoder instances passed in as + input to the corresponding decoders. + + The final output of the model should always call :class:`lib.model.nn_blocks.Conv2DOutput` + so that the correct data type is set for the final activation, to support Mixed Precision + Training. Failure to do so is likely to lead to issues when Mixed Precision is enabled. + + Parameters + ---------- + inputs: list + A list of input tensors for the model. This will be a list of 2 tensors of + shape :attr:`input_shape`, the first for side "a", the second for side "b". + + Returns + ------- + :class:`keras.models.Model` + See Keras documentation for the correct + structure, but note that parameter :attr:`name` is a required rather than an optional + argument in Faceswap. You should assign this to the attribute ``self.name`` that is + automatically generated from the plugin's filename. + """ + input_a = inputs[0] + input_b = inputs[1] + + encoder = self.encoder() + encoder_a = encoder(input_a) + encoder_b = encoder(input_b) + + outputs = self.decoder("a")(encoder_a) + self.decoder("b")(encoder_b) + + autoencoder = KModel(inputs, outputs, name=self.model_name) + return autoencoder def encoder(self): - """ Encoder Network """ + """ The original Faceswap Encoder Network. + + The encoder for the original model has it's weights shared between both the "A" and "B" + side of the model, so only one instance is created :func:`build_model`. However this same + instance is then used twice (once for A and once for B) meaning that the weights get + shared. + + Returns + ------- + :class:`keras.models.Model` + The Keras encoder model, for sharing between inputs from both sides. + """ input_ = Input(shape=self.input_shape) var_x = input_ - var_x = self.blocks.conv(var_x, 128) - var_x = self.blocks.conv(var_x, 256) - var_x = self.blocks.conv(var_x, 512) - if not self.config.get("lowmem", False): - var_x = self.blocks.conv(var_x, 1024) - var_x = Dense(self.encoder_dim)(Flatten()(var_x)) - var_x = Dense(4 * 4 * 1024)(var_x) - var_x = Reshape((4, 4, 1024))(var_x) - var_x = self.blocks.upscale(var_x, 512) - return KerasModel(input_, var_x) - - def decoder(self): - """ Decoder Network """ + var_x = Conv2DBlock(128, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(256, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(512, activation="leakyrelu")(var_x) + if not self.low_mem: + var_x = Conv2DBlock(1024, activation="leakyrelu")(var_x) + var_x = layers.Dense(self.encoder_dim)(layers.Flatten()(var_x)) + var_x = layers.Dense(4 * 4 * 1024)(var_x) + var_x = layers.Reshape((4, 4, 1024))(var_x) + var_x = UpscaleBlock(512, activation="leakyrelu")(var_x) + return KModel(input_, var_x, name="encoder") + + def decoder(self, side): + """ The original Faceswap Decoder Network. + + The decoders for the original model have separate weights for each side "A" and "B", so two + instances are created in :func:`build_model`, one for each side. + + Parameters + ---------- + side: str + Either `"a` or `"b"`. This is used for naming the decoder model. + + Returns + ------- + :class:`keras.models.Model` + The Keras decoder model. This will be called twice, once for each side. + """ input_ = Input(shape=(8, 8, 512)) var_x = input_ - var_x = self.blocks.upscale(var_x, 256) - var_x = self.blocks.upscale(var_x, 128) - var_x = self.blocks.upscale(var_x, 64) - var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x) + var_x = UpscaleBlock(256, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(128, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(64, activation="leakyrelu")(var_x) + var_x = Conv2DOutput(3, 5, name=f"face_out_{side}")(var_x) outputs = [var_x] - if self.config.get("mask_type", None): + if self.learn_mask: var_y = input_ - var_y = self.blocks.upscale(var_y, 256) - var_y = self.blocks.upscale(var_y, 128) - var_y = self.blocks.upscale(var_y, 64) - var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + var_y = UpscaleBlock(256, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(128, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(64, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name=f"mask_out_{side}")(var_y) outputs.append(var_y) - return KerasModel(input_, outputs=outputs) + return KModel(input_, outputs=outputs, name=f"decoder_{side}") + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/model/original_defaults.py b/plugins/train/model/original_defaults.py new file mode 100755 index 0000000000..3a519ac2f2 --- /dev/null +++ b/plugins/train/model/original_defaults.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Original Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = "Original Faceswap Model." + + +lowmem = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\n" + "NB: Models with a changed lowmem mode are not compatible with each other.", + fixed=True) diff --git a/plugins/train/model/phaze_a.py b/plugins/train/model/phaze_a.py new file mode 100644 index 0000000000..21903cd619 --- /dev/null +++ b/plugins/train/model/phaze_a.py @@ -0,0 +1,1332 @@ +#!/usr/bin/env python3 +""" Phaze-A Model by TorzDF with thanks to BirbFakes and the myriad of testers. """ + +# pylint:disable=too-many-lines +from __future__ import annotations +import logging +import typing as T +from dataclasses import dataclass + +import numpy as np +import keras +from keras import applications as kapp, layers as kl + +from lib.logger import parse_class_init +from lib.model.nn_blocks import ( + Conv2D, Conv2DBlock, Conv2DOutput, ResidualBlock, UpscaleBlock, Upscale2xBlock, + UpscaleResizeImagesBlock, UpscaleDNYBlock) +from lib.model.normalization import ( + AdaInstanceNormalization, GroupNormalization, InstanceNormalization, RMSNormalization) +from lib.model.networks import ViT, TypeModelsViT +from lib.utils import get_keras_version, FaceswapError +from plugins.train.train_config import Loss as cfg_loss + +from ._base import ModelBase, get_all_sub_models +from . import phaze_a_defaults as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + +logger = logging.getLogger(__name__) + + +@dataclass +class _EncoderInfo: + """ Contains model configuration options for various Phaze-A Encoders. + + Parameters + ---------- + keras_name: str + The name of the encoder in Keras Applications. Empty string `""` if the encoder does not + exist in Keras Applications + default_size: int + The default input size of the encoder + keras_min: float, optional + The lowest version of Keras that the encoder can be used for. Default: `3.0` + scaling: tuple, optional + The float scaling that the encoder expects. Default: `(0, 1)` + min_size: int, optional + The minimum input size that the encoder will allow. Default: 32 + enforce_for_weights: bool, optional + ``True`` if the input size for the model must be forced to the default size when loading + imagenet weights, otherwise ``False``. Default: ``False`` + color_order: str, optional + The color order that the model expects (`"bgr"` or `"rgb"`). Default: `"rgb"` + """ + keras_name: str + default_size: int + keras_min: tuple[int, int] = (3, 0) + scaling: tuple[int, int] = (0, 1) + min_size: int = 32 + enforce_for_weights: bool = False + color_order: T.Literal["bgr", "rgb"] = "rgb" + + +_MODEL_MAPPING: dict[str, _EncoderInfo] = { + "clipv_farl-b-16-16": _EncoderInfo( + keras_name="FaRL-B-16-16", default_size=224), + "clipv_farl-b-16-64": _EncoderInfo( + keras_name="FaRL-B-16-64", default_size=224), + "clipv_vit-b-16": _EncoderInfo( + keras_name="ViT-B-16", default_size=224), + "clipv_vit-b-32": _EncoderInfo( + keras_name="ViT-B-32", default_size=224), + "clipv_vit-l-14": _EncoderInfo( + keras_name="ViT-L-14", default_size=224), + "clipv_vit-l-14-336px": _EncoderInfo( + keras_name="ViT-L-14-336px", default_size=336), + "convnext_tiny": _EncoderInfo( + keras_name="ConvNeXtTiny", scaling=(0, 255), default_size=224), + "convnext_small": _EncoderInfo( + keras_name="ConvNeXtSmall", scaling=(0, 255), default_size=224), + "convnext_base": _EncoderInfo( + keras_name="ConvNeXtBase", scaling=(0, 255), default_size=224), + "convnext_large": _EncoderInfo( + keras_name="ConvNeXtLarge", scaling=(0, 255), default_size=224), + "convnext_extra_large": _EncoderInfo( + keras_name="ConvNeXtXLarge", scaling=(0, 255), default_size=224), + "densenet121": _EncoderInfo( + keras_name="DenseNet121", default_size=224), + "densenet169": _EncoderInfo( + keras_name="DenseNet169", default_size=224), + "densenet201": _EncoderInfo( + keras_name="DenseNet201", default_size=224), + "efficientnet_b0": _EncoderInfo( + keras_name="EfficientNetB0", scaling=(0, 255), default_size=224), + "efficientnet_b1": _EncoderInfo( + keras_name="EfficientNetB1", scaling=(0, 255), default_size=240), + "efficientnet_b2": _EncoderInfo( + keras_name="EfficientNetB2", scaling=(0, 255), default_size=260), + "efficientnet_b3": _EncoderInfo( + keras_name="EfficientNetB3", scaling=(0, 255), default_size=300), + "efficientnet_b4": _EncoderInfo( + keras_name="EfficientNetB4", scaling=(0, 255), default_size=380), + "efficientnet_b5": _EncoderInfo( + keras_name="EfficientNetB5", scaling=(0, 255), default_size=456), + "efficientnet_b6": _EncoderInfo( + keras_name="EfficientNetB6", scaling=(0, 255), default_size=528), + "efficientnet_b7": _EncoderInfo( + keras_name="EfficientNetB7", scaling=(0, 255), default_size=600), + "efficientnet_v2_b0": _EncoderInfo( + keras_name="EfficientNetV2B0", scaling=(-1, 1), default_size=224), + "efficientnet_v2_b1": _EncoderInfo( + keras_name="EfficientNetV2B1", scaling=(-1, 1), default_size=240), + "efficientnet_v2_b2": _EncoderInfo( + keras_name="EfficientNetV2B2", scaling=(-1, 1), default_size=260), + "efficientnet_v2_b3": _EncoderInfo( + keras_name="EfficientNetV2B3", scaling=(-1, 1), default_size=300), + "efficientnet_v2_s": _EncoderInfo( + keras_name="EfficientNetV2S", scaling=(-1, 1), default_size=384), + "efficientnet_v2_m": _EncoderInfo( + keras_name="EfficientNetV2M", scaling=(-1, 1), default_size=480), + "efficientnet_v2_l": _EncoderInfo( + keras_name="EfficientNetV2L", scaling=(-1, 1), default_size=480), + "inception_resnet_v2": _EncoderInfo( + keras_name="InceptionResNetV2", scaling=(-1, 1), min_size=75, default_size=299), + "inception_v3": _EncoderInfo( + keras_name="InceptionV3", scaling=(-1, 1), min_size=75, default_size=299), + "mobilenet": _EncoderInfo( + keras_name="MobileNet", scaling=(-1, 1), default_size=224), + "mobilenet_v2": _EncoderInfo( + keras_name="MobileNetV2", scaling=(-1, 1), default_size=224), + "mobilenet_v3_large": _EncoderInfo( + keras_name="MobileNetV3Large", scaling=(-1, 1), default_size=224), + "mobilenet_v3_small": _EncoderInfo( + keras_name="MobileNetV3Small", scaling=(-1, 1), default_size=224), + "nasnet_large": _EncoderInfo( + keras_name="NASNetLarge", scaling=(-1, 1), default_size=331, enforce_for_weights=True), + "nasnet_mobile": _EncoderInfo( + keras_name="NASNetMobile", scaling=(-1, 1), default_size=224, enforce_for_weights=True), + "resnet50": _EncoderInfo( + keras_name="ResNet50", scaling=(-1, 1), min_size=32, default_size=224), + "resnet50_v2": _EncoderInfo( + keras_name="ResNet50V2", scaling=(-1, 1), default_size=224), + "resnet101": _EncoderInfo( + keras_name="ResNet101", scaling=(-1, 1), default_size=224), + "resnet101_v2": _EncoderInfo( + keras_name="ResNet101V2", scaling=(-1, 1), default_size=224), + "resnet152": _EncoderInfo( + keras_name="ResNet152", scaling=(-1, 1), default_size=224), + "resnet152_v2": _EncoderInfo( + keras_name="ResNet152V2", scaling=(-1, 1), default_size=224), + "vgg16": _EncoderInfo( + keras_name="VGG16", color_order="bgr", scaling=(0, 255), default_size=224), + "vgg19": _EncoderInfo( + keras_name="VGG19", color_order="bgr", scaling=(0, 255), default_size=224), + "xception": _EncoderInfo( + keras_name="Xception", scaling=(-1, 1), min_size=71, default_size=299), + "fs_original": _EncoderInfo( + keras_name="", color_order="bgr", min_size=32, default_size=1024)} + + +class Model(ModelBase): + """ Phaze-A Faceswap Model. + + An highly adaptable and configurable model by torzDF + + Parameters + ----------513 + args: varies + The default command line arguments passed in from :class:`~scripts.train.Train` or + :class:`~scripts.train.Convert` + kwargs: varies + The default keyword arguments passed in from :class:`~scripts.train.Train` or + :class:`~scripts.train.Convert` + """ + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if cfg.output_size() % 16 != 0: + raise FaceswapError("Phaze-A output shape must be a multiple of 16") + + self._validate_encoder_architecture() + + self.input_shape: tuple[int, int, int] = self._get_input_shape() + self.color_order = _MODEL_MAPPING[cfg.enc_architecture()].color_order + + @property + def freeze_layers(self) -> list[str]: + """ list[str] : Valid layers to freeze based on configured options """ + return self._select_real_layers(cfg.freeze_layers()) + + @property + def load_layers(self) -> list[str]: + """ list[str] : Valid layers to load based on configured options """ + return self._select_real_layers(cfg.load_layers()) + + def build(self) -> None: + """ Build the model and assign to :attr:`model`. + + Override's the default build function for allowing the setting of dropout rate for pre- + existing models. + """ + is_summary = hasattr(self._args, "summary") and self._args.summary + if not self._io.model_exists or self._is_predict or is_summary: + logger.debug("New model, inference or summary. Falling back to default build: " + "(exists: %s, inference: %s, is_summary: %s)", + self._io.model_exists, self._is_predict, is_summary) + super().build() + return + model = self.io.load() + model = self._update_dropouts(model) + self._model = model + self._compile_model() + self._output_summary() + + def _update_dropouts(self, model: keras.models.Model) -> keras.models.Model: + """ Update the saved model with new dropout rates. + + Keras, annoyingly, does not actually change the dropout of the underlying layer, so we need + to update the rate, then clone the model into a new model and reload weights. + + Parameters + ---------- + model: :class:`keras.models.Model` + The loaded saved Keras Model to update the dropout rates for + + Returns + ------- + :class:`keras.models.Model` + The loaded Keras Model with the dropout rates updated + """ + dropouts = {"fc": cfg.fc_dropout(), "gblock": cfg.fc_gblock_dropout()} + logger.debug("Config dropouts: %s", dropouts) + updated = False + for mod in get_all_sub_models(model): + if not mod.name.startswith("fc_"): + continue + key = "gblock" if "gblock" in mod.name else mod.name.split("_")[0] + rate = dropouts[key] + log_once = False + for layer in mod.layers: + if not isinstance(layer, kl.Dropout): + continue + if layer.rate != rate: + logger.debug("Updating dropout rate for %s from %s to %s", + f"{mod.name} - {layer.name}", layer.rate, rate) + if not log_once: + logger.info("Updating Dropout Rate for '%s' from %s to %s", + mod.name, layer.rate, rate) + log_once = True + layer.rate = rate + updated = True + if updated: + logger.debug("Dropout rate updated. Cloning model") + new_model = keras.models.clone_model(model) + new_model.set_weights(model.get_weights()) + del model + model = new_model + return model + + def _select_real_layers(self, layers: list[str]) -> list[str]: + """ Process the selected freeze or load layers configuration options and replace the + `keras_encoder` option with the actual keras model name for the configured architecture + + Returns + ------- + list + The selected layers for weight freezing + """ + arch = cfg.enc_architecture() + # EfficientNetV2 is inconsistent with other model's naming conventions + keras_name = _MODEL_MAPPING[arch].keras_name.replace("EfficientNetV2", "EfficientNetV2-") + # CLIPv model is always called 'visual' regardless of weights/format loaded + keras_name = "visual" if arch.startswith("clipv_") else keras_name + + if "keras_encoder" not in cfg.freeze_layers(): + retval = layers + elif keras_name: + retval = [layer.replace("keras_encoder", keras_name) for layer in layers] + logger.debug("Substituting 'keras_encoder' for '%s'", keras_name) + else: + retval = [layer for layer in layers if layer != "keras_encoder"] + logger.debug("Removing 'keras_encoder' for '%s'", keras_name) + + return retval + + def _get_input_shape(self) -> tuple[int, int, int]: + """ Obtain the input shape for the model. + + Input shape is calculated from the selected Encoder's input size, scaled to the user + selected Input Scaling, rounded down to the nearest 16 pixels. + + Notes + ----- + Some models (NasNet) require the input size to be of a certain dimension if loading + imagenet weights. In these instances resize inputs and raise warning message + + Returns + ------- + tuple + The shape tuple for the input size to the Phaze-A model + """ + arch = cfg.enc_architecture() + enforce_size = _MODEL_MAPPING[arch].enforce_for_weights + default_size = _MODEL_MAPPING[arch].default_size + scaling = cfg.enc_scaling() / 100 + + min_size = _MODEL_MAPPING[arch].min_size + size = int(max(min_size, ((default_size * scaling) // 16) * 16)) + + if cfg.enc_load_weights() and enforce_size and scaling != 1.0: + logger.warning("%s requires input size to be %spx when loading imagenet weights. " + "Adjusting input size from %spx to %spx", + arch, default_size, size, default_size) + retval = (default_size, default_size, 3) + else: + retval = (size, size, 3) + + logger.debug("Encoder input set to: %s", retval) + return retval + + def _validate_encoder_architecture(self) -> None: + """ Validate that the requested architecture is a valid choice for the running system + configuration. + + If the selection is not valid, an error is logged and system exits. + """ + arch = cfg.enc_architecture() + model = _MODEL_MAPPING.get(arch) + if not model: + raise FaceswapError(f"'{arch}' is not a valid choice for encoder architecture. Choose " + f"one of {list(_MODEL_MAPPING.keys())}.") + + keras_ver = get_keras_version() + keras_min = model.keras_min + if keras_ver < keras_min: + raise FaceswapError(f"{arch}' is not compatible with your version of Keras. The " + f"minimum version required is {keras_min} whilst you have version " + f"{keras_ver} installed.") + + def build_model(self, inputs: list[KerasTensor]) -> keras.models.Model: + """ Create the model's structure. + + Parameters + ---------- + inputs: list[:class:`keras.KerasTensor`] + A list of input tensors for the model. This will be a list of 2 tensors of + shape :attr:`input_shape`, the first for side "a", the second for side "b". + + Returns + ------- + :class:`keras.models.Model` + The generated model + """ + # Create sub-Models + encoders = self._build_encoders(inputs) + inters = self._build_fully_connected(encoders) + g_blocks = self._build_g_blocks(inters) + decoders = self._build_decoders(g_blocks) + + # Create Autoencoder + outputs = decoders["a"] + decoders["b"] + autoencoder = keras.models.Model(inputs, outputs, name=self.model_name) + return autoencoder + + def _build_encoders(self, inputs: list[KerasTensor]) -> dict[str, keras.models.Model]: + """ Build the encoders for Phaze-A + + Parameters + ---------- + inputs: list[:class:`keras.KerasTensor`] + A list of input tensors for the model. This will be a list of 2 tensors of + shape :attr:`input_shape`, the first for side "a", the second for side "b". + + Returns + ------- + dict + side as key ('a' or 'b'), encoder for side as value + """ + encoder = Encoder(self.input_shape)() + retval = {"a": encoder(inputs[0]), "b": encoder(inputs[1])} + logger.debug("Encoders: %s", retval) + return retval + + def _build_fully_connected( + self, + inputs: dict[str, keras.models.Model]) -> dict[str, list[keras.models.Model]]: + """ Build the fully connected layers for Phaze-A + + Parameters + ---------- + inputs: dict + The compiled encoder models that act as inputs to the fully connected layers + + Returns + ------- + dict + side as key ('a' or 'b'), fully connected model for side as value + """ + input_shapes = inputs["a"].shape[1:] + + fc_a = fc_both = None + if cfg.split_fc(): + fc_a = FullyConnected("a", input_shapes)() + inter_a = [fc_a(inputs["a"])] + inter_b = [FullyConnected("b", input_shapes)()(inputs["b"])] + else: + fc_both = FullyConnected("both", input_shapes)() + inter_a = [fc_both(inputs["a"])] + inter_b = [fc_both(inputs["b"])] + + shared_fc = None if cfg.shared_fc() == "none" else cfg.shared_fc() + if shared_fc: + if shared_fc == "full": + fc_shared = FullyConnected("shared", input_shapes)() + elif cfg.split_fc(): + assert fc_a is not None + fc_shared = fc_a + else: + assert fc_both is not None + fc_shared = fc_both + inter_a = [kl.Concatenate(name="inter_a")([inter_a[0], fc_shared(inputs["a"])])] + inter_b = [kl.Concatenate(name="inter_b")([inter_b[0], fc_shared(inputs["b"])])] + + if cfg.enable_gblock(): + fc_gblock = FullyConnected("gblock", input_shapes)() + inter_a.append(fc_gblock(inputs["a"])) + inter_b.append(fc_gblock(inputs["b"])) + + inter_a = inter_a[0] if len(inter_a) == 1 else inter_a + inter_b = inter_b[0] if len(inter_b) == 1 else inter_b + retval = {"a": inter_a, "b": inter_b} + logger.debug("Fully Connected: %s", retval) + return retval + + def _build_g_blocks( + self, + inputs: dict[str, list[keras.models.Model]] + ) -> dict[str, list[keras.models.Model] | keras.models.Model]: + """ Build the g-block layers for Phaze-A. + + If a g-block has not been selected for this model, then the original `inters` models are + returned for passing straight to the decoder + + Parameters + ---------- + inputs: dict + The compiled inter models that act as inputs to the g_blocks + + Returns + ------- + dict + side as key ('a' or 'b'), g-block model for side as value. If g-block has been disabled + then the values will be the fully connected layers + """ + if not cfg.enable_gblock(): + logger.debug("No G-Block selected, returning Inters: %s", inputs) + return inputs + + input_shapes = [inter.shape[1:] for inter in inputs["a"]] + if cfg.split_gblock(): + retval = {"a": GBlock("a", input_shapes)()(inputs["a"]), + "b": GBlock("b", input_shapes)()(inputs["b"])} + else: + g_block = GBlock("both", input_shapes)() + retval = {"a": g_block((inputs["a"])), "b": g_block((inputs["b"]))} + + logger.debug("G-Blocks: %s", retval) + return retval + + def _build_decoders(self, + inputs: dict[str, list[keras.models.Model] | keras.models.Model] + ) -> dict[str, keras.models.Model]: + """ Build the encoders for Phaze-A + + Parameters + ---------- + inputs: dict + A dict of inputs to the decoder. This will either be g-block output (if g-block is + enabled) or fully connected layers output (if g-block is disabled). + + Returns + ------- + dict + side as key ('a' or 'b'), decoder for side as value + """ + input_ = inputs["a"] + # If input is inters, shapes will be a list. + # There will only ever be 1 input. For inters: either inter out, or concatenate of inters + # For g-block, this only ever has one output + input_ = input_[0] if isinstance(input_, list) else input_ + + # If learning a mask and upscales have been placed into FC layer, then the mask will also + # come as an input + if cfg_loss.learn_mask() and cfg.dec_upscales_in_fc(): + input_ = input_[0] + + input_shape = input_.shape[1:] + + if cfg.split_decoders(): + retval = {"a": Decoder("a", input_shape)()(inputs["a"]), + "b": Decoder("b", input_shape)()(inputs["b"])} + else: + decoder = Decoder("both", input_shape)() + retval = {"a": decoder(inputs["a"]), "b": decoder(inputs["b"])} + + logger.debug("Decoders: %s", retval) + return retval + + +def _bottleneck(inputs: KerasTensor, bottleneck: str, size: int, normalization: str + ) -> KerasTensor: + """ The bottleneck fully connected layer. Can be called from Encoder or FullyConnected layers. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the bottleneck layer + bottleneck: str or ``None`` + The type of layer to use for the bottleneck. ``None`` to not use a bottleneck + size: int + The number of nodes for the dense layer (if selected) + normalization: str + The normalization method to use prior to the bottleneck layer + + Returns + ------- + :class:`keras.KerasTensor` + The output from the bottleneck + """ + norm = None if normalization == "none" else normalization + norms = {"layer": kl.LayerNormalization, + "rms": RMSNormalization, + "instance": InstanceNormalization} + bottlenecks = {"average_pooling": kl.GlobalAveragePooling2D(), + "dense": kl.Dense(size), + "max_pooling": kl.GlobalMaxPooling2D()} + var_x = inputs + if norm: + var_x = norms[norm]()(var_x) + if bottleneck == "dense" and var_x.ndim > 2: # Flatten non-1D inputs for dense + var_x = kl.Flatten()(var_x) + if bottleneck != "flatten": + var_x = bottlenecks[bottleneck](var_x) + if var_x.ndim > 2: + # Flatten prior to fc layers + var_x = kl.Flatten()(var_x) + return var_x + + +def _get_upscale_layer(method: T.Literal["resize_images", "subpixel", "upscale_dny", + "upscale_fast", "upscale_hybrid", "upsample2d"], + filters: int, + activation: str | None = None, + upsamples: int | None = None, + interpolation: str | None = None) -> keras.layers.Layer: + """ Obtain an instance of the requested upscale method. + + Parameters + ---------- + method: str + The user selected upscale method to use. One of `"resize_images"`, `"subpixel"`, + `"upscale_dny"`, `"upscale_fast"`, `"upscale_hybrid"`, `"upsample2d"` + filters: int + The number of filters to use in the upscale layer + activation: str, optional + The activation function to use in the upscale layer. ``None`` to use no activation. + Default: ``None`` + upsamples: int, optional + Only used for UpSampling2D. If provided, then this is passed to the layer as the ``size`` + parameter. Default: ``None`` + interpolation: str, optional + Only used for UpSampling2D. If provided, then this is passed to the layer as the + ``interpolation`` parameter. Default: ``None`` + + Returns + ------- + :class:`keras.layers.Layer` + The selected configured upscale layer + """ + if method == "upsample2d": + kwargs: dict[str, str | int] = {} + if upsamples: + kwargs["size"] = upsamples + if interpolation: + kwargs["interpolation"] = interpolation + return kl.UpSampling2D(**kwargs) + if method == "subpixel": + return UpscaleBlock(filters, activation=activation) + if method == "upscale_fast": + return Upscale2xBlock(filters, activation=activation, fast=True) + if method == "upscale_hybrid": + return Upscale2xBlock(filters, activation=activation, fast=False) + if method == "upscale_dny": + return UpscaleDNYBlock(filters, activation=activation) + return UpscaleResizeImagesBlock(filters, activation=activation) + + +def _get_curve(start_y: int, + end_y: int, + num_points: int, + scale: float, + mode: T.Literal["full", "cap_max", "cap_min"] = "full") -> list[int]: + """ Obtain a curve. + + For the given start and end y values, return the y co-ordinates of a curve for the given + number of points. The points are rounded down to the nearest 8. + + Parameters + ---------- + start_y: int + The y co-ordinate for the starting point of the curve + end_y: int + The y co-ordinate for the end point of the curve + num_points: int + The number of data points to plot on the x-axis + scale: float + The scale of the curve (from -.99 to 0.99) + slope_mode: str, optional + The method to generate the curve. One of `"full"`, `"cap_max"` or `"cap_min"`. `"full"` + mode generates a curve from the `"start_y"` to the `"end_y"` values. `"cap_max"` pads the + earlier points with the `"start_y"` value before filling out the remaining points at a + fixed divider to the `"end_y"` value. `"cap_min"` starts at the `"start_y" filling points + at a fixed divider until the `"end_y"` value is reached and pads the remaining points with + the `"end_y"` value. Default: `"full"` + + Returns + ------- + list + List of ints of points for the given curve + """ + scale = min(.99, max(-.99, scale)) + logger.debug("Obtaining curve: (start_y: %s, end_y: %s, num_points: %s, scale: %s, mode: %s)", + start_y, end_y, num_points, scale, mode) + if mode == "full": + x_axis = np.linspace(0., 1., num=num_points) + y_axis = (x_axis - x_axis * scale) / (scale - abs(x_axis) * 2 * scale + 1) + y_axis = y_axis * (end_y - start_y) + start_y + retval = [int((y // 8) * 8) for y in y_axis] + else: + y_axis = [start_y] + scale = 1. - abs(scale) + for _ in range(num_points - 1): + current_value = max(end_y, int(((y_axis[-1] * scale) // 8) * 8)) + y_axis.append(current_value) + if current_value == end_y: + break + pad = [start_y if mode == "cap_max" else end_y for _ in range(num_points - len(y_axis))] + retval = pad + y_axis if mode == "cap_max" else y_axis + pad + logger.debug("Returning curve: %s", retval) + return retval + + +def _scale_dim(target_resolution: int, original_dim: int) -> int: + """ Scale a given `original_dim` so that it is a factor of the target resolution. + + Parameters + ---------- + target_resolution: int + The output resolution that is being targetted + original_dim: int + The dimension that needs to be checked for compatibility for upscaling to the + target resolution + + Returns + ------- + int + The highest dimension below or equal to `original_dim` that is a factor of the + target resolution. + """ + new_dim = target_resolution + while new_dim > original_dim: + next_dim = new_dim / 2 + if not next_dim.is_integer(): + break + new_dim = int(next_dim) + logger.debug("target_resolution: %s, original_dim: %s, new_dim: %s", + target_resolution, original_dim, new_dim) + return new_dim + + +class Encoder(): + """ Encoder. Uses one of pre-existing Keras/Faceswap models or custom encoder. + + Parameters + ---------- + input_shape: tuple + The shape tuple for the input tensor + """ + def __init__(self, input_shape: tuple[int, int, int]) -> None: + logger.debug(parse_class_init(locals())) + self.input_shape = input_shape + self._input_shape = input_shape + + @property + def _model_kwargs(self) -> dict[str, dict[str, float | int | bool]]: + """ dict: Configuration option for architecture mapped to optional kwargs. """ + return {"mobilenet": {"alpha": cfg.mobilenet_width(), + "depth_multiplier": cfg.mobilenet_depth(), + "dropout": cfg.mobilenet_dropout()}, + "mobilenet_v2": {"alpha": cfg.mobilenet_width()}, + "mobilenet_v3": {"alpha": cfg.mobilenet_width(), + "minimalist": cfg.mobilenet_minimalistic(), + "include_preprocessing": False}} + + @property + def _selected_model(self) -> tuple[_EncoderInfo, dict]: + """ tuple(dict, :class:`_EncoderInfo`): The selected encoder model and it's associated + keyword arguments """ + arch = cfg.enc_architecture() + model = _MODEL_MAPPING[arch] + kwargs = self._model_kwargs.get(arch, {}) + if arch.startswith("efficientnet_v2"): + kwargs["include_preprocessing"] = False + return model, kwargs + + def __call__(self) -> keras.models.Model: + """ Create the Phaze-A Encoder Model. + + Returns + ------- + :class:`keras.models.Model` + The selected Encoder Model + """ + input_ = T.cast("KerasTensor", kl.Input(shape=self._input_shape)) + var_x = input_ + + scaling = self._selected_model[0].scaling + + if scaling: + # Some models expect different scaling. + logger.debug("Scaling to %s for '%s'", scaling, cfg.enc_architecture()) + if scaling == (0, 255): + # models expecting inputs from 0 to 255. + var_x = var_x * 255. + if scaling == (-1, 1): + # models expecting inputs from -1 to 1. + var_x = var_x * 2. + var_x = var_x - 1.0 + + var_x = self._get_encoder_model()(var_x) + + if cfg.bottleneck_in_encoder(): + var_x = _bottleneck(var_x, + cfg.bottleneck_type(), + cfg.bottleneck_size(), + cfg.bottleneck_norm()) + + return keras.models.Model(input_, var_x, name="encoder") + + def _get_encoder_model(self) -> keras.models.Model: + """ Return the model defined by the selected architecture. + + Returns + ------- + :class:`keras.Model` + The selected keras model for the chosen encoder architecture + """ + model, kwargs = self._selected_model + if model.keras_name and cfg.enc_architecture().startswith("clipv_"): + assert model.keras_name in T.get_args(TypeModelsViT) + kwargs["input_shape"] = self._input_shape + kwargs["load_weights"] = cfg.enc_load_weights() + retval = ViT(T.cast(TypeModelsViT, model.keras_name), + input_size=self._input_shape[0], + load_weights=cfg.enc_load_weights())() + elif model.keras_name: + kwargs["input_shape"] = self._input_shape + kwargs["include_top"] = False + kwargs["weights"] = "imagenet" if cfg.enc_load_weights() else None + retval = getattr(kapp, model.keras_name)(**kwargs) + else: + retval = _EncoderFaceswap() + return retval + + +class _EncoderFaceswap(): + """ A configurable standard Faceswap encoder based off Original model. """ + def __init__(self) -> None: + logger.debug(parse_class_init(locals())) + self._type = cfg.enc_architecture() + self._depth = getattr(cfg, f"{self._type}_depth")() + self._min_filters = cfg.fs_original_min_filters() + self._max_filters = cfg.fs_original_max_filters() + self._is_alt = cfg.fs_original_use_alt() + self._relu_alpha = 0.2 if self._is_alt else 0.1 + self._kernel_size = 3 if self._is_alt else 5 + self._strides = 1 if self._is_alt else 2 + + def __call__(self, inputs: KerasTensor) -> KerasTensor: + """ Call the original Faceswap Encoder + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the Faceswap Encoder + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the Faceswap Encoder + """ + var_x = inputs + filters = cfg.fs_original_min_filters() + + if self._is_alt: + var_x = Conv2DBlock(filters, + kernel_size=1, + strides=self._strides, + relu_alpha=self._relu_alpha)(var_x) + + for i in range(self._depth): + name = f"fs_{'dny_' if self._is_alt else ''}enc" + var_x = Conv2DBlock(filters, + kernel_size=self._kernel_size, + strides=self._strides, + relu_alpha=self._relu_alpha, + name=f"{name}_convblk_{i}")(var_x) + filters = min(cfg.fs_original_max_filters(), filters * 2) + if self._is_alt and i == self._depth - 1: + var_x = Conv2DBlock(filters, + kernel_size=4, + strides=self._strides, + padding="valid", + relu_alpha=self._relu_alpha, + name=f"{name}_convblk_{i}_1")(var_x) + elif self._is_alt: + var_x = Conv2DBlock(filters, + kernel_size=self._kernel_size, + strides=self._strides, + relu_alpha=self._relu_alpha, + name=f"{name}_convblk_{i}_1")(var_x) + var_x = kl.MaxPooling2D(2, name=f"{name}_pool_{i}")(var_x) + return var_x + + +class FullyConnected(): + """ Intermediate Fully Connected layers for Phaze-A Model. + + Parameters + ---------- + side: ["a", "b", "both", "gblock", "shared"] + The side of the model that the fully connected layers belong to. Used for naming + input_shape: tuple + The input shape for the fully connected layers + """ + def __init__(self, + side: T.Literal["a", "b", "both", "gblock", "shared"], + input_shape: tuple) -> None: + logger.debug(parse_class_init(locals())) + self._side = side + self._input_shape = input_shape + self._final_dims = cfg.fc_dimensions() * (cfg.fc_upsamples() + 1) + self._prefix = "fc_gblock" if self._side == "gblock" else "fc" + + logger.debug("Initialized: %s (side: %s, min_nodes: %s, max_nodes: %s)", + self.__class__.__name__, self._side, self._min_nodes, self._max_nodes) + + @property + def _min_nodes(self) -> int: + """ int: The number of nodes for the first Dense. For non g-block layers this will be the + given minimum filters multiplied by the dimensions squared. For g-block layers, this is the + given value """ + if self._side == "gblock": + return cfg.fc_gblock_min_nodes() + retval = self._scale_filters(cfg.fc_min_filters()) + retval = int(retval * cfg.fc_dimensions() ** 2) + return retval + + @property + def _max_nodes(self) -> int: + """ int: The number of nodes for the final Dense. For non g-block layers this will be the + given maximum filters multiplied by the dimensions squared. This number will be scaled down + if the final shape can not be mapped to the requested output size. + + For g-block layers, this is the given config value. + """ + if self._side == "gblock": + return cfg.fc_gblock_max_nodes() + retval = self._scale_filters(cfg.fc_max_filters()) + retval = int(retval * cfg.fc_dimensions() ** 2) + return retval + + def _scale_filters(self, original_filters: int) -> int: + """ Scale the filters to be compatible with the model's selected output size. + + Parameters + ---------- + original_filters: int + The original user selected number of filters + + Returns + ------- + int + The number of filters scaled down for output size + """ + scaled_dim = _scale_dim(cfg.output_size(), self._final_dims) + if scaled_dim == self._final_dims: + logger.debug("filters don't require scaling. Returning: %s", original_filters) + return original_filters + + flat = self._final_dims ** 2 * original_filters + modifier = self._final_dims ** 2 * scaled_dim ** 2 + retval = int((flat // modifier) * modifier) + retval = int(retval / self._final_dims ** 2) + logger.debug("original_filters: %s, scaled_filters: %s", original_filters, retval) + return retval + + def _do_upsampling(self, inputs: KerasTensor) -> KerasTensor: + """ Perform the upsampling at the end of the fully connected layers. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input to the upsample layers + + Returns + ------- + :class:`keras.KerasTensor` + The output from the upsample layers + """ + upsample_filts = self._scale_filters(cfg.fc_upsample_filters()) + upsampler = T.cast(T.Literal["resize_images", "subpixel", "upscale_dny", "upscale_fast", + "upscale_hybrid", "upsample2d"], + cfg.fc_upsampler().lower()) + num_upsamples = cfg.fc_upsamples() + var_x = inputs + if upsampler == "upsample2d" and num_upsamples > 1: + upscaler = _get_upscale_layer(upsampler, + upsample_filts, # Not used but required + upsamples=2 ** num_upsamples, + interpolation="bilinear") + var_x = upscaler(var_x) + else: + for _ in range(num_upsamples): + upscaler = _get_upscale_layer(upsampler, + upsample_filts, + activation="leakyrelu") + var_x = upscaler(var_x) + if upsampler == "upsample2d": + var_x = kl.LeakyReLU(negative_slope=0.1)(var_x) + return var_x + + def __call__(self) -> keras.models.Model: + """ Call the intermediate layer. + + Returns + ------- + :class:`keras.models.Model` + The Fully connected model + """ + input_ = kl.Input(shape=self._input_shape) + var_x = T.cast("KerasTensor", input_) + + node_curve = _get_curve(self._min_nodes, + self._max_nodes, + getattr(cfg, f"{self._prefix}_depth")(), + getattr(cfg, f"{self._prefix}_filter_slope")()) + + if not cfg.bottleneck_in_encoder(): + var_x = _bottleneck(var_x, + cfg.bottleneck_type(), + cfg.bottleneck_size(), + cfg.bottleneck_norm()) + + dropout = getattr(cfg, f"{self._prefix}_dropout")() + for idx, nodes in enumerate(node_curve): + var_x = kl.Dropout(dropout, name=f"{dropout}_{idx + 1}")(var_x) + var_x = kl.Dense(nodes)(var_x) + + if self._side != "gblock": + dim = cfg.fc_dimensions() + var_x = kl.Reshape((dim, dim, int(self._max_nodes / (dim ** 2))))(var_x) + var_x = self._do_upsampling(var_x) + + num_upscales = cfg.dec_upscales_in_fc() + if num_upscales: + var_x = UpscaleBlocks(self._side, + layer_indicies=(0, num_upscales))(var_x) + + return keras.models.Model(input_, var_x, name=f"fc_{self._side}") + + +class UpscaleBlocks(): + """ Obtain a block of upscalers. + + This class exists outside of the :class:`Decoder` model, as it is possible to place some of + the upscalers at the end of the Fully Connected Layers, so the upscale chain needs to be able + to be calculated by both the Fully Connected Layers and by the Decoder if required. + + For this reason, the Upscale Filter list is created as a class attribute of the + :class:`UpscaleBlocks` layers for reference by either the Decoder or Fully Connected models + + Parameters + ---------- + side: ["a", "b", "both", "shared"] + The side of the model that the Decoder belongs to. Used for naming + layer_indices: tuple, optional + The tuple indicies indicating the starting layer index and the ending layer index to + generate upscales for. Used for when splitting upscales between the Fully Connected Layers + and the Decoder. ``None`` will generate the full Upscale chain. An end index of -1 will + generate the layers from the starting index to the final upscale. Default: ``None`` + """ + _filters: list[int] = [] + + def __init__(self, + side: T.Literal["a", "b", "both", "shared"], + layer_indicies: tuple[int, int] | None = None) -> None: + logger.debug(parse_class_init(locals())) + self._side = side + self._is_dny = cfg.dec_upscale_method().lower() == "upscale_dny" + self._layer_indicies = layer_indicies + logger.debug("Initialized: %s", self.__class__.__name__,) + + def _reshape_for_output(self, inputs: KerasTensor) -> KerasTensor: + """ Reshape the input for arbitrary output sizes. + + The number of filters in the input will have been scaled to the model output size allowing + us to scale the dimensions to the requested output size. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The tensor that is to be reshaped + + Returns + ------- + :class:`keras.KerasTensor` + The tensor shaped correctly to upscale to output size + """ + var_x = inputs + old_dim = inputs.shape[1] + new_dim = _scale_dim(cfg.output_size(), old_dim) + if new_dim != old_dim: + old_shape = inputs.shape[1:] + new_shape = (new_dim, new_dim, np.prod(old_shape) // new_dim ** 2) + logger.debug("Reshaping tensor from %s to %s for output size %s", + inputs.shape[1:], new_shape, cfg.output_size()) + var_x = kl.Reshape(new_shape)(var_x) + return var_x + + def _upscale_block(self, + inputs: KerasTensor, + filters: int, + skip_residual: bool = False, + is_mask: bool = False) -> KerasTensor: + """ Upscale block for Phaze-A Decoder. + + Uses requested upscale method, adds requested regularization and activation function. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor for the upscale block + filters: int + The number of filters to use for the upscale + skip_residual: bool, optional + ``True`` if a residual block should not be placed in the upscale block, otherwise + ``False``. Default ``False`` + is_mask: bool, optional + ``True`` if the input is a mask. ``False`` if the input is a face. Default: ``False`` + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the upscale block + """ + upscaler = _get_upscale_layer(T.cast(T.Literal["resize_images", "subpixel", "upscale_dny", + "upscale_fast", "upscale_hybrid", + "upsample2d"], + cfg.dec_upscale_method()), + filters, + activation="leakyrelu", + upsamples=2, + interpolation="bilinear") + + var_x = upscaler(inputs) + if not is_mask and cfg.dec_gaussian(): + var_x = kl.GaussianNoise(1.0)(var_x) + if not is_mask and cfg.dec_res_blocks() and not skip_residual: + var_x = self._normalization(var_x) + var_x = kl.LeakyReLU(negative_slope=0.2)(var_x) + for _ in range(cfg.dec_res_blocks()): + var_x = ResidualBlock(filters)(var_x) + else: + var_x = self._normalization(var_x) + if not self._is_dny: + var_x = kl.LeakyReLU(negative_slope=0.1)(var_x) + return var_x + + def _normalization(self, inputs: KerasTensor) -> KerasTensor: + """ Add a normalization layer if requested. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to apply normalization to. + + Returns + -------- + :class:`keras.KerasTensor` + The tensor with any normalization applied + """ + dec_norm: str | None = cfg.dec_norm() + dec_norm = None if dec_norm == "none" else dec_norm + if not dec_norm: + return inputs + norms = {"batch": kl.BatchNormalization, + "group": GroupNormalization, + "instance": InstanceNormalization, + "layer": kl.LayerNormalization, + "rms": RMSNormalization} + return norms[dec_norm]()(inputs) + + def _dny_entry(self, inputs: KerasTensor) -> KerasTensor: + """ Entry convolutions for using the upscale_dny method. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The inputs to the dny entry block + + Returns + ------- + :class:`keras.KerasTensor` + The output from the dny entry block + """ + var_x = Conv2DBlock(cfg.dec_max_filters(), + kernel_size=4, + strides=1, + padding="same", + relu_alpha=0.2)(inputs) + var_x = Conv2DBlock(cfg.dec_max_filters(), + kernel_size=3, + strides=1, + padding="same", + relu_alpha=0.2)(var_x) + return var_x + + def __call__(self, inputs: KerasTensor | list[KerasTensor]) -> KerasTensor | list[KerasTensor]: + """ Upscale Network. + + Parameters + inputs: :class:`keras.KerasTensor` | list[:class:`keras.KerasTensor`] + Input tensor(s) to upscale block. This will be a single tensor if learn mask is not + selected or if this is the first call to the upscale blocks. If learn mask is selected + and this is not the first call to upscale blocks, then this will be a list of the face + and mask tensors. + + Returns + ------- + :class:`keras.KerasTensor` | list[:class:`keras.KerasTensor`] + The output of encoder blocks. Either a single tensor (if learn mask is not enabled) or + list of tensors (if learn mask is enabled) + """ + start_idx, end_idx = (0, None) if self._layer_indicies is None else self._layer_indicies + end_idx = None if end_idx == -1 else end_idx + + var_x: KerasTensor + var_y: KerasTensor + if cfg_loss.learn_mask() and start_idx == 0: + # Mask needs to be created + var_x = inputs + var_y = inputs + elif cfg_loss.learn_mask(): + # Mask has already been created and is an input to upscale blocks + var_x, var_y = inputs + else: + # No mask required + var_x = inputs + + if start_idx == 0: + var_x = self._reshape_for_output(var_x) + + if cfg_loss.learn_mask(): + var_y = self._reshape_for_output(var_y) + + if self._is_dny: + var_x = self._dny_entry(var_x) + if self._is_dny and cfg_loss.learn_mask(): + var_y = self._dny_entry(var_y) + + # De-convolve + if not self._filters: + upscales = int(np.log2(cfg.output_size() / var_x.shape[1])) + self._filters.extend(_get_curve(cfg.dec_max_filters(), + cfg.dec_min_filters(), + upscales, + cfg.dec_filter_slope(), + mode=T.cast(T.Literal["full", "cap_min", "cap_max"], + cfg.dec_slope_mode()))) + logger.debug("Generated class filters: %s", self._filters) + + filters = self._filters[start_idx: end_idx] + + for idx, filts in enumerate(filters): + skip_res = idx == len(filters) - 1 and cfg.dec_skip_last_residual() + var_x = self._upscale_block(var_x, filts, skip_residual=skip_res) + if cfg_loss.learn_mask(): + var_y = self._upscale_block(var_y, filts, is_mask=True) + retval = [var_x, var_y] if cfg_loss.learn_mask() else var_x + return retval + + +class GBlock(): + """ G-Block model, borrowing from Adain StyleGAN. + + Parameters + ---------- + side: ["a", "b", "both"] + The side of the model that the fully connected layers belong to. Used for naming + input_shapes: list or tuple + The shape tuples for the input to the G-Block. The first item is the input from each side's + fully connected model, the second item is the input shape from the combined fully connected + model. + """ + def __init__(self, side: T.Literal["a", "b", "both"], input_shapes: list | tuple) -> None: + logger.debug(parse_class_init(locals())) + self._side = side + self._inputs = [kl.Input(shape=shape) for shape in input_shapes] + self._dense_nodes = 512 + self._dense_recursions = 3 + logger.debug("Initialized: %s", self.__class__.__name__) + + @classmethod + def _g_block(cls, + inputs: KerasTensor, + style: KerasTensor, + filters: int, + recursions: int = 2) -> KerasTensor: + """ G_block adapted from ADAIN StyleGAN. + + Parameters + ---------- + inputs: :class:`keras.KerasTensor` + The input tensor to the G-Block model + style: :class:`keras.KerasTensor` + The input combined 'style' tensor to the G-Block model + filters: int + The number of filters to use for the G-Block Convolutional layers + recursions: int, optional + The number of recursive Convolutions to process. Default: `2` + + Returns + ------- + :class:`keras.KerasTensor` + The output tensor from the G-Block model + """ + var_x = inputs + for i in range(recursions): + styles = [kl.Reshape([1, 1, filters])(kl.Dense(filters)(style)) for _ in range(2)] + noise = kl.Conv2D(filters, 1, padding="same")(kl.GaussianNoise(1.0)(var_x)) + + if i == recursions - 1: + var_x = kl.Conv2D(filters, 3, padding="same")(var_x) + + var_x = AdaInstanceNormalization(dtype="float32")([var_x, *styles]) + var_x = kl.Add()([var_x, noise]) + var_x = kl.LeakyReLU(0.2)(var_x) + + return var_x + + def __call__(self) -> keras.models.Model: + """ G-Block Network. + + Returns + ------- + :class:`keras.models.Model` + The G-Block model + """ + var_x, style = self._inputs + for i in range(self._dense_recursions): + style = kl.Dense(self._dense_nodes, kernel_initializer="he_normal")(style) + if i != self._dense_recursions - 1: # Don't add leakyReLu to final output + style = kl.LeakyReLU(0.1)(style) + + # Scale g_block filters to side dense + g_filts = var_x.shape[-1] + var_x = Conv2D(g_filts, 3, strides=1, padding="same")(var_x) + var_x = kl.GaussianNoise(1.0)(var_x) + var_x = self._g_block(var_x, style, g_filts) + return keras.models.Model(self._inputs, var_x, name=f"g_block_{self._side}") + + +class Decoder(): + """ Decoder Network. + + Parameters + ---------- + side: ["a", "b", "both"] + The side of the model that the Decoder belongs to. Used for naming + input_shape: tuple + The shape tuple for the input to the decoder. + """ + def __init__(self, + side: T.Literal["a", "b", "both"], + input_shape: tuple[int, int, int]) -> None: + logger.debug(parse_class_init(locals())) + self._side = side + self._input_shape = input_shape + logger.debug("Initialized: %s", self.__class__.__name__,) + + def __call__(self) -> keras.models.Model: + """ Decoder Network. + + Returns + ------- + :class:`keras.models.Model` + The Decoder model + """ + inputs = T.cast("KerasTensor", kl.Input(shape=self._input_shape)) + + num_ups_in_fc = cfg.dec_upscales_in_fc() + + if cfg_loss.learn_mask() and num_ups_in_fc: + # Mask has already been created in FC and is an output of that model + inputs = [inputs, kl.Input(shape=self._input_shape)] + + indicies = None if not num_ups_in_fc else (num_ups_in_fc, -1) + upscales = UpscaleBlocks(self._side, layer_indicies=indicies)(inputs) + + if cfg_loss.learn_mask(): + var_x, var_y = upscales + else: + var_x = upscales + + outputs = [Conv2DOutput(3, cfg.dec_output_kernel(), name="face_out")(var_x)] + if cfg_loss.learn_mask(): + outputs.append(Conv2DOutput(1, + cfg.dec_output_kernel(), + name="mask_out")(var_y)) + + return keras.models.Model(inputs, outputs=outputs, name=f"decoder_{self._side}") diff --git a/plugins/train/model/phaze_a_defaults.py b/plugins/train/model/phaze_a_defaults.py new file mode 100644 index 0000000000..0650d79064 --- /dev/null +++ b/plugins/train/model/phaze_a_defaults.py @@ -0,0 +1,714 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Phaze-A Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +from lib.config import ConfigItem + + +HELPTEXT: str = ( + "Phaze-A Model by TorzDF, with thanks to BirbFakes.\n" + "Allows for the experimentation of various standard Networks as the encoder and takes " + "inspiration from Nvidia's StyleGAN for the Decoder. It is highly recommended to research to " + "understand the parameters better.") + +_ENCODERS: list[str] = sorted([ + "clipv_vit-b-16", "clipv_vit-b-32", "clipv_vit-l-14", "clipv_vit-l-14-336px", + "clipv_farl-b-16-16", "clipv_farl-b-16-64", + "convnext_tiny", "convnext_small", "convnext_base", "convnext_large", "convnext_extra_large", + "densenet121", "densenet169", "densenet201", "efficientnet_b0", "efficientnet_b1", + "efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", + "efficientnet_b7", "efficientnet_v2_b0", "efficientnet_v2_b1", "efficientnet_v2_b2", + "efficientnet_v2_b3", "efficientnet_v2_l", "efficientnet_v2_m", "efficientnet_v2_s", + "inception_resnet_v2", "inception_v3", "mobilenet", "mobilenet_v2", "mobilenet_v3_large", + "mobilenet_v3_small", "nasnet_large", "nasnet_mobile", "resnet50", "resnet50_v2", "resnet101", + "resnet101_v2", "resnet152", "resnet152_v2", "vgg16", "vgg19", "xception", "fs_original"]) + + +# General +output_size = ConfigItem( + datatype=int, + default=128, + group="general", + info="Resolution (in pixels) of the output image to generate.\n" + "BE AWARE Larger resolution will dramatically increase VRAM requirements.", + rounding=16, + min_max=(64, 2048), + fixed=True) + +shared_fc = ConfigItem( + datatype=str, + default="none", + group="general", + info="Whether to create a shared fully connected layer. This layer will have the same " + "structure as the fully connected layers used for each side of the model. A shared " + "fully connected layer looks for patterns that are common to both sides. NB: " + "Enabling this option only makes sense if 'split fc' is selected." + "\n\tnone - Do not create a Fully Connected layer for shared data. (Original method)" + "\n\tfull - Create an exclusive Fully Connected layer for shared data. (IAE method)" + "\n\thalf - Use the 'fc_a' layer for shared data. This saves VRAM by re-using the " + "'A' side's fully connected model for the shared data. However, this will lead to " + "an 'unbalanced' model and can lead to more identity bleed (DFL method)", + choices=["none", "full", "half"], + gui_radio=True, + fixed=True) + +enable_gblock = ConfigItem( + datatype=bool, + default=True, + group="general", + info="Whether to enable the G-Block. If enabled, this will create a shared fully " + "connected layer (configurable in the 'G-Block hidden layers' section) to look for " + "patterns in the combined data, before feeding a block prior to the decoder for " + "merging this shared and combined data." + "\n\tTrue - Use the G-Block in the Decoder. A combined fully connected layer will be " + "created to feed this block which can be configured below." + "\n\tFalse - Don't use the G-Block in the decoder. No combined fully connected layer " + "will be created.", + fixed=True) + +split_fc = ConfigItem( + datatype=bool, + default=True, + group="general", + info="Whether to use a single shared Fully Connected layer or separate Fully Connected " + "layers for each side." + "\n\tTrue - Use separate Fully Connected layers for Face A and Face B. This is more " + "similar to the 'IAE' style of model." + "\n\tFalse - Use combined Fully Connected layers for both sides. This is more " + "similar to the original Faceswap architecture.", + fixed=True) + +split_gblock = ConfigItem( + datatype=bool, + default=False, + group="general", + info="If the G-Block is enabled, Whether to use a single G-Block shared between both " + "sides, or whether to have a separate G-Block (one for each side). NB: The Fully " + "Connected layer that feeds the G-Block will always be shared." + "\n\tTrue - Use separate G-Blocks for Face A and Face B." + "\n\tFalse - Use a combined G-Block layers for both sides.", + fixed=True) + +split_decoders = ConfigItem( + datatype=bool, + default=False, + group="general", + info="Whether to use a single decoder or split decoders." + "\n\tTrue - Use a separate decoder for Face A and Face B. This is more similar to " + "the original Faceswap architecture." + "\n\tFalse - Use a combined Decoder. This is more similar to 'IAE' style " + "architecture.", + fixed=True) + +# Encoder +enc_architecture = ConfigItem( + datatype=str, + default="fs_original", + group="encoder", + info="The encoder architecture to use. See the relevant config sections for specific " + "architecture tweaking.\nNB: For keras based pre-built models, the global " + "initializers and padding options will be ignored for the selected encoder." + "\n\n\tCLIPv: This is an implementation of the Visual encoder from the CLIP " + "transformer. The ViT weights are trained on imagenet whilst the FaRL weights are " + "trained on face related tasks. All have a default input size of 224px except for " + "ViT-L-14-336px that has an input size of 336px. Ref: Learning Transferable Visual " + "Models From Natural Language Supervision (2021): https://arxiv.org/abs/2103.00020" + "\n\n\tconvnext: There are 6 varations of increasing complexity. All have a default " + "input size of 224px. Ref: A ConvNet for the 2020s (2022): " + "https://arxiv.org/abs/1608.06993" + "\n\n\tdensenet: (32px-224px). Ref: Densely Connected Convolutional Networks " + "(2016): https://arxiv.org/abs/1608.06993" + "\n\n\tefficientnet: EfficientNet has numerous variants (B0 -B8) that increases the " + "model width, depth and dimensional space at each step. The minimum input resolution " + "is 32px for all variants. The maximum input resolution for each variant is: b0: " + "224px, b1: 240px, b2: 260px, b3: 300px, b4: 380px, b5: 456px, b6: 528px, b7 600px. " + "Ref: Rethinking Model Scaling for Convolutional Neural Networks (2020): " + "https://arxiv.org/abs/1905.11946" + "\n\n\tefficientnet_v2: EfficientNetV2 is the follow up to efficientnet. It has " + "numerous variants (B0 - B3 and Small, Medium and Large) that increases the model " + "width, depth and dimensional space at each step. The minimum input resolution is " + "32px for all variants. The maximum input resolution for each variant is: b0: 224px, " + "b1: 240px, b2: 260px, b3: 300px, s: 384px, m: 480px, l: 480px. Ref: EfficientNetV2: " + "Smaller Models and Faster Training (2021): https://arxiv.org/abs/2104.00298" + "\n\n\tfs_original: (32px - 1024px). A configurable variant of the original facewap " + "encoder. ImageNet weights cannot be loaded for this model. Additional parameters " + "can be configured with the 'fs_enc' options. A version of this encoder is used in " + "the following models: Original, Original (lowmem), Dfaker, DFL-H128, DFL-SAE, IAE, " + "Lightweight." + "\n\n\tinception_resnet_v2: (75px - 299px). Ref: Inception-ResNet and the Impact of " + "Residual Connections on Learning (2016): https://arxiv.org/abs/1602.07261" + "\n\n\tinceptionV3: (75px - 299px). Ref: Rethinking the Inception Architecture for " + "Computer Vision (2015): https://arxiv.org/abs/1512.00567" + "\n\n\tmobilenet: (32px - 224px). Additional MobileNet parameters can be set with " + "the 'mobilenet' options. Ref: MobileNets: Efficient Convolutional Neural Networks " + "for Mobile Vision Applications (2017): https://arxiv.org/abs/1704.04861" + "\n\n\tmobilenet_v2: (32px - 224px). Additional MobileNet parameters can be set with " + "the 'mobilenet' options. Ref: MobileNetV2: Inverted Residuals and Linear " + "Bottlenecks (2018): https://arxiv.org/abs/1801.04381" + "\n\n\tmobilenet_v3: (32px - 224px). Additional MobileNet parameters can be set with " + "the 'mobilenet' options. Ref: Searching for MobileNetV3 (2019): " + "https://arxiv.org/pdf/1905.02244.pdf" + "\n\n\tnasnet: (32px - 331px (large) or 224px (mobile)). Ref: Learning Transferable " + "Architectures for Scalable Image Recognition (2017): " + "https://arxiv.org/abs/1707.07012" + "\n\n\tresnet: (32px - 224px). Deep Residual Learning for Image Recognition (2015): " + "https://arxiv.org/abs/1512.03385" + "\n\n\tvgg: (32px - 224px). Very Deep Convolutional Networks for Large-Scale Image " + "Recognition (2014): https://arxiv.org/abs/1409.1556" + "\n\n\txception: (71px - 229px). Ref: Deep Learning with Depthwise Separable " + "Convolutions (2017): https://arxiv.org/abs/1409.1556.\n", + choices=_ENCODERS, + gui_radio=False, + fixed=True) + +enc_scaling = ConfigItem( + datatype=int, + default=7, + group="encoder", + info="Input scaling for the encoder. Some of the encoders have large input sizes, which " + "often are not helpful for Faceswap. This setting scales the dimensional space that " + "the encoder works in. For example an encoder with a maximum input size of 224px " + "will be input an image of 112px at 50%% scaling. See the Architecture tooltip for " + "the minimum and maximum sizes for each encoder. NB: The input size will be rounded " + "down to the nearest 16 pixels.", + min_max=(0, 200), + rounding=1, + fixed=True) + +enc_load_weights = ConfigItem( + datatype=bool, + default=True, + group="encoder", + info="Load pre-trained weights trained on ImageNet data. Only available for non-" + "Faceswap encoders (i.e. those not beginning with 'fs'). NB: If you use the global " + "'load weights' option and have selected to load weights from a previous model's " + "'encoder' or 'keras_encoder' then the weights loaded here will be replaced by the " + "weights loaded from your saved model.", + fixed=True) + +# Bottleneck +bottleneck_type = ConfigItem( + datatype=str, + default="dense", + group="bottleneck", + info="The type of layer to use for the bottleneck." + "\n\taverage_pooling: Use a Global Average Pooling 2D layer for the bottleneck." + "\n\tdense: Use a Dense layer for the bottleneck (the traditional Faceswap method). " + "You can set the size of the Dense layer with the 'bottleneck_size' parameter." + "\n\tmax_pooling: Use a Global Max Pooling 2D layer for the bottleneck." + "\n\flatten: Don't use a bottleneck at all. Some encoders output in a size that make " + "a bottleneck unnecessary. This option flattens the output from the encoder, with no " + "further operations", + gui_radio=True, + choices=["average_pooling", "dense", "max_pooling", "flatten"], + fixed=True) + +bottleneck_norm = ConfigItem( + datatype=str, + default="none", + group="bottleneck", + info="Apply a normalization layer after encoder output and prior to the bottleneck." + "\n\tnone - Do not apply a normalization layer" + "\n\tinstance - Apply Instance Normalization" + "\n\tlayer - Apply Layer Normalization (Ba et al., 2016)" + "\n\trms - Apply Root Mean Squared Layer Normalization (Zhang et al., 2019). A " + "simplified version of Layer Normalization with reduced overhead.", + gui_radio=True, + choices=["none", "instance", "layer", "rms"], + fixed=True) + +bottleneck_size = ConfigItem( + datatype=int, + default=1024, + group="bottleneck", + info="If using a Dense layer for the bottleneck, then this is the number of nodes to " + "use.", + rounding=128, + min_max=(128, 4096), + fixed=True) + +bottleneck_in_encoder = ConfigItem( + datatype=bool, + default=True, + group="bottleneck", + info="Whether to place the bottleneck in the Encoder or to place it with the other " + "hidden layers. Placing the bottleneck in the encoder means that both sides will " + "share the same bottleneck. Placing it with the other fully connected layers means " + "that each fully connected layer will each get their own bottleneck. This may be " + "combined or split depending on your overall architecture configuration settings.", + fixed=True) + +# Intermediate Layers +fc_depth = ConfigItem( + datatype=int, + default=1, + group="hidden layers", + info="The number of consecutive Dense (fully connected) layers to include in each " + "side's intermediate layer.", + rounding=1, + min_max=(0, 16), + fixed=True) + +fc_min_filters = ConfigItem( + datatype=int, + default=1024, + group="hidden layers", + info="The number of filters to use for the initial fully connected layer. The number of " + "nodes actually used is: fc_min_filters x fc_dimensions x fc_dimensions.\nNB: This " + "value may be scaled down, depending on output resolution.", + rounding=16, + min_max=(16, 5120), + fixed=True) + +fc_max_filters = ConfigItem( + datatype=int, + default=1024, + group="hidden layers", + info="This is the number of filters to be used in the final reshape layer at the end of " + "the fully connected layers. The actual number of nodes used for the final fully " + "connected layer is: fc_min_filters x fc_dimensions x fc_dimensions.\nNB: This value " + "may be scaled down, depending on output resolution.", + rounding=64, + min_max=(128, 5120), + fixed=True) + +fc_dimensions = ConfigItem( + datatype=int, + default=4, + group="hidden layers", + info="The height and width dimension for the final reshape layer at the end of the " + "fully connected layers.\nNB: The total number of nodes within the final fully " + "connected layer will be: fc_dimensions x fc_dimensions x fc_max_filters.", + rounding=1, + min_max=(1, 16), + fixed=True) + +fc_filter_slope = ConfigItem( + datatype=float, + default=-0.5, + group="hidden layers", + info="The rate that the filters move from the minimum number of filters to the maximum " + "number of filters. EG:\n" + "Negative numbers will change the number of filters quicker at first and slow down " + "each layer.\n" + "Positive numbers will change the number of filters slower at first but then speed " + "up each layer.\n" + "0.0 - This will change at a linear rate (i.e. the same number of filters will be " + "changed at each layer).", + min_max=(-.99, .99), + rounding=2, + fixed=True) + +fc_dropout = ConfigItem( + datatype=float, + default=0.0, + group="hidden layers", + info="Dropout is a form of regularization that can prevent a model from over-fitting " + "and help to keep neurons 'alive'. 0.5 will dropout half the connections between " + "each fully connected layer, 0.25 will dropout a quarter of the connections etc. Set " + "to 0.0 to disable.", + rounding=2, + min_max=(0.0, 0.99), + fixed=False) + +fc_upsampler = ConfigItem( + datatype=str, + default="upsample2d", + group="hidden layers", + info="The type of dimensional upsampling to perform at the end of the fully connected " + "layers, if upsamples > 0. The number of filters used for the upscale layers will be " + "the value given in 'fc_upsample_filters'." + "\n\tupsample2d - A lightweight and VRAM friendly method. 'quick and dirty' but does " + "not learn any parameters" + "\n\tsubpixel - Sub-pixel upscaler using depth-to-space which may require more " + "VRAM." + "\n\tresize_images - Uses the Keras resize_image function to save about half as much " + "vram as the heaviest methods." + "\n\tupscale_fast - Developed by Andenixa. Focusses on speed to upscale, but " + "requires more VRAM." + "\n\tupscale_hybrid - Developed by Andenixa. Uses a combination of PixelShuffler and " + "Upsampling2D to upscale, saving about 1/3rd of VRAM of the heaviest methods.", + choices=["resize_images", "subpixel", "upscale_fast", "upscale_hybrid", "upsample2d"], + gui_radio=False, + fixed=True) + +fc_upsamples = ConfigItem( + datatype=int, + default=1, + group="hidden layers", + info="Some upsampling can occur within the Fully Connected layers rather than in the " + "Decoder to increase the dimensional space. Set how many upscale layers should occur " + "within the Fully Connected layers.", + min_max=(0, 4), + rounding=1, + fixed=True) + +fc_upsample_filters = ConfigItem( + datatype=int, + default=512, + group="hidden layers", + info="If you have selected an upsampler which requires filters (i.e. any upsampler with " + "the exception of Upsampling2D), then this is the number of filters to be used for " + "the upsamplers within the fully connected layers, NB: This value may be scaled " + "down, depending on output resolution. Also note, that this figure will dictate the " + "number of filters used for the G-Block, if selected.", + rounding=64, + min_max=(128, 5120), + fixed=True) + +# G-Block +fc_gblock_depth = ConfigItem( + datatype=int, + default=3, + group="g-block hidden layers", + info="The number of consecutive Dense (fully connected) layers to include in the " + "G-Block shared layer.", + rounding=1, + min_max=(1, 16), + fixed=True) + +fc_gblock_min_nodes = ConfigItem( + datatype=int, + default=512, + group="g-block hidden layers", + info="The number of nodes to use for the initial G-Block shared fully connected layer.", + rounding=64, + min_max=(128, 5120), + fixed=True) + +fc_gblock_max_nodes = ConfigItem( + datatype=int, + default=512, + group="g-block hidden layers", + info="The number of nodes to use for the final G-Block shared fully connected layer.", + rounding=64, + min_max=(128, 5120), + fixed=True) + +fc_gblock_filter_slope = ConfigItem( + datatype=float, + default=-0.5, + group="g-block hidden layers", + info="The rate that the filters move from the minimum number of filters to the maximum " + "number of filters for the G-Block shared layers. EG:\n" + "Negative numbers will change the number of filters quicker at first and slow down " + "each layer.\n" + "Positive numbers will change the number of filters slower at first but then speed " + "up each layer.\n" + "0.0 - This will change at a linear rate (i.e. the same number of filters will be " + "changed at each layer).", + min_max=(-.99, .99), + rounding=2, + fixed=True) + +fc_gblock_dropout = ConfigItem( + datatype=float, + default=0.0, + group="g-block hidden layers", + info="Dropout is a regularization technique that can prevent a model from over-fitting " + "and help to keep neurons 'alive'. 0.5 will dropout half the connections between " + "each fully connected layer, 0.25 will dropout a quarter of the connections etc. Set " + "to 0.0 to disable.", + rounding=2, + min_max=(0.0, 0.99), + fixed=False) + +# Decoder +dec_upscale_method = ConfigItem( + datatype=str, + default="subpixel", + group="decoder", + info="The method to use for the upscales within the decoder. Images are upscaled " + "multiple times within the decoder as the network learns to reconstruct the face." + "\n\tsubpixel - Sub-pixel upscaler using depth-to-space which requires more " + "VRAM." + "\n\tresize_images - Uses the Keras resize_image function to save about half as much " + "vram as the heaviest methods." + "\n\tupscale_fast - Developed by Andenixa. Focusses on speed to upscale, but " + "requires more VRAM." + "\n\tupscale_hybrid - Developed by Andenixa. Uses a combination of PixelShuffler and " + "Upsampling2D to upscale, saving about 1/3rd of VRAM of the heaviest methods." + "\n\tupscale_dny - An alternative upscale implementation using Upsampling2D to " + "upsale.", + choices=["subpixel", "resize_images", "upscale_fast", "upscale_hybrid", "upscale_dny"], + gui_radio=True, + fixed=True) + +dec_upscales_in_fc = ConfigItem( + datatype=int, + default=0, + min_max=(0, 6), + rounding=1, + group="decoder", + info="It is possible to place some of the upscales at the end of the fully connected " + "model. For models with split decoders, but a shared fully connected layer, this " + "would have the effect of saving some VRAM but possibly at the cost of introducing " + "artefacts. For models with a shared decoder but split fully connected layers, this " + "would have the effect of increasing VRAM usage by processing some of the upscales " + "for each side rather than together.", + fixed=True) + +dec_norm = ConfigItem( + datatype=str, + default="none", + group="decoder", + info="Normalization to apply to apply after each upscale." + "\n\tnone - Do not apply a normalization layer" + "\n\tbatch - Apply Batch Normalization" + "\n\tgroup - Apply Group Normalization" + "\n\tinstance - Apply Instance Normalization" + "\n\tlayer - Apply Layer Normalization (Ba et al., 2016)" + "\n\trms - Apply Root Mean Squared Layer Normalization (Zhang et al., 2019). A " + "simplified version of Layer Normalization with reduced overhead.", + gui_radio=True, + choices=["none", "batch", "group", "instance", "layer", "rms"], + fixed=True) + +dec_min_filters = ConfigItem( + datatype=int, + default=64, + group="decoder", + info="The minimum number of filters to use in decoder upscalers (i.e. the number of " + "filters to use for the final upscale layer).", + min_max=(16, 512), + rounding=16, + fixed=True) + +dec_max_filters = ConfigItem( + datatype=int, + default=512, + group="decoder", + info="The maximum number of filters to use in decoder upscalers (i.e. the number of " + "filters to use for the first upscale layer).", + min_max=(256, 5120), + rounding=64, + fixed=True) + +dec_slope_mode = ConfigItem( + datatype=str, + default="full", + group="decoder", + info="Alters the action of the filter slope.\n" + "\n\tfull: The number of filters at each upscale layer will reduce from the chosen " + "max_filters at the first layer to the chosen min_filters at the last layer as " + "dictated by the dec_filter_slope." + "\n\tcap_max: The filters will decline at a fixed rate from each upscale to the next " + "based on the filter_slope setting. If there are more upscales than filters, " + "then the earliest upscales will be capped at the max_filter value until the filters " + "can reduce to the min_filters value at the final upscale. (EG: 512 -> 512 -> 512 -> " + "256 -> 128 -> 64)." + "\n\tcap_min: The filters will decline at a fixed rate from each upscale to the next " + "based on the filter_slope setting. If there are more upscales than filters, then " + "the earliest upscales will drop their filters until the min_filter value is met and " + "repeat the min_filter value for the remaining upscales. (EG: 512 -> 256 -> 128 -> " + "64 -> 64 -> 64).", + choices=["full", "cap_max", "cap_min"], + fixed=True, + gui_radio=True) + +dec_filter_slope = ConfigItem( + datatype=float, + default=-0.45, + group="decoder", + info="The rate that the filters reduce at each upscale layer.\n" + "\n\tFull Slope Mode: Negative numbers will drop the number of filters quicker at " + "first and slow down each upscale. Positive numbers will drop the number of filters " + "slower at first but then speed up each upscale. A value of 0.0 will reduce at a " + "linear rate (i.e. the same number of filters will be reduced at each upscale).\n" + "\n\tCap Min/Max Slope Mode: Only positive values will work here. Negative values " + "will automatically be converted to their positive counterpart. A value of 0.5 will " + "halve the number of filters at each upscale until the minimum value is reached. A " + "value of 0.33 will be reduce the number of filters by a third until the minimum " + "value is reached etc.", + min_max=(-.99, .99), + rounding=2, + fixed=True) + +dec_res_blocks = ConfigItem( + datatype=int, + default=1, + group="decoder", + info="The number of Residual Blocks to apply to each upscale layer. Set to 0 to disable " + "residual blocks entirely.", + rounding=1, + min_max=(0, 8), + fixed=True) + +dec_output_kernel = ConfigItem( + datatype=int, + default=5, + group="decoder", + info="The kernel size to apply to the final Convolution layer.", + rounding=2, + min_max=(1, 9), + fixed=True) + +dec_gaussian = ConfigItem( + datatype=bool, + default=True, + group="decoder", + info="Gaussian Noise acts as a regularization technique for preventing overfitting of " + "data." + "\n\tTrue - Apply a Gaussian Noise layer to each upscale." + "\n\tFalse - Don't apply a Gaussian Noise layer to each upscale.", + fixed=True) + +dec_skip_last_residual = ConfigItem( + datatype=bool, + default=True, + group="decoder", + info="If Residual blocks have been enabled, enabling this option will not apply a " + "Residual block to the final upscaler." + "\n\tTrue - Don't apply a Residual block to the final upscale." + "\n\tFalse - Apply a Residual block to all upscale layers.", + fixed=True) + +# Weight management +freeze_layers = ConfigItem( + datatype=list, + default=["keras_encoder"], + group="weights", + info="If the command line option 'freeze-weights' is enabled, then the layers indicated " + "here will be frozen the next time the model starts up. NB: Not all architectures " + "contain all of the layers listed here, so any layers marked for freezing that are " + "not within your chosen architecture will be ignored. EG:\n If 'split fc' has " + "been selected, then 'fc_a' and 'fc_b' are available for freezing. If it has " + "not been selected then 'fc_both' is available for freezing.", + choices=["encoder", "keras_encoder", "fc_a", "fc_b", "fc_both", "fc_shared", + "fc_gblock", "g_block_a", "g_block_b", "g_block_both", "decoder_a", + "decoder_b", "decoder_both"], + fixed=False) + +load_layers = ConfigItem( + datatype=list, + default=["encoder"], + group="weights", + info="If the command line option 'load-weights' is populated, then the layers indicated " + "here will be loaded from the given weights file if starting a new model. NB Not all " + "architectures contain all of the layers listed here, so any layers marked for " + "loading that are not within your chosen architecture will be ignored. EG:\n If " + "'split fc' has been selected, then 'fc_a' and 'fc_b' are available for loading. If " + "it has not been selected then 'fc_both' is available for loading.", + choices=["encoder", "fc_a", "fc_b", "fc_both", "fc_shared", "fc_gblock", "g_block_a", + "g_block_b", "g_block_both", "decoder_a", "decoder_b", "decoder_both"], + fixed=True) + +# # SPECIFIC ENCODER SETTINGS # # +# Faceswap Original +fs_original_depth = ConfigItem( + datatype=int, + default=4, + group="faceswap encoder configuration", + info="Faceswap Encoder only: The number of convolutions to perform within the encoder.", + min_max=(2, 10), + rounding=1, + fixed=True) + +fs_original_min_filters = ConfigItem( + datatype=int, + default=128, + group="faceswap encoder configuration", + info="Faceswap Encoder only: The minumum number of filters to use for encoder " + "convolutions. (i.e. the number of filters to use for the first encoder layer).", + min_max=(16, 2048), + rounding=64, + fixed=True) + +fs_original_max_filters = ConfigItem( + datatype=int, + default=1024, + group="faceswap encoder configuration", + info="Faceswap Encoder only: The maximum number of filters to use for encoder " + "convolutions. (i.e. the number of filters to use for the final encoder layer).", + min_max=(256, 8192), + rounding=128, + fixed=True) + +fs_original_use_alt = ConfigItem( + datatype=bool, + default=False, + group="faceswap encoder configuration", + info="Use a slightly alternate version of the Faceswap Encoder." + "\n\tTrue - Use the alternate variation of the Faceswap Encoder." + "\n\tFalse - Use the original Faceswap Encoder.", + fixed=True) + +# MobileNet +mobilenet_width = ConfigItem( + datatype=float, + default=1.0, + group="mobilenet encoder configuration", + info="The width multiplier for mobilenet encoders. Controls the width of the " + "network. Values less than 1.0 proportionally decrease the number of filters within " + "each layer. Values greater than 1.0 proportionally increase the number of filters " + "within each layer. 1.0 is the default number of layers used within the paper.\n" + "NB: This option is ignored for any non-mobilenet encoders.\n" + "NB: If loading ImageNet weights, then for MobilenetV1 only values of '0.25', " + "'0.5', '0.75' or '1.0 can be selected. For MobilenetV2 only values of '0.35', " + "'0.50', '0.75', '1.0', '1.3' or '1.4' can be selected. For mobilenet_v3 only values " + "of '0.75' or '1.0' can be selected", + min_max=(0.1, 2.0), + rounding=2, + fixed=True) + +mobilenet_depth = ConfigItem( + datatype=int, + default=1, + group="mobilenet encoder configuration", + info="The depth multiplier for MobilenetV1 encoder. This is the depth multiplier " + "for depthwise convolution (known as the resolution multiplier within the original " + "paper).\n" + "NB: This option is only used for MobilenetV1 and is ignored for all other " + "encoders.\n" + "NB: If loading ImageNet weights, this must be set to 1.", + min_max=(1, 10), + rounding=1, + fixed=True) + +mobilenet_dropout = ConfigItem( + datatype=float, + default=0.001, + group="mobilenet encoder configuration", + info="The dropout rate for MobilenetV1 encoder.\n" + "NB: This option is only used for MobilenetV1 and is ignored for all other " + "encoders.", + min_max=(0.001, 2.0), + rounding=3, + fixed=True) + +mobilenet_minimalistic = ConfigItem( + datatype=bool, + default=False, + group="mobilenet encoder configuration", + info="Use a minimilist version of MobilenetV3.\n" + "In addition to large and small models MobilenetV3 also contains so-called " + "minimalistic models, these models have the same per-layer dimensions characteristic " + "as MobilenetV3 however, they don't utilize any of the advanced blocks " + "(squeeze-and-excite units, hard-swish, and 5x5 convolutions). While these models " + "are less efficient on CPU, they are much more performant on GPU/DSP.\n" + "NB: This option is only used for MobilenetV3 and is ignored for all other " + "encoders.\n", + fixed=True) diff --git a/plugins/train/model/realface.py b/plugins/train/model/realface.py new file mode 100644 index 0000000000..c772d63264 --- /dev/null +++ b/plugins/train/model/realface.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +""" RealFaceRC1, codenamed 'Pegasus' + Based on the original https://www.reddit.com/r/deepfakes/ + code sample + contributions + Major thanks goes to BryanLyon as it vastly powered by his ideas and insights. + Without him it would not be possible to come up with the model. + Additional thanks: Birb - source of inspiration, great Encoder ideas + Kvrooman - additional counseling on auto-encoders and practical advice + """ +import logging +import sys + +from keras import initializers, Input, layers, Model as KModel + +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, ResidualBlock, UpscaleBlock +from plugins.train.train_config import Loss as cfg_loss + +from ._base import ModelBase +from . import realface_defaults as cfg +# pylint:disable=duplicate-code + +logger = logging.getLogger(__name__) + + +class Model(ModelBase): + """ RealFace(tm) Faceswap Model """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_shape = (cfg.input_size(), cfg.input_size(), 3) + self.check_input_output() + self.dense_width, self.upscalers_no = self.get_dense_width_upscalers_numbers() + self.kernel_initializer = initializers.RandomNormal(0, 0.02) + + @property + def downscalers_no(self): + """ Number of downscale blocks. Don't change! """ + return 4 + + @property + def _downscale_ratio(self): + """ Downscale Ratio """ + return 2**self.downscalers_no + + @property + def dense_filters(self): + """ Dense Filters. Don't change! """ + return (int(1024 - (self.dense_width - 4) * 64) // 16) * 16 + + def check_input_output(self): + """ Confirm valid input and output sized have been provided """ + if not 64 <= cfg.input_size() <= 128 or cfg.input_size() % 16 != 0: + logger.error("Config error: input_size must be between 64 and 128 and be divisible by " + "16.") + sys.exit(1) + if not 64 <= cfg.output_size() <= 256 or cfg.output_size() % 32 != 0: + logger.error("Config error: output_size must be between 64 and 256 and be divisible " + "by 32.") + sys.exit(1) + logger.debug("Input and output sizes are valid") + + def get_dense_width_upscalers_numbers(self): + """ Return the dense width and number of upscale blocks """ + output_size = cfg.output_size() + sides = [(output_size // 2**n, n) for n in [4, 5] if (output_size // 2**n) < 10] + closest = min([x * self._downscale_ratio for x, _ in sides], + key=lambda x: abs(x - cfg.input_size())) + dense_width, upscalers_no = [(s, n) for s, n in sides + if s * self._downscale_ratio == closest][0] + logger.debug("dense_width: %s, upscalers_no: %s", dense_width, upscalers_no) + return dense_width, upscalers_no + + def build_model(self, inputs): + """ Build the RealFace model. """ + encoder = self.encoder() + encoder_a = encoder(inputs[0]) + encoder_b = encoder(inputs[1]) + + outputs = self.decoder_a()(encoder_a) + self.decoder_b()(encoder_b) + + autoencoder = KModel(inputs, outputs, name=self.model_name) + return autoencoder + + def encoder(self): + """ RealFace Encoder Network """ + input_ = Input(shape=self.input_shape) + var_x = input_ + + encoder_complexity = cfg.complexity_encoder() + + for idx in range(self.downscalers_no - 1): + var_x = Conv2DBlock(encoder_complexity * 2**idx, activation=None)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(encoder_complexity * 2**idx, use_bias=True)(var_x) + var_x = ResidualBlock(encoder_complexity * 2**idx, use_bias=True)(var_x) + + var_x = Conv2DBlock(encoder_complexity * 2**(idx + 1), activation="leakyrelu")(var_x) + + return KModel(input_, var_x, name="encoder") + + def decoder_b(self): + """ RealFace Decoder Network """ + input_filters = cfg.complexity_encoder() * 2**(self.downscalers_no-1) + input_width = cfg.input_size() // self._downscale_ratio + input_ = Input(shape=(input_width, input_width, input_filters)) + + var_xy = input_ + + var_xy = layers.Dense(cfg.dense_nodes())(layers.Flatten()(var_xy)) + var_xy = layers.Dense(self.dense_width * self.dense_width * self.dense_filters)(var_xy) + var_xy = layers.Reshape((self.dense_width, self.dense_width, self.dense_filters))(var_xy) + var_xy = UpscaleBlock(self.dense_filters, activation=None)(var_xy) + + var_x = var_xy + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(self.dense_filters, use_bias=False)(var_x) + + decoder_b_complexity = cfg.complexity_decoder() + for idx in range(self.upscalers_no - 2): + var_x = UpscaleBlock(decoder_b_complexity // 2**idx, activation=None)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(decoder_b_complexity // 2**idx, use_bias=False)(var_x) + var_x = ResidualBlock(decoder_b_complexity // 2**idx, use_bias=True)(var_x) + var_x = UpscaleBlock(decoder_b_complexity // 2**(idx + 1), activation="leakyrelu")(var_x) + + var_x = Conv2DOutput(3, 5, name="face_out_b")(var_x) + + outputs = [var_x] + + if cfg_loss.learn_mask(): + var_y = var_xy + var_y = layers.LeakyReLU(negative_slope=0.1)(var_y) + + mask_b_complexity = 384 + for idx in range(self.upscalers_no-2): + var_y = UpscaleBlock(mask_b_complexity // 2**idx, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(mask_b_complexity // 2**(idx + 1), activation="leakyrelu")(var_y) + + var_y = Conv2DOutput(1, 5, name="mask_out_b")(var_y) + + outputs += [var_y] + + return KModel(input_, outputs=outputs, name="decoder_b") + + def decoder_a(self): + """ RealFace Decoder (A) Network """ + input_filters = cfg.complexity_encoder() * 2**(self.downscalers_no-1) + input_width = cfg.input_size() // self._downscale_ratio + input_ = Input(shape=(input_width, input_width, input_filters)) + + var_xy = input_ + + dense_nodes = int(cfg.dense_nodes()/1.5) + dense_filters = int(self.dense_filters/1.5) + + var_xy = layers.Dense(dense_nodes)(layers.Flatten()(var_xy)) + var_xy = layers.Dense(self.dense_width * self.dense_width * dense_filters)(var_xy) + var_xy = layers.Reshape((self.dense_width, self.dense_width, dense_filters))(var_xy) + + var_xy = UpscaleBlock(dense_filters, activation=None)(var_xy) + + var_x = var_xy + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(dense_filters, use_bias=False)(var_x) + + decoder_a_complexity = int(cfg.complexity_decoder() / 1.5) + for idx in range(self.upscalers_no-2): + var_x = UpscaleBlock(decoder_a_complexity // 2**idx, activation="leakyrelu")(var_x) + var_x = UpscaleBlock(decoder_a_complexity // 2**(idx + 1), activation="leakyrelu")(var_x) + + var_x = Conv2DOutput(3, 5, name="face_out_a")(var_x) + + outputs = [var_x] + + if cfg_loss.learn_mask(): + var_y = var_xy + var_y = layers.LeakyReLU(negative_slope=0.1)(var_y) + + mask_a_complexity = 384 + for idx in range(self.upscalers_no-2): + var_y = UpscaleBlock(mask_a_complexity // 2**idx, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(mask_a_complexity // 2**(idx + 1), activation="leakyrelu")(var_y) + + var_y = Conv2DOutput(1, 5, name="mask_out_a")(var_y) + + outputs += [var_y] + + return KModel(input_, outputs=outputs, name="decoder_a") diff --git a/plugins/train/model/realface_defaults.py b/plugins/train/model/realface_defaults.py new file mode 100755 index 0000000000..2727fb9034 --- /dev/null +++ b/plugins/train/model/realface_defaults.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Realface Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "An extra detailed variant of Original model.\n" + "Incorporates ideas from Bryanlyon and inspiration from the Villain model.\n" + "Requires about 6GB-8GB of VRAM (batchsize 8-16).\n" +) + + +input_size = ConfigItem( + datatype=int, + default=64, + group="size", + info="Resolution (in pixels) of the input image to train on.\n" + "BE AWARE Larger resolution will dramatically increase VRAM requirements.\n" + "Higher resolutions may increase prediction accuracy, but does not effect the " + "resulting output size.\nMust be between 64 and 128 and be divisible by 16.", + rounding=16, + min_max=(64, 128), + fixed=True) + +output_size = ConfigItem( + datatype=int, + default=128, + group="size", + info="Output image resolution (in pixels).\nBe aware that larger resolution will " + "increase VRAM requirements.\nNB: Must be between 64 and 256 and be divisible " + "by 16.", + rounding=16, + min_max=(64, 256), + fixed=True) + +dense_nodes = ConfigItem( + datatype=int, + default=1536, + group="network", + info="Number of nodes for decoder. Might affect your model's ability to learn in " + "general.\nNote that: Lower values will affect the ability to predict " + "details.", + rounding=64, + min_max=(768, 2048), + fixed=True) + +complexity_encoder = ConfigItem( + datatype=int, + default=128, + group="network", + info="Encoder Convolution Layer Complexity. sensible ranges: 128 to 150.", + rounding=4, + min_max=(96, 160), + fixed=True) + +complexity_decoder = ConfigItem( + datatype=int, + default=512, + group="network", + info="Decoder Complexity.", + rounding=4, + min_max=(512, 544), + fixed=True) diff --git a/plugins/train/model/unbalanced.py b/plugins/train/model/unbalanced.py index 85ede86fc9..2756ed73c9 100644 --- a/plugins/train/model/unbalanced.py +++ b/plugins/train/model/unbalanced.py @@ -1,130 +1,137 @@ #!/usr/bin/env python3 """ Unbalanced Model Based on the original https://www.reddit.com/r/deepfakes/ - code sample + contribs """ + code sample + contributions """ -from keras.initializers import RandomNormal -from keras.layers import Conv2D, Dense, Flatten, Input, Reshape, SpatialDropout2D -from keras.models import Model as KerasModel +from keras import initializers, Input, layers, Model as KModel -from .original import logger, Model as OriginalModel +from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, ResidualBlock, UpscaleBlock +from plugins.train.train_config import Loss as cfg_loss +from ._base import ModelBase +from . import unbalanced_defaults as cfg +# pylint:disable=duplicate-code -class Model(OriginalModel): + +class Model(ModelBase): """ Unbalanced Faceswap Model """ def __init__(self, *args, **kwargs): - logger.debug("Initializing %s: (args: %s, kwargs: %s", - self.__class__.__name__, args, kwargs) + super().__init__(*args, **kwargs) + self.input_shape = (cfg.input_size(), cfg.input_size(), 3) + self.low_mem = cfg.lowmem() + self.encoder_dim = 512 if self.low_mem else cfg.nodes() + self.kernel_initializer = initializers.RandomNormal(0, 0.02) - self.lowmem = self.config.get("lowmem", False) - kwargs["input_shape"] = (self.config["input_size"], self.config["input_size"], 3) - kwargs["encoder_dim"] = 512 if self.lowmem else self.config["nodes"] - self.kernel_initializer = RandomNormal(0, 0.02) + def build_model(self, inputs): + """ build the Unbalanced Model. """ + encoder = self.encoder() + encoder_a = encoder(inputs[0]) + encoder_b = encoder(inputs[1]) - super().__init__(*args, **kwargs) - logger.debug("Initialized %s", self.__class__.__name__) + outputs = self.decoder_a()(encoder_a) + self.decoder_b()(encoder_b) - def add_networks(self): - """ Add the original model weights """ - logger.debug("Adding networks") - self.add_network("decoder", "a", self.decoder_a()) - self.add_network("decoder", "b", self.decoder_b()) - self.add_network("encoder", None, self.encoder()) - logger.debug("Added networks") + autoencoder = KModel(inputs, outputs, name=self.model_name) + return autoencoder def encoder(self): """ Unbalanced Encoder """ - kwargs = dict(kernel_initializer=self.kernel_initializer) - encoder_complexity = 128 if self.lowmem else self.config["complexity_encoder"] - dense_dim = 384 if self.lowmem else 512 + kwargs = {"kernel_initializer": self.kernel_initializer} + encoder_complexity = 128 if self.low_mem else cfg.complexity_encoder() + dense_dim = 384 if self.low_mem else 512 dense_shape = self.input_shape[0] // 16 input_ = Input(shape=self.input_shape) var_x = input_ - var_x = self.blocks.conv(var_x, encoder_complexity, use_instance_norm=True, **kwargs) - var_x = self.blocks.conv(var_x, encoder_complexity * 2, use_instance_norm=True, **kwargs) - var_x = self.blocks.conv(var_x, encoder_complexity * 4, **kwargs) - var_x = self.blocks.conv(var_x, encoder_complexity * 6, **kwargs) - var_x = self.blocks.conv(var_x, encoder_complexity * 8, **kwargs) - var_x = Dense(self.encoder_dim, - kernel_initializer=self.kernel_initializer)(Flatten()(var_x)) - var_x = Dense(dense_shape * dense_shape * dense_dim, - kernel_initializer=self.kernel_initializer)(var_x) - var_x = Reshape((dense_shape, dense_shape, dense_dim))(var_x) - return KerasModel(input_, var_x) + var_x = Conv2DBlock(encoder_complexity, + normalization="instance", + activation="leakyrelu", + **kwargs)(var_x) + var_x = Conv2DBlock(encoder_complexity * 2, + normalization="instance", + activation="leakyrelu", + **kwargs)(var_x) + var_x = Conv2DBlock(encoder_complexity * 4, **kwargs, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(encoder_complexity * 6, **kwargs, activation="leakyrelu")(var_x) + var_x = Conv2DBlock(encoder_complexity * 8, **kwargs, activation="leakyrelu")(var_x) + var_x = layers.Dense(self.encoder_dim, + kernel_initializer=self.kernel_initializer)(layers.Flatten()(var_x)) + var_x = layers.Dense(dense_shape * dense_shape * dense_dim, + kernel_initializer=self.kernel_initializer)(var_x) + var_x = layers.Reshape((dense_shape, dense_shape, dense_dim))(var_x) + return KModel(input_, var_x, name="encoder") def decoder_a(self): """ Decoder for side A """ - kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer) - decoder_complexity = 320 if self.lowmem else self.config["complexity_decoder_a"] - dense_dim = 384 if self.lowmem else 512 + kwargs = {"kernel_size": 5, "kernel_initializer": self.kernel_initializer} + decoder_complexity = 320 if self.low_mem else cfg.complexity_decoder_a() + dense_dim = 384 if self.low_mem else 512 decoder_shape = self.input_shape[0] // 16 input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim)) var_x = input_ - var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs) - var_x = SpatialDropout2D(0.25)(var_x) - var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs) - if self.lowmem: - var_x = SpatialDropout2D(0.15)(var_x) + var_x = UpscaleBlock(decoder_complexity, activation="leakyrelu", **kwargs)(var_x) + var_x = layers.SpatialDropout2D(0.25)(var_x) + var_x = UpscaleBlock(decoder_complexity, activation="leakyrelu", **kwargs)(var_x) + if self.low_mem: + var_x = layers.SpatialDropout2D(0.15)(var_x) else: - var_x = SpatialDropout2D(0.25)(var_x) - var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs) - var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs) - var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + var_x = layers.SpatialDropout2D(0.25)(var_x) + var_x = UpscaleBlock(decoder_complexity // 2, activation="leakyrelu", **kwargs)(var_x) + var_x = UpscaleBlock(decoder_complexity // 4, activation="leakyrelu", **kwargs)(var_x) + var_x = Conv2DOutput(3, 5, name="face_out_a")(var_x) outputs = [var_x] - if self.config.get("mask_type", None): + if cfg_loss.learn_mask(): var_y = input_ - var_y = self.blocks.upscale(var_y, decoder_complexity) - var_y = self.blocks.upscale(var_y, decoder_complexity) - var_y = self.blocks.upscale(var_y, decoder_complexity // 2) - var_y = self.blocks.upscale(var_y, decoder_complexity // 4) - var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + var_y = UpscaleBlock(decoder_complexity, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(decoder_complexity, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(decoder_complexity // 2, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(decoder_complexity // 4, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name="mask_out_a")(var_y) outputs.append(var_y) - return KerasModel(input_, outputs=outputs) + return KModel(input_, outputs=outputs, name="decoder_a") def decoder_b(self): """ Decoder for side B """ - kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer) - dense_dim = 384 if self.lowmem else self.config["complexity_decoder_b"] - decoder_complexity = 384 if self.lowmem else 512 + kwargs = {"kernel_size": 5, "kernel_initializer": self.kernel_initializer} + decoder_complexity = 384 if self.low_mem else cfg.complexity_decoder_b() + dense_dim = 384 if self.low_mem else 512 decoder_shape = self.input_shape[0] // 16 input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim)) var_x = input_ - if self.lowmem: - var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs) - var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs) - var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs) - var_x = self.blocks.upscale(var_x, decoder_complexity // 8, **kwargs) + if self.low_mem: + var_x = UpscaleBlock(decoder_complexity, activation="leakyrelu", **kwargs)(var_x) + var_x = UpscaleBlock(decoder_complexity // 2, activation="leakyrelu", **kwargs)(var_x) + var_x = UpscaleBlock(decoder_complexity // 4, activation="leakyrelu", **kwargs)(var_x) + var_x = UpscaleBlock(decoder_complexity // 8, activation="leakyrelu", **kwargs)(var_x) else: - var_x = self.blocks.upscale(var_x, decoder_complexity, - res_block_follows=True, **kwargs) - var_x = self.blocks.res_block(var_x, decoder_complexity, - kernel_initializer=self.kernel_initializer) - var_x = self.blocks.upscale(var_x, decoder_complexity, - res_block_follows=True, **kwargs) - var_x = self.blocks.res_block(var_x, decoder_complexity, - kernel_initializer=self.kernel_initializer) - var_x = self.blocks.upscale(var_x, decoder_complexity // 2, - res_block_follows=True, **kwargs) - var_x = self.blocks.res_block(var_x, decoder_complexity // 2, - kernel_initializer=self.kernel_initializer) - var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs) - var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + var_x = UpscaleBlock(decoder_complexity, activation=None, **kwargs)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(decoder_complexity, + kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(decoder_complexity, activation=None, **kwargs)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(decoder_complexity, + kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(decoder_complexity // 2, activation=None, **kwargs)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(decoder_complexity // 2, + kernel_initializer=self.kernel_initializer)(var_x) + var_x = UpscaleBlock(decoder_complexity // 4, activation="leakyrelu", **kwargs)(var_x) + var_x = Conv2DOutput(3, 5, name="face_out_b")(var_x) outputs = [var_x] - if self.config.get("mask_type", None): + if cfg_loss.learn_mask(): var_y = input_ - var_y = self.blocks.upscale(var_y, decoder_complexity) - if not self.lowmem: - var_y = self.blocks.upscale(var_y, decoder_complexity) - var_y = self.blocks.upscale(var_y, decoder_complexity // 2) - var_y = self.blocks.upscale(var_y, decoder_complexity // 4) - if self.lowmem: - var_y = self.blocks.upscale(var_y, decoder_complexity // 8) - var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + var_y = UpscaleBlock(decoder_complexity, activation="leakyrelu")(var_y) + if not self.low_mem: + var_y = UpscaleBlock(decoder_complexity, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(decoder_complexity // 2, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(decoder_complexity // 4, activation="leakyrelu")(var_y) + if self.low_mem: + var_y = UpscaleBlock(decoder_complexity // 8, activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name="mask_out_b")(var_y) outputs.append(var_y) - return KerasModel(input_, outputs=outputs) + return KModel(input_, outputs=outputs, name="decoder_b") diff --git a/plugins/train/model/unbalanced_defaults.py b/plugins/train/model/unbalanced_defaults.py new file mode 100755 index 0000000000..52a2ca46f9 --- /dev/null +++ b/plugins/train/model/unbalanced_defaults.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Unbalanced Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "An unbalanced model with adjustable input size options.\n" + "This is an unbalanced model so b>a swaps may not work well\n" +) + + +input_size = ConfigItem( + datatype=int, + default=128, + group="size", + info="Resolution (in pixels) of the image to train on.\n" + "BE AWARE Larger resolution will dramatically increaseVRAM requirements.\n" + "Make sure your resolution is divisible by 64 (e.g. 64, 128, 256 etc.).\n" + "NB: Your faceset must be at least 1.6x larger than your required input " + "size.\n(e.g. 160 is the maximum input size for a 256x256 faceset).", + rounding=64, + min_max=(64, 512), + fixed=True) + +lowmem = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\n" + "NB: Models with a changed lowmem mode are not compatible with each other.\n" + "NB: lowmem will override cutom nodes and complexity settings.", + fixed=True) + +nodes = ConfigItem( + datatype=int, + default=1024, + group="network", + info="Number of nodes for decoder. Don't change this unless you know what you are doing!", + rounding=64, + min_max=(512, 4096), + fixed=True) + +complexity_encoder = ConfigItem( + datatype=int, + default=128, + group="network", + info="Encoder Convolution Layer Complexity. sensible ranges: 128 to 160.", + rounding=16, + min_max=(64, 1024), + fixed=True) + +complexity_decoder_a = ConfigItem( + datatype=int, + default=384, + group="network", + info="Decoder A Complexity.", + rounding=16, + min_max=(64, 1024), + fixed=True) + +complexity_decoder_b = ConfigItem( + datatype=int, + default=512, + group="network", + info="Decoder B Complexity.", + rounding=16, + min_max=(64, 1024), + fixed=True) diff --git a/plugins/train/model/villain.py b/plugins/train/model/villain.py index c4d18ac1cd..e9081d05f3 100644 --- a/plugins/train/model/villain.py +++ b/plugins/train/model/villain.py @@ -1,83 +1,88 @@ #!/usr/bin/env python3 """ Original - VillainGuy model - Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs + Based on the original https://www.reddit.com/r/deepfakes/ code sample + contributions Adapted from a model by VillainGuy (https://github.com/VillainGuy) """ -from keras.initializers import RandomNormal -from keras.layers import add, Conv2D, Dense, Flatten, Input, Reshape -from keras.models import Model as KerasModel +from keras import initializers, Input, layers, Model as KModel from lib.model.layers import PixelShuffler -from .original import logger, Model as OriginalModel +from lib.model.nn_blocks import (Conv2DOutput, Conv2DBlock, ResidualBlock, SeparableConv2DBlock, + UpscaleBlock) +from plugins.train.train_config import Loss as cfg_loss + +from .original import Model as OriginalModel +from . import villain_defaults as cfg +# pylint:disable=duplicate-code class Model(OriginalModel): """ Villain Faceswap Model """ def __init__(self, *args, **kwargs): - logger.debug("Initializing %s: (args: %s, kwargs: %s", - self.__class__.__name__, args, kwargs) - - kwargs["input_shape"] = (128, 128, 3) - kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024 - self.kernel_initializer = RandomNormal(0, 0.02) - super().__init__(*args, **kwargs) - logger.debug("Initialized %s", self.__class__.__name__) + self.input_shape = (128, 128, 3) + self.encoder_dim = 512 if self.low_mem else 1024 + self.kernel_initializer = initializers.RandomNormal(0, 0.02) def encoder(self): """ Encoder Network """ - kwargs = dict(kernel_initializer=self.kernel_initializer) + kwargs = {"kernel_initializer": self.kernel_initializer} input_ = Input(shape=self.input_shape) in_conv_filters = self.input_shape[0] if self.input_shape[0] > 128: in_conv_filters = 128 + (self.input_shape[0] - 128) // 4 dense_shape = self.input_shape[0] // 16 - var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs) + var_x = Conv2DBlock(in_conv_filters, activation=None, **kwargs)(input_) tmp_x = var_x - res_cycles = 8 if self.config.get("lowmem", False) else 16 + + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + res_cycles = 8 if cfg.lowmem() else 16 for _ in range(res_cycles): - nn_x = self.blocks.res_block(var_x, 128, **kwargs) + nn_x = ResidualBlock(in_conv_filters, **kwargs)(var_x) var_x = nn_x # consider adding scale before this layer to scale the residual chain - var_x = add([var_x, tmp_x]) - var_x = self.blocks.conv(var_x, 128, **kwargs) + tmp_x = layers.LeakyReLU(negative_slope=0.1)(tmp_x) + var_x = layers.add([var_x, tmp_x]) + var_x = Conv2DBlock(128, activation="leakyrelu", **kwargs)(var_x) var_x = PixelShuffler()(var_x) - var_x = self.blocks.conv(var_x, 128, **kwargs) + var_x = Conv2DBlock(128, activation="leakyrelu", **kwargs)(var_x) var_x = PixelShuffler()(var_x) - var_x = self.blocks.conv(var_x, 128, **kwargs) - var_x = self.blocks.conv_sep(var_x, 256, **kwargs) - var_x = self.blocks.conv(var_x, 512, **kwargs) - if not self.config.get("lowmem", False): - var_x = self.blocks.conv_sep(var_x, 1024, **kwargs) + var_x = Conv2DBlock(128, activation="leakyrelu", **kwargs)(var_x) + var_x = SeparableConv2DBlock(256, **kwargs)(var_x) + var_x = Conv2DBlock(512, activation="leakyrelu", **kwargs)(var_x) + if not cfg.lowmem(): + var_x = SeparableConv2DBlock(1024, **kwargs)(var_x) - var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x)) - var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x) - var_x = Reshape((dense_shape, dense_shape, 1024))(var_x) - var_x = self.blocks.upscale(var_x, 512, **kwargs) - return KerasModel(input_, var_x) + var_x = layers.Dense(self.encoder_dim, **kwargs)(layers.Flatten()(var_x)) + var_x = layers.Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x) + var_x = layers.Reshape((dense_shape, dense_shape, 1024))(var_x) + var_x = UpscaleBlock(512, activation="leakyrelu", **kwargs)(var_x) + return KModel(input_, var_x, name="encoder") - def decoder(self): + def decoder(self, side): """ Decoder Network """ - kwargs = dict(kernel_initializer=self.kernel_initializer) + kwargs = {"kernel_initializer": self.kernel_initializer} decoder_shape = self.input_shape[0] // 8 input_ = Input(shape=(decoder_shape, decoder_shape, 512)) var_x = input_ - var_x = self.blocks.upscale(var_x, 512, res_block_follows=True, **kwargs) - var_x = self.blocks.res_block(var_x, 512, **kwargs) - var_x = self.blocks.upscale(var_x, 256, res_block_follows=True, **kwargs) - var_x = self.blocks.res_block(var_x, 256, **kwargs) - var_x = self.blocks.upscale(var_x, self.input_shape[0], res_block_follows=True, **kwargs) - var_x = self.blocks.res_block(var_x, self.input_shape[0], **kwargs) - var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + var_x = UpscaleBlock(512, activation=None, **kwargs)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(512, **kwargs)(var_x) + var_x = UpscaleBlock(256, activation=None, **kwargs)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(256, **kwargs)(var_x) + var_x = UpscaleBlock(self.input_shape[0], activation=None, **kwargs)(var_x) + var_x = layers.LeakyReLU(negative_slope=0.2)(var_x) + var_x = ResidualBlock(self.input_shape[0], **kwargs)(var_x) + var_x = Conv2DOutput(3, 5, name=f"face_out_{side}")(var_x) outputs = [var_x] - if self.config.get("mask_type", None): + if cfg_loss.learn_mask(): var_y = input_ - var_y = self.blocks.upscale(var_y, 512) - var_y = self.blocks.upscale(var_y, 256) - var_y = self.blocks.upscale(var_y, self.input_shape[0]) - var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + var_y = UpscaleBlock(512, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(256, activation="leakyrelu")(var_y) + var_y = UpscaleBlock(self.input_shape[0], activation="leakyrelu")(var_y) + var_y = Conv2DOutput(1, 5, name=f"mask_out_{side}")(var_y) outputs.append(var_y) - return KerasModel(input_, outputs=outputs) + return KModel(input_, outputs=outputs, name=f"decoder_{side}") diff --git a/plugins/train/model/villain_defaults.py b/plugins/train/model/villain_defaults.py new file mode 100755 index 0000000000..22946c252c --- /dev/null +++ b/plugins/train/model/villain_defaults.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +""" The default options for the faceswap Villain Model plugin. + +Defaults files should be named `_defaults.py` + +Any qualifying items placed into this file will automatically get added to the relevant config +.ini files within the faceswap/config folder and added to the relevant GUI settings page. + +The following variable should be defined: + + Parameters + ---------- + HELPTEXT: str + A string describing what this plugin does + +Further plugin configuration options are assigned using: +>>> = ConfigItem(...) + +where is the name of the configuration option to be added (lower-case, alpha-numeric ++ underscore only) and ConfigItem(...) is the [`~lib.config.objects.ConfigItem`] data for the +option. + +See the docstring/ReadtheDocs documentation required parameters for the ConfigItem object. +Items will be grouped together as per their `group` parameter, but otherwise will be processed in +the order that they are added to this module. +from lib.config import ConfigItem +""" +# pylint:disable=duplicate-code +from lib.config import ConfigItem + + +HELPTEXT = ( + "A Higher resolution version of the Original Model by VillainGuy.\n" + "Extremely VRAM heavy. Don't try to run this if you have a small GPU.\n" +) + + +lowmem = ConfigItem( + datatype=bool, + default=False, + group="settings", + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\n" + "NB: Models with a changed lowmem mode are not compatible with each other.", + fixed=True) diff --git a/plugins/train/train_config.py b/plugins/train/train_config.py new file mode 100644 index 0000000000..814614f48e --- /dev/null +++ b/plugins/train/train_config.py @@ -0,0 +1,805 @@ +#!/usr/bin/env python3 +""" Default configurations for models """ + +import gettext +import logging +import os + +from dataclasses import dataclass + +from lib.config import ConfigItem, FaceswapConfig, GlobalSection +from plugins.plugin_loader import PluginLoader +from plugins.train.trainer import trainer_config + +# LOCALES +_LANG = gettext.translation("plugins.train._config", localedir="locales", fallback=True) +_ = _LANG.gettext + +logger = logging.getLogger(__name__) + + +_ADDITIONAL_INFO = _("\nNB: Unless specifically stated, values changed here will only take effect " + "when creating a new model.") + + +class _Config(FaceswapConfig): + """ Config File for Models """ + # pylint:disable=too-many-statements + def set_defaults(self, helptext="") -> None: + """ Set the default values for config """ + super().set_defaults(helptext=_("Options that apply to all models") + _ADDITIONAL_INFO) + self._defaults_from_plugin(os.path.dirname(__file__)) + + train_helptext, section, train_opts = trainer_config.get_defaults() + self.add_section(section, train_helptext) + for k, v in train_opts.items(): + self.add_item(section, k, v) + + +centering = ConfigItem( + datatype=str, + default="face", + gui_radio=True, + group=_("face"), + info=_( + "How to center the training image. The extracted images are centered on the middle of the " + "skull based on the face's estimated pose. A subsection of these images are used for " + "training. The centering used dictates how this subsection will be cropped from the " + "aligned images." + "\n\tface: Centers the training image on the center of the face, adjusting for pitch and " + "yaw." + "\n\thead: Centers the training image on the center of the head, adjusting for pitch and " + "yaw. NB: You should only select head centering if you intend to include the full head (" + "including hair) in the final swap. This may give mixed results. Additionally, it is only " + "worth choosing head centering if you are training with a mask that includes the hair (" + "e.g. BiSeNet-FP-Head)." + "\n\tlegacy: The 'original' extraction technique. Centers the training image near the tip " + "of the nose with no adjustment. Can result in the edges of the face appearing outside of " + "the training area."), + choices=["face", "head", "legacy"], + fixed=True) + + +coverage = ConfigItem( + datatype=float, + default=100.0, + group=_("face"), + info=_( + "How much of the extracted image to train on. A lower coverage will limit the model's " + "scope to a zoomed-in central area while higher amounts can include the entire face. A " + "trade-off exists between lower amounts given more detail versus higher amounts avoiding " + "noticeable swap transitions. For 'Face' centering you will want to leave this above 75%. " + "For Head centering you will most likely want to set this to 100%. Sensible values for " + "'Legacy' centering are:" + "\n\t62.5% spans from eyebrow to eyebrow." + "\n\t75.0% spans from temple to temple." + "\n\t87.5% spans from ear to ear." + "\n\t100.0% is a mugshot."), + min_max=(62.5, 100.0), + rounding=2, + fixed=True) + + +vertical_offset = ConfigItem( + datatype=int, + default=0, + group=_("face"), + info=_( + "How much to adjust the vertical position of the aligned face as a percentage of face " + "image size. Negative values move the face up (expose more chin and less forehead). " + "Positive values move the face down (expose less chin and more forehead)"), + min_max=(-25, 25), + rounding=1, + fixed=True) + + +icnr_init = ConfigItem( + datatype=bool, + default=False, + group=_("initialization"), + info=_( + "Use ICNR to tile the default initializer in a repeating pattern. This strategy is " + "designed for pairing with sub-pixel / pixel shuffler to reduce the 'checkerboard effect' " + "in image reconstruction. \n\t https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf")) + + +conv_aware_init = ConfigItem( + datatype=bool, + default=False, + group=_("initialization"), + info=_( + "Use Convolution Aware Initialization for convolutional layers. This can help eradicate " + "the vanishing and exploding gradient problem as well as lead to higher accuracy, lower " + "loss and faster convergence.\nNB:\n\t This can use more VRAM when creating a new model " + "so you may want to lower the batch size for the first run. The batch size can be raised " + "again when reloading the model." + "\n\t Multi-GPU is not supported for this option, so you should start the model on a " + "single GPU. Once training has started, you can stop training, enable multi-GPU and " + "resume." + "\n\t Building the model will likely take several minutes as the calculations for this " + "initialization technique are expensive. This will only impact starting a new model.")) + + +lr_finder_iterations = ConfigItem( + datatype=int, + default=1000, + group=_("Learning Rate Finder"), + info=_( + "The number of iterations to process to find the optimal learning rate. Higher values " + "will take longer, but will be more accurate."), + min_max=(100, 10000), + rounding=100, + fixed=True) + + +lr_finder_mode = ConfigItem( + datatype=str, + default="set", + group=_("Learning Rate Finder"), + info=_( + "The operation mode for the learning rate finder. Only applicable to new models. For " + "existing models this will always default to 'set'." + "\n\tset - Train with the discovered optimal learning rate." + "\n\tgraph_and_set - Output a graph in the training folder showing the discovered " + "learning rates and train with the optimal learning rate." + "\n\tgraph_and_exit - Output a graph in the training folder with the discovered learning " + "rates and exit."), + gui_radio=True, + choices=["set", "graph_and_set", "graph_and_exit"], + fixed=True) + + +lr_finder_strength = ConfigItem( + datatype=str, + default="default", + group=_("Learning Rate Finder"), + info=_( + "How aggressively to set the Learning Rate. More aggressive can learn faster, but is more " + "likely to lead to exploding gradients." + "\n\tdefault - The default optimal learning rate. A safe choice for nearly all use cases." + "\n\taggressive - Set's a higher learning rate than the default. May learn faster but " + "with a higher chance of exploding gradients." + "\n\textreme - The highest optimal learning rate. A much higher risk of exploding " + "gradients."), + gui_radio=True, + choices=["default", "aggressive", "extreme"], + fixed=True) + + +reflect_padding = ConfigItem( + datatype=bool, + default=False, + group=_("network"), + info=_( + "Use reflection padding rather than zero padding with convolutions. Each convolution must " + "pad the image boundaries to maintain the proper sizing. More complex padding schemes can " + "reduce artifacts at the border of the image." + "\n\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt")) + + +mixed_precision = ConfigItem( + datatype=bool, + default=False, + group=_("network"), + info=_( + "NVIDIA GPUs can run operations in float16 faster than in float32. Mixed precision allows " + "you to use a mix of float16 with float32, to get the performance benefits from float16 " + "and the numeric stability benefits from float32.\n\nThis is untested on non-Nvidia " + "cards, but will run on most Nvidia models. it will only speed up training on more recent " + "GPUs. Those with compute capability 7.0 or higher will see the greatest performance " + "benefit from mixed precision because they have Tensor Cores. Older GPUs offer no math " + "performance benefit for using mixed precision, however memory and bandwidth savings can " + "enable some speedups. Generally RTX GPUs and later will offer the most benefit."), + fixed=False) + + +nan_protection = ConfigItem( + datatype=bool, + default=True, + group=_("network"), + info=_( + "If a 'NaN' is generated in the model, this means that the model has corrupted and the " + "model is likely to start deteriorating from this point on. Enabling NaN protection will " + "stop training immediately in the event of a NaN. The last save will not contain the NaN, " + "so you may still be able to rescue your model."), + fixed=False) + + +convert_batchsize = ConfigItem( + datatype=int, + default=16, + group=_("convert"), + info=_( + "[GPU Only]. The number of faces to feed through the model at once when running the " + "Convert process.\n\nNB: Increasing this figure is unlikely to improve convert speed, " + "however, if you are getting Out of Memory errors, then you may want to reduce the batch " + "size."), + min_max=(1, 32), + rounding=1, + fixed=False) + + +_LOSS_HELP = { + "ffl": _( + "Focal Frequency Loss. Analyzes the frequency spectrum of the images rather than the " + "images themselves. This loss function can be used on its own, but the original paper " + "found increased benefits when using it as a complementary loss to another spacial loss " + "function (e.g. MSE). Ref: Focal Frequency Loss for Image Reconstruction and Synthesis " + "https://arxiv.org/pdf/2012.12821.pdf NB: This loss does not currently work on AMD " + "cards."), + "flip": _( + "Nvidia FLIP. A perceptual loss measure that approximates the difference perceived by " + "humans as they alternate quickly (or flip) between two images. Used on its own and this " + "loss function creates a distinct grid on the output. However it can be helpful when " + "used as a complimentary loss function. Ref: FLIP: A Difference Evaluator for " + "Alternating Images: " + "https://research.nvidia.com/sites/default/files/node/3260/FLIP_Paper.pdf"), + "gmsd": _( + "Gradient Magnitude Similarity Deviation seeks to match the global standard deviation of " + "the pixel to pixel differences between two images. Similar in approach to SSIM. Ref: " + "Gradient Magnitude Similarity Deviation: An Highly Efficient Perceptual Image Quality " + "Index https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf"), + "l_inf_norm": _( + "The L_inf norm will reduce the largest individual pixel error in an image. As " + "each largest error is minimized sequentially, the overall error is improved. This loss " + "will be extremely focused on outliers."), + "laploss": _( + "Laplacian Pyramid Loss. Attempts to improve results by focussing on edges using " + "Laplacian Pyramids. As this loss function gives priority to edges over other low-" + "frequency information, like color, it should not be used on its own. The original " + "implementation uses this loss as a complimentary function to MSE. " + "Ref: Optimizing the Latent Space of Generative Networks " + "https://arxiv.org/abs/1707.05776"), + "lpips_alex": _( + "LPIPS is a perceptual loss that uses the feature outputs of other pretrained models as a " + "loss metric. Be aware that this loss function will use more VRAM. Used on its own and " + "this loss will create a distinct moire pattern on the output, however it can be helpful " + "as a complimentary loss function. The output of this function is strong, so depending " + "on your chosen primary loss function, you are unlikely going to want to set the weight " + "above about 25%. Ref: The Unreasonable Effectiveness of Deep Features as a Perceptual " + "Metric http://arxiv.org/abs/1801.03924\nThis variant uses the AlexNet backbone. A fairly " + "light and old model which performed best in the paper's original implementation.\nNB: " + "For AMD Users the final linear layer is not implemented."), + "lpips_squeeze": _( + "Same as lpips_alex, but using the SqueezeNet backbone. A more lightweight " + "version of AlexNet.\nNB: For AMD Users the final linear layer is not implemented."), + "lpips_vgg16": _( + "Same as lpips_alex, but using the VGG16 backbone. A more heavyweight model.\n" + "NB: For AMD Users the final linear layer is not implemented."), + "logcosh": _( + "log(cosh(x)) acts similar to MSE for small errors and to MAE for large errors. Like " + "MSE, it is very stable and prevents overshoots when errors are near zero. Like MAE, it " + "is robust to outliers."), + "mae": _( + "Mean absolute error will guide reconstructions of each pixel towards its median value in " + "the training dataset. Robust to outliers but as a median, it can potentially ignore some " + "infrequent image types in the dataset."), + "mse": _( + "Mean squared error will guide reconstructions of each pixel towards its average value in " + "the training dataset. As an avg, it will be susceptible to outliers and typically " + "produces slightly blurrier results. Ref: Multi-Scale Structural Similarity for Image " + "Quality Assessment https://www.cns.nyu.edu/pub/eero/wang03b.pdf"), + "ms_ssim": _( + "Multiscale Structural Similarity Index Metric is similar to SSIM except that it " + "performs the calculations along multiple scales of the input image."), + "smooth_loss": _( + "Smooth_L1 is a modification of the MAE loss to correct two of its disadvantages. " + "This loss has improved stability and guidance for small errors. Ref: A General and " + "Adaptive Robust Loss Function https://arxiv.org/pdf/1701.03077.pdf"), + "ssim": _( + "Structural Similarity Index Metric is a perception-based loss that considers changes in " + "texture, luminance, contrast, and local spatial statistics of an image. Potentially " + "delivers more realistic looking images. Ref: Image Quality Assessment: From Error " + "Visibility to Structural Similarity http://www.cns.nyu.edu/pub/eero/wang03-reprint.pdf"), + "pixel_gradient_diff": _( + "Instead of minimizing the difference between the absolute value of each " + "pixel in two reference images, compute the pixel to pixel spatial difference in each " + "image and then minimize that difference between two images. Allows for large color " + "shifts, but maintains the structure of the image."), + "none": _("Do not use an additional loss function.")} + +_NON_PRIMARY_LOSS = ["flip", "lpips_alex", "lpips_squeeze", "lpips_vgg16", "none"] + + +@dataclass +class Loss(GlobalSection): + """ global.loss configuration section + Loss Documentation + MAE https://heartbeat.fritz.ai/5-regression-loss-functions-all-machine-learners-should-know-4fb140e9d4b0 + MSE https://heartbeat.fritz.ai/5-regression-loss-functions-all-machine-learners-should-know-4fb140e9d4b0 + LogCosh https://heartbeat.fritz.ai/5-regression-loss-functions-all-machine-learners-should-know-4fb140e9d4b0 + L_inf_norm https://medium.com/@montjoile/l0-norm-l1-norm-l2-norm-l-infinity-norm-7a7d18a4f40c + """ # pylint:disable=line-too-long # noqa[E501] + + helptext = _( + "Loss configuration options\n" + "Loss is the mechanism by which a Neural Network judges how well it thinks that it " + "is recreating a face.") + _ADDITIONAL_INFO + loss_function = ConfigItem( + datatype=str, + default="ssim", + group=_("loss"), + info=(_("The loss function to use.") + + "\n\n\t" + "\n\n\t".join(f"{k}: {v}" + for k, v in sorted(_LOSS_HELP.items()) + if k not in _NON_PRIMARY_LOSS)), + choices=[x for x in sorted(_LOSS_HELP) if x not in _NON_PRIMARY_LOSS], + fixed=False) + loss_function_2 = ConfigItem( + datatype=str, + default="mse", + group=_("loss"), + info=_( + "The second loss function to use. If using a structural based loss (such as " + "SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " + "regularization (MSE) function. You can adjust the weighting of this loss " + "function with the loss_weight_2 option." + + "\n\n\t" + "\n\n\t".join(f"{k}: {v}" for k, v in sorted(_LOSS_HELP.items()))), + choices=list(sorted(_LOSS_HELP)), + fixed=False) + loss_weight_2 = ConfigItem( + datatype=int, + default=100, + group=_("loss"), + info=_( + "The amount of weight to apply to the second loss function.\n\n" + "\n\nThe value given here is as a percentage denoting how much the selected " + "function should contribute to the overall loss cost of the model. For " + "example:" + "\n\t 100 - The loss calculated for the second loss function will be applied " + "at its full amount towards the overall loss score. " + "\n\t 25 - The loss calculated for the second loss function will be reduced " + "by a quarter prior to adding to the overall loss score. " + "\n\t 400 - The loss calculated for the second loss function will be " + "mulitplied 4 times prior to adding to the overall loss score. " + "\n\t 0 - Disables the second loss function altogether."), + min_max=(0, 400), + rounding=1, + fixed=False) + loss_function_3 = ConfigItem( + datatype=str, + default="none", + group=_("loss"), + info=_("The third loss function to use. You can adjust the weighting of this loss " + "function with the loss_weight_3 option." + + "\n\n\t" + + "\n\n\t".join(f"{k}: {v}" for k, v in sorted(_LOSS_HELP.items()))), + choices=list(sorted(_LOSS_HELP)), + fixed=False) + loss_weight_3 = ConfigItem( + datatype=int, + default=0, + group=_("loss"), + info=_( + "The amount of weight to apply to the third loss function.\n\n" + "\n\nThe value given here is as a percentage denoting how much the selected " + "function should contribute to the overall loss cost of the model. For " + "example:" + "\n\t 100 - The loss calculated for the third loss function will be applied " + "at its full amount towards the overall loss score. " + "\n\t 25 - The loss calculated for the third loss function will be reduced " + "by a quarter prior to adding to the overall loss score. " + "\n\t 400 - The loss calculated for the third loss function will be " + "mulitplied 4 times prior to adding to the overall loss score. " + "\n\t 0 - Disables the third loss function altogether."), + min_max=(0, 400), + rounding=1, + fixed=False) + loss_function_4 = ConfigItem( + datatype=str, + default="none", + group=_("loss"), + info=_( + "The fourth loss function to use. You can adjust the weighting of this " + "loss function with the loss_weight_3 option." + + "\n\n\t" + + "\n\n\t".join(f"{k}: {v}" for k, v in sorted(_LOSS_HELP.items()))), + choices=list(sorted(_LOSS_HELP)), + fixed=False) + loss_weight_4 = ConfigItem( + datatype=int, + default=0, + group=_("loss"), + info=_( + "The amount of weight to apply to the fourth loss function.\n\n" + "\n\nThe value given here is as a percentage denoting how much the selected " + "function should contribute to the overall loss cost of the model. For " + "example:" + "\n\t 100 - The loss calculated for the fourth loss function will be applied " + "at its full amount towards the overall loss score. " + "\n\t 25 - The loss calculated for the fourth loss function will be reduced " + "by a quarter prior to adding to the overall loss score. " + "\n\t 400 - The loss calculated for the fourth loss function will be " + "mulitplied 4 times prior to adding to the overall loss score. " + "\n\t 0 - Disables the fourth loss function altogether."), + min_max=(0, 400), + rounding=1, + fixed=False) + mask_loss_function = ConfigItem( + datatype=str, + default="mse", + group=_("loss"), + info=_( + "The loss function to use when learning a mask." + "\n\t MAE - Mean absolute error will guide reconstructions of each pixel " + "towards its median value in the training dataset. Robust to outliers but as " + "a median, it can potentially ignore some infrequent image types in the " + "dataset." + "\n\t MSE - Mean squared error will guide reconstructions of each pixel " + "towards its average value in the training dataset. As an average, it will be " + "susceptible to outliers and typically produces slightly blurrier results."), + choices=["mae", "mse"], + fixed=False) + eye_multiplier = ConfigItem( + datatype=int, + default=3, + group=_("loss"), + info=_( + "The amount of priority to give to the eyes.\n\nThe value given here is as a " + "multiplier of the main loss score. For example:" + "\n\t 1 - The eyes will receive the same priority as the rest of the face. " + "\n\t 10 - The eyes will be given a score 10 times higher than the rest of " + "the face." + "\n\nNB: Penalized Mask Loss must be enable to use this option."), + min_max=(1, 40), + rounding=1, + fixed=False) + mouth_multiplier = ConfigItem( + datatype=int, + default=2, + group=_("loss"), + info=_( + "The amount of priority to give to the mouth.\n\nThe value given here is as a " + "multiplier of the main loss score. For Example:" + "\n\t 1 - The mouth will receive the same priority as the rest of the face. " + "\n\t 10 - The mouth will be given a score 10 times higher than the rest of " + "the face." + "\n\nNB: Penalized Mask Loss must be enable to use this option."), + min_max=(1, 40), + rounding=1, + fixed=False) + penalized_mask_loss = ConfigItem( + datatype=bool, + default=True, + group=_("loss"), + info=_( + "Image loss function is weighted by mask presence. For areas of " + "the image without the facial mask, reconstruction errors will be " + "ignored while the masked face area is prioritized. May increase " + "overall quality by focusing attention on the core face area.")) + mask_type = ConfigItem( + datatype=str, + default="extended", + group=_("mask"), + info=_( + "The mask to be used for training. If you have selected 'Learn Mask' or " + "'Penalized Mask Loss' you must select a value other than 'none'. The " + "required mask should have been selected as part of the Extract process. If " + "it does not exist in the alignments file then it will be generated prior to " + "training commencing." + "\n\tnone: Don't use a mask." + "\n\tbisenet-fp_face: Relatively lightweight NN based mask that provides more " + "refined control over the area to be masked (configurable in mask settings). " + "Use this version of bisenet-fp if your model is trained with 'face' or " + "'legacy' centering." + "\n\tbisenet-fp_head: Relatively lightweight NN based mask that provides more " + "refined control over the area to be masked (configurable in mask settings). " + "Use this version of bisenet-fp if your model is trained with 'head' " + "centering." + "\n\tcomponents: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks to create a mask." + "\n\tcustom_face: Custom user created, face centered mask." + "\n\tcustom_head: Custom user created, head centered mask." + "\n\textended: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks and the mask is extended upwards onto the forehead." + "\n\tvgg-clear: Mask designed to provide smart segmentation of mostly frontal " + "faces clear of obstructions. Profile faces and obstructions may result in " + "sub-par performance." + "\n\tvgg-obstructed: Mask designed to provide smart segmentation of mostly " + "frontal faces. The mask model has been specifically trained to recognize " + "some facial obstructions (hands and eyeglasses). Profile faces may result in " + "sub-par performance." + "\n\tunet-dfl: Mask designed to provide smart segmentation of mostly frontal " + "faces. The mask model has been trained by community members and will need " + "testing for further description. Profile faces may result in sub-par " + "performance."), + choices=PluginLoader.get_available_extractors("mask", + add_none=True, extend_plugin=True), + gui_radio=True) + mask_dilation = ConfigItem( + datatype=float, + default=0.0, + group=_("mask"), + info=_( + "Dilate or erode the mask. Negative values erode the mask (make it smaller). " + "Positive values dilate the mask (make it larger). The value given is a " + "percentage of the total mask size."), + min_max=(-5.0, 5.0), + rounding=1, + fixed=False) + mask_blur_kernel = ConfigItem( + datatype=int, + default=3, + group=_("mask"), + info=_( + "Apply gaussian blur to the mask input. This has the effect of smoothing the " + "edges of the mask, which can help with poorly calculated masks and give less " + "of a hard edge to the predicted mask. The size is in pixels (calculated from " + "a 128px mask). Set to 0 to not apply gaussian blur. This value should be " + "odd, if an even number is passed in then it will be rounded to the next odd " + "number."), + min_max=(0, 9), + rounding=1, + fixed=False) + mask_threshold = ConfigItem( + datatype=int, + default=4, + group=_("mask"), + info=_( + "Sets pixels that are near white to white and near black to black. Set to 0 " + "for off."), + min_max=(0, 50), + rounding=1, + fixed=False) + learn_mask = ConfigItem( + datatype=bool, + default=False, + group=_("mask"), + info=_( + "Dedicate a portion of the model to learning how to duplicate the input " + "mask. Increases VRAM usage in exchange for learning a quick ability to try " + "to replicate more complex mask models.")) + + +@dataclass +class Optimizer(GlobalSection): + """ global.optimizer configuration section """ + helptext = (_("Optimizer configuration options\n" + "The optimizer applies the output of the loss function to the model.\n") + + _ADDITIONAL_INFO) + optimizer = ConfigItem( + datatype=str, + default="adam", + group=_("optimizer"), + info=_( + "The optimizer to use." + "\n\t adabelief - Adapting Stepsizes by the Belief in Observed Gradients. An " + "optimizer with the aim to converge faster, generalize better and remain more " + "stable. (https://arxiv.org/abs/2010.07468). NB: Epsilon for AdaBelief needs " + "to be set to a smaller value than other Optimizers. Generally setting the " + "'Epsilon Exponent' to around '-16' should work." + "\n\t adam - Adaptive Moment Optimization. A stochastic gradient descent " + "method that is based on adaptive estimation of first-order and second-order " + "moments." + "\n\t adamax - a variant of Adam based on the infinity norm. Due to its " + "capability of adjusting the learning rate based on data characteristics, it " + "is suited to learn time-variant process, " + "parameters follow those provided in the paper" + "\n\t adamw - Like 'adam' but with an added method to decay weights per the " + "techniques discussed in the paper (https://arxiv.org/abs/1711.05101). NB: " + "Weight decay should be set at 0.004 for default implementation." + "\n\t lion - A method that uses the sign operator to control the magnitude of " + "the update, rather than relying on second-order moments (Adam). saves VRAM " + "by only tracking the momentum. Performance gains should be better with " + "larger batch sizes. A suitable learning rate for Lion is typically 3-10x " + "smaller than that for AdamW. The weight decay for Lion should be 3-10x " + "larger than that for AdamW to maintain a similar strength." + "\n\t nadam - Adaptive Moment Optimization with Nesterov Momentum. Much like " + "Adam but uses a different formula for calculating momentum." + "\n\t rms-prop - Root Mean Square Propagation. Maintains a moving " + "(discounted) average of the square of the gradients. Divides the gradient by " + "the root of this average."), + choices=["adabelief", "adam", "adamax", "adamw", "lion", "nadam", "rms-prop"], + gui_radio=True, + fixed=True) + learning_rate = ConfigItem( + datatype=float, + default=5e-5, + group=_("optimizer"), + info=_( + "Learning rate - how fast your network will learn (how large are the " + "modifications to the model weights after one batch of training). Values that " + "are too large might result in model crashes and the inability of the model " + "to find the best solution. Values that are too small might be unable to " + "escape from dead-ends and find the best global minimum."), + min_max=(1e-6, 1e-4), + rounding=6, + fixed=False) + epsilon_exponent = ConfigItem( + datatype=int, + default=-7, + group=_("optimizer"), + info=_( + "The epsilon adds a small constant to weight updates to attempt to avoid " + "'divide by zero' errors. Unless you are using the AdaBelief Optimizer, then " + "Generally this option should be left at default value, For AdaBelief, " + "setting this to around '-16' should work.\n" + "In all instances if you are getting 'NaN' loss values, and have been unable " + "to resolve the issue any other way (for example, increasing batch size, or " + "lowering learning rate), then raising the epsilon can lead to a more stable " + "model. It may, however, come at the cost of slower training and a less " + "accurate final result.\n" + "Note: The value given here is the 'exponent' to the epsilon. For example, " + "choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the " + "epsilon to 0.001 (1e-3).\n" + "Note: Not used by the Lion optimizer"), + min_max=(-20, 0), + rounding=1, + fixed=False) + save_optimizer = ConfigItem( + datatype=str, + default="exit", + group=_("optimizer"), + info=_( + "When to save the Optimizer Weights. Saving the optimizer weights is not " + "necessary and will increase the model file size 3x (and by extension the " + "amount of time it takes to save the model). However, it can be useful to " + "save these weights if you want to guarantee that a resumed model carries off " + "exactly from where it left off, rather than spending a few hundred " + "iterations catching up." + "\n\t never - Don't save optimizer weights." + "\n\t always - Save the optimizer weights at every save iteration. Model " + "saving will take longer, due to the increased file size, but you will always " + "have the last saved optimizer state in your model file." + "\n\t exit - Only save the optimizer weights when explicitly terminating a " + "model. This can be when the model is actively stopped or when the target " + "iterations are met. Note: If the training session ends because of another " + "reason (e.g. power outage, Out of Memory Error, NaN detected) then the " + "optimizer weights will NOT be saved."), + gui_radio=True, + choices=["never", "always", "exit"], + fixed=False) + gradient_clipping = ConfigItem( + datatype=str, + default="none", + group=_("clipping"), + info=_( + "Apply clipping to the gradients. Can help prevent NaNs and improve model " + "optimization at the expense of VRAM." + "\n\tautoclip: Analyzes the gradient weights and adjusts the normalization " + "value dynamically to fit the data" + "\n\tglobal_norm: Clips the gradient of each weight so that the global norm " + "is no higher than the given value." + "\n\tnorm: Clips the gradient of each weight so that its norm is no higher " + "than the given value." + "\n\tvalue: Clips the gradient of each weight so that it is no higher than " + "the given value." + "\n\tnone: Don't perform any clipping to the gradients."), + choices=["autoclip", "global_norm", "norm", "value", "none"], + gui_radio=True, + fixed=False) + clipping_value = ConfigItem( + datatype=float, + default=1.0, + group=_("clipping"), + info=_( + "The amount of clipping to perform." + "\n\tautoclip: The percentile to clip at. A value of 1.0 will clip at the " + "10th percentile a value of 2.5 will clip at the 25th percentile etc. " + "Default: 1.0" + "\n\tglobal_norm: The gradient of each weight is clipped so that the global " + "norm is no higher than this value." + "\n\tnorm: The gradient of each weight is clipped so that its norm is no " + "higher than this value." + "\n\tvalue: The gradient of each weight is clipped to be no higher than this " + "value." + "\n\tnone: This option is ignored."), + min_max=(0.0, 10.0), + rounding=1, + fixed=False) + autoclip_history = ConfigItem( + datatype=int, + default=10000, + group=_("clipping"), + info=_( + "The maximum number of prior iterations for autoclipper to analyze when " + "calculating the normalization amount. 0 to always include all prior " + "iterations."), + min_max=(0, 100000), + rounding=1000, + fixed=False) + weight_decay = ConfigItem( + datatype=float, + default=0.0, + group=_("updates"), + info=_("If set, weight decay is applied. 0.0 for no weight decay. Default is 0.0 " + "for all optimizers except AdamW (0.004)"), + min_max=(0.0, 1.0), + rounding=4, + fixed=False) + gradient_accumulation = ConfigItem( + datatype=int, + default=1, + group=_("updates"), + info=_( + "Values above 1 will enable Gradient Accumulation. Updates will not be at " + "every iteration; instead they will occur every number of iterations given " + "here. The update will be the average value of the gradients since the last " + "update. Can be useful when your batch size is very small, in order to reduce " + "gradient noise at each update iteration."), + min_max=(1, 100), + rounding=1, + fixed=False) + use_ema = ConfigItem( + datatype=bool, + default=False, + group=_("exponential moving average"), + info=_( + "Enable exponential moving average (EMA). EMA consists of computing an " + "exponential moving average of the weights of the model (as the weight values " + "change after each training batch), and periodically overwriting the weights " + "with their moving average"), + fixed=True) + ema_momentum = ConfigItem( + datatype=float, + default=0.99, + group=_("exponential moving average"), + info=_( + "Only used if use_ema is enabled. This is the momentum to use when computing " + "the EMA of the model's weights: new_average = ema_momentum * old_average + " + "(1 - ema_momentum) * current_variable_value."), + min_max=(0.0, 1.0), + rounding=4, + fixed=True) + ema_frequency = ConfigItem( + datatype=int, + default=100, + group=_("exponential moving average"), + info=_( + "Only used if use_ema is enabled. Set the number of iterations, to overwrite " + "the model variable by its moving average. "), + min_max=(10, 10000), + rounding=10, + fixed=True) + ada_beta_1 = ConfigItem( + datatype=float, + default=0.9, + group=_("optimizer specific"), + info=_( + "The exponential decay rate for the 1st moment estimates. Used for the " + "following Optimizers: AdaBelief, Adam, Adamax, AdamW, Lion, nAdam. Ignored " + "for all others."), + min_max=(0.0, 1.0), + rounding=4, + fixed=True) + ada_beta_2 = ConfigItem( + datatype=float, + default=0.999, + group=_("optimizer specific"), + info=_( + "The exponential decay rate for the 2nd moment estimates. Used for the " + "following Optimizers: AdaBelief, Adam, Adamax, AdamW, Lion, nAdam. Ignored " + "for all others."), + min_max=(0.0, 1.0), + rounding=4, + fixed=True) + ada_amsgrad = ConfigItem( + datatype=bool, + default=False, + group=_("optimizer specific"), + info=_( + "Whether to apply AMSGrad variant of the algorithm from the paper 'On the " + "Convergence of Adam and beyond. Used for the following Optimizers: " + "AdaBelief, Adam, AdamW. Ignored for all others.'"), + fixed=True) + + +# pylint:disable=duplicate-code +_IS_LOADED: bool = False + + +def load_config(config_file: str | None = None) -> None: + """ Load the Train configuration .ini file + + Parameters + ---------- + config_file : str | None, optional + Path to a custom .ini configuration file to load. Default: ``None`` (use default + configuration file) + """ + global _IS_LOADED # pylint:disable=global-statement + if not _IS_LOADED: + _Config(configfile=config_file) + _IS_LOADED = True diff --git a/plugins/train/trainer/_base.py b/plugins/train/trainer/_base.py index c8ba688b8c..aebcce0810 100644 --- a/plugins/train/trainer/_base.py +++ b/plugins/train/trainer/_base.py @@ -1,576 +1,58 @@ #!/usr/bin/env python3 +""" Base Class for Faceswap Trainer plugins. All Trainer plugins should be inherited from +this class. - -""" Base Trainer Class for Faceswap - - Trainers should be inherited from this class. - - A training_opts dictionary can be set in the corresponding model. - Accepted values: - alignments: dict containing paths to alignments files for keys 'a' and 'b' - preview_scaling: How much to scale the preview out by - training_size: Size of the training images - coverage_ratio: Ratio of face to be cropped out for training - mask_type: Type of mask to use. See lib.model.masks for valid mask names. - Set to None for not used - no_logs: Disable tensorboard logging - warp_to_landmarks: Use random_warp_landmarks instead of random_warp - no_flip: Don't perform a random flip on the image +At present there is only the :class:`~plugins.train.trainer.original` plugin, so that entirely +inherits from this class. If further plugins are developed, then common code should be kept here, +with "original" unique code split out to the original plugin. """ - +from __future__ import annotations +import abc import logging -import os -import time - -import cv2 -import numpy as np - -from tensorflow import keras as tf_keras - -from lib.alignments import Alignments -from lib.faces_detect import DetectedFace -from lib.training_data import TrainingDataGenerator, stack_images -from lib.utils import get_folder, get_image_paths - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class TrainerBase(): - """ Base Trainer """ - - def __init__(self, model, images, batch_size): - logger.debug("Initializing %s: (model: '%s', batch_size: %s)", - self.__class__.__name__, model, batch_size) - self.batch_size = batch_size - self.model = model - self.model.state.add_session_batchsize(batch_size) - self.images = images - - self.process_training_opts() - - self.batchers = {side: Batcher(side, - images[side], - self.model, - self.use_mask, - batch_size) - for side in images.keys()} - - self.tensorboard = self.set_tensorboard() - self.samples = Samples(self.model, - self.use_mask, - self.model.training_opts["coverage_ratio"], - self.model.training_opts["preview_scaling"]) - self.timelapse = Timelapse(self.model, - self.use_mask, - self.model.training_opts["coverage_ratio"], - self.batchers) - logger.debug("Initialized %s", self.__class__.__name__) - - @property - def timestamp(self): - """ Standardised timestamp for loss reporting """ - return time.strftime("%H:%M:%S") +import typing as T - @property - def landmarks_required(self): - """ Return True if Landmarks are required """ - opts = self.model.training_opts - retval = bool(opts.get("mask_type", None) or opts["warp_to_landmarks"]) - logger.debug(retval) - return retval +import torch - @property - def use_mask(self): - """ Return True if a mask is requested """ - retval = bool(self.model.training_opts.get("mask_type", None)) - logger.debug(retval) - return retval +if T.TYPE_CHECKING: + from plugins.train.model._base import ModelBase - def process_training_opts(self): - """ Override for processing model specific training options """ - logger.debug(self.model.training_opts) - if self.landmarks_required: - landmarks = Landmarks(self.model.training_opts).landmarks - self.model.training_opts["landmarks"] = landmarks +logger = logging.getLogger(__name__) - def set_tensorboard(self): - """ Set up tensorboard callback """ - if self.model.training_opts["no_logs"]: - logger.verbose("TensorBoard logging disabled") - return None - logger.debug("Enabling TensorBoard Logging") - tensorboard = dict() - for side in self.images.keys(): - logger.debug("Setting up TensorBoard Logging. Side: %s", side) - log_dir = os.path.join(str(self.model.model_dir), - "{}_logs".format(self.model.name), - side, - "session_{}".format(self.model.state.session_id)) - tbs = tf_keras.callbacks.TensorBoard(log_dir=log_dir, - histogram_freq=0, # Must be 0 or hangs - batch_size=self.batch_size, - write_graph=True, - write_grads=True) - tbs.set_model(self.model.predictors[side]) - tensorboard[side] = tbs - logger.info("Enabled TensorBoard Logging") - return tensorboard +class TrainerBase(abc.ABC): + """ A trainer plugin interface. It must implement the method "train_batch" which takes an input + of inputs to the model and target images for model output. It returns loss per side - def print_loss(self, loss): - """ Override for specific model loss formatting """ - output = list() - for side in sorted(list(loss.keys())): - display = ", ".join(["{}_{}: {:.5f}".format(self.model.state.loss_names[side][idx], - side.capitalize(), - this_loss) - for idx, this_loss in enumerate(loss[side])]) - output.append(display) - print("[{}] [#{:05d}] {}, {}".format( - self.timestamp, self.model.iterations, output[0], output[1]), end='\r') - - def train_one_step(self, viewer, timelapse_kwargs): - """ Train a batch """ - logger.trace("Training one step: (iteration: %s)", self.model.iterations) - is_preview_iteration = False if viewer is None else True - loss = dict() - for side, batcher in self.batchers.items(): - loss[side] = batcher.train_one_batch(is_preview_iteration) - if not is_preview_iteration: - continue - self.samples.images[side] = batcher.compile_sample(self.batch_size) - if timelapse_kwargs: - self.timelapse.get_sample(side, timelapse_kwargs) - - self.model.state.increment_iterations() - - for side, side_loss in loss.items(): - self.store_history(side, side_loss) - self.log_tensorboard(side, side_loss) - self.print_loss(loss) - - if viewer is not None: - viewer(self.samples.show_sample(), - "Training - 'S': Save Now. 'ENTER': Save and Quit") - - if timelapse_kwargs is not None: - self.timelapse.output_timelapse() - - def store_history(self, side, loss): - """ Store the history of this step """ - logger.trace("Updating loss history: '%s'", side) - self.model.history[side].append(loss[0]) # Either only loss or total loss - logger.trace("Updated loss history: '%s'", side) - - def log_tensorboard(self, side, loss): - """ Log loss to TensorBoard log """ - if not self.tensorboard: - return - logger.trace("Updating TensorBoard log: '%s'", side) - logs = {log[0]: log[1] - for log in zip(self.model.state.loss_names[side], loss)} - self.tensorboard[side].on_batch_end(self.model.state.iterations, logs) - logger.trace("Updated TensorBoard log: '%s'", side) - - def clear_tensorboard(self): - """ Indicate training end to Tensorboard """ - if not self.tensorboard: - return - for side, tensorboard in self.tensorboard.items(): - logger.debug("Ending Tensorboard. Side: '%s'", side) - tensorboard.on_train_end(None) - - -class Batcher(): - """ Batch images from a single side """ - def __init__(self, side, images, model, use_mask, batch_size): - logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)", - self.__class__.__name__, side, len(images), batch_size) + Parameters + ---------- + model : :class:`plugins.train.model.Base.ModelBase` + The model plugin + batch_size : int + The requested batch size for each iteration to be trained through the model. + """ + def __init__(self, model: ModelBase, batch_size: int) -> None: self.model = model - self.use_mask = use_mask - self.side = side - self.target = None - self.samples = None - self.mask = None - - self.feed = self.load_generator().minibatch_ab(images, batch_size, self.side) - self.timelapse_feed = None - - def load_generator(self): - """ Pass arguments to TrainingDataGenerator and return object """ - logger.debug("Loading generator: %s", self.side) - input_size = self.model.input_shape[0] - output_size = self.model.output_shape[0] - logger.debug("input_size: %s, output_size: %s", input_size, output_size) - generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts) - return generator - - def train_one_batch(self, is_preview_iteration): - """ Train a batch """ - logger.trace("Training one step: (side: %s)", self.side) - batch = self.get_next(is_preview_iteration) - loss = self.model.predictors[self.side].train_on_batch(*batch) - loss = loss if isinstance(loss, list) else [loss] - return loss - - def get_next(self, is_preview_iteration): - """ Return the next batch from the generator - Items should come out as: (warped, target [, mask]) """ - batch = next(self.feed) - self.samples = batch[0] if is_preview_iteration else None - batch = batch[1:] # Remove full size samples from batch - if self.use_mask: - batch = self.compile_mask(batch) - self.target = batch[1] if is_preview_iteration else None - return batch - - def compile_mask(self, batch): - """ Compile the mask into training data """ - logger.trace("Compiling Mask: (side: '%s')", self.side) - mask = batch[-1] - retval = list() - for idx in range(len(batch) - 1): - image = batch[idx] - retval.append([image, mask]) - return retval - - def compile_sample(self, batch_size, samples=None, images=None): - """ Training samples to display in the viewer """ - num_images = self.model.training_opts.get("preview_images", 14) - num_images = min(batch_size, num_images) - logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images) - images = images if images is not None else self.target - samples = [samples[0:num_images]] if samples is not None else [self.samples[0:num_images]] - if self.use_mask: - retval = [tgt[0:num_images] for tgt in images] - else: - retval = [images[0:num_images]] - retval = samples + retval - return retval - - def compile_timelapse_sample(self): - """ Timelapse samples """ - batch = next(self.timelapse_feed) - samples = batch[0] - batch = batch[1:] # Remove full size samples from batch - batchsize = len(samples) - if self.use_mask: - batch = self.compile_mask(batch) - images = batch[1] - sample = self.compile_sample(batchsize, samples=samples, images=images) - return sample - - def set_timelapse_feed(self, images, batchsize): - """ Set the timelapse dictionary """ - logger.debug("Setting timelapse feed: (side: '%s', input_images: '%s', batchsize: %s)", - self.side, images, batchsize) - self.timelapse_feed = self.load_generator().minibatch_ab(images[:batchsize], - batchsize, self.side, - do_shuffle=False, - is_timelapse=True) - logger.debug("Set timelapse feed") - - -class Samples(): - """ Display samples for preview and timelapse """ - def __init__(self, model, use_mask, coverage_ratio, scaling=1.0): - logger.debug("Initializing %s: model: '%s', use_mask: %s, coverage_ratio: %s)", - self.__class__.__name__, model, use_mask, coverage_ratio) - self.model = model - self.use_mask = use_mask - self.images = dict() - self.coverage_ratio = coverage_ratio - self.scaling = scaling - logger.debug("Initialized %s", self.__class__.__name__) - - def show_sample(self): - """ Display preview data """ - logger.debug("Showing sample") - feeds = dict() - figures = dict() - headers = dict() - for side, samples in self.images.items(): - faces = samples[1] - if self.model.input_shape[0] / faces.shape[1] != 1.0: - feeds[side] = self.resize_sample(side, faces, self.model.input_shape[0]) - feeds[side] = feeds[side].reshape((-1, ) + self.model.input_shape) - else: - feeds[side] = faces - if self.use_mask: - mask = samples[-1] - feeds[side] = [feeds[side], mask] - - preds = self.get_predictions(feeds["a"], feeds["b"]) - - for side, samples in self.images.items(): - other_side = "a" if side == "b" else "b" - predictions = [preds["{}_{}".format(side, side)], - preds["{}_{}".format(other_side, side)]] - display = self.to_full_frame(side, samples, predictions) - headers[side] = self.get_headers(side, other_side, display[0].shape[1]) - figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) - if self.images[side][0].shape[0] % 2 == 1: - figures[side] = np.concatenate([figures[side], - np.expand_dims(figures[side][0], 0)]) - - width = 4 - side_cols = width // 2 - if side_cols != 1: - headers = self.duplicate_headers(headers, side_cols) - - header = np.concatenate([headers["a"], headers["b"]], axis=1) - figure = np.concatenate([figures["a"], figures["b"]], axis=0) - height = int(figure.shape[0] / width) - figure = figure.reshape((width, height) + figure.shape[1:]) - figure = stack_images(figure) - figure = np.vstack((header, figure)) - - logger.debug("Compiled sample") - return np.clip(figure * 255, 0, 255).astype('uint8') - - @staticmethod - def resize_sample(side, sample, target_size): - """ Resize samples where predictor expects different shape from processed image """ - scale = target_size / sample.shape[1] - if scale == 1.0: - return sample - logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)", - side, sample.shape, target_size, scale) - interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA # pylint: disable=no-member - retval = np.array([cv2.resize(img, # pylint: disable=no-member - (target_size, target_size), - interpn) - for img in sample]) - logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) - return retval - - def get_predictions(self, feed_a, feed_b): - """ Return the sample predictions from the model """ - logger.debug("Getting Predictions") - preds = dict() - preds["a_a"] = self.model.predictors["a"].predict(feed_a) - preds["b_a"] = self.model.predictors["b"].predict(feed_a) - preds["a_b"] = self.model.predictors["a"].predict(feed_b) - preds["b_b"] = self.model.predictors["b"].predict(feed_b) - - # Get the returned image from predictors that emit multiple items - if not isinstance(preds["a_a"], np.ndarray): - for key, val in preds.items(): - preds[key] = val[0] - logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()}) - return preds - - def to_full_frame(self, side, samples, predictions): - """ Patch the images into the full frame """ - logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)", - side, len(samples), [pred.shape for pred in predictions]) - full, faces = samples[:2] - images = [faces] + predictions - full_size = full.shape[1] - target_size = int(full_size * self.coverage_ratio) - if target_size != full_size: - frame = self.frame_overlay(full, target_size, (0, 0, 255)) - - if self.use_mask: - images = self.compile_masked(images, samples[-1]) - images = [self.resize_sample(side, image, target_size) for image in images] - if target_size != full_size: - images = [self.overlay_foreground(frame, image) for image in images] - if self.scaling != 1.0: - new_size = int(full_size * self.scaling) - images = [self.resize_sample(side, image, new_size) for image in images] - return images - - @staticmethod - def frame_overlay(images, target_size, color): - """ Add roi frame to a backfround image """ - logger.debug("full_size: %s, target_size: %s, color: %s", - images.shape[1], target_size, color) - new_images = list() - full_size = images.shape[1] - padding = (full_size - target_size) // 2 - length = target_size // 4 - t_l, b_r = (padding, full_size - padding) - for img in images: - cv2.rectangle(img, # pylint: disable=no-member - (t_l, t_l), - (t_l + length, t_l + length), - color, - 3) - cv2.rectangle(img, # pylint: disable=no-member - (b_r, t_l), - (b_r - length, t_l + length), - color, - 3) - cv2.rectangle(img, # pylint: disable=no-member - (b_r, b_r), - (b_r - length, - b_r - length), - color, - 3) - cv2.rectangle(img, # pylint: disable=no-member - (t_l, b_r), - (t_l + length, b_r - length), - color, - 3) - new_images.append(img) - retval = np.array(new_images) - logger.debug("Overlayed background. Shape: %s", retval.shape) - return retval - - @staticmethod - def compile_masked(faces, masks): - """ Add the mask to the faces for masked preview """ - retval = list() - masks3 = np.tile(1 - np.rint(masks), 3) - for mask in masks3: - mask[np.where((mask == [1., 1., 1.]).all(axis=2))] = [0., 0., 1.] - for previews in faces: - images = np.array([cv2.addWeighted(img, 1.0, # pylint: disable=no-member - masks3[idx], 0.3, - 0) - for idx, img in enumerate(previews)]) - retval.append(images) - logger.debug("masked shapes: %s", [faces.shape for faces in retval]) - return retval - - @staticmethod - def overlay_foreground(backgrounds, foregrounds): - """ Overlay the training images into the center of the background """ - offset = (backgrounds.shape[1] - foregrounds.shape[1]) // 2 - new_images = list() - for idx, img in enumerate(backgrounds): - img[offset:offset + foregrounds[idx].shape[0], - offset:offset + foregrounds[idx].shape[1]] = foregrounds[idx] - new_images.append(img) - retval = np.array(new_images) - logger.debug("Overlayed foreground. Shape: %s", retval.shape) - return retval - - def get_headers(self, side, other_side, width): - """ Set headers for images """ - logger.debug("side: '%s', other_side: '%s', width: %s", - side, other_side, width) - side = side.upper() - other_side = other_side.upper() - height = int(64 * self.scaling) - total_width = width * 3 - logger.debug("height: %s, total_width: %s", height, total_width) - font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member - texts = ["Target {}".format(side), - "{} > {}".format(side, side), - "{} > {}".format(side, other_side)] - text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member - font, - self.scaling, - 1)[0] - for idx in range(len(texts))] - text_y = int((height + text_sizes[0][1]) / 2) - text_x = [int((width - text_sizes[idx][0]) / 2) + width * idx - for idx in range(len(texts))] - logger.debug("texts: %s, text_sizes: %s, text_x: %s, text_y: %s", - texts, text_sizes, text_x, text_y) - header_box = np.ones((height, total_width, 3), np.float32) - for idx, text in enumerate(texts): - cv2.putText(header_box, # pylint: disable=no-member - text, - (text_x[idx], text_y), - font, - self.scaling, - (0, 0, 0), - 1, - lineType=cv2.LINE_AA) # pylint: disable=no-member - logger.debug("header_box.shape: %s", header_box.shape) - return header_box - - @staticmethod - def duplicate_headers(headers, columns): - """ Duplicate headers for the number of columns displayed """ - for side, header in headers.items(): - duped = tuple([header for _ in range(columns)]) - headers[side] = np.concatenate(duped, axis=1) - logger.debug("side: %s header.shape: %s", side, header.shape) - return headers - - -class Timelapse(): - """ Create the timelapse """ - def __init__(self, model, use_mask, coverage_ratio, batchers): - logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, " - "batchers: '%s')", self.__class__.__name__, model, use_mask, - coverage_ratio, batchers) - self.samples = Samples(model, use_mask, coverage_ratio) - self.model = model - self.batchers = batchers - self.output_file = None - logger.debug("Initialized %s", self.__class__.__name__) - - def get_sample(self, side, timelapse_kwargs): - """ Perform timelapse """ - logger.debug("Getting timelapse samples: '%s'", side) - if not self.output_file: - self.setup(**timelapse_kwargs) - self.samples.images[side] = self.batchers[side].compile_timelapse_sample() - logger.debug("Got timelapse samples: '%s' - %s", side, len(self.samples.images[side])) - - def setup(self, input_a=None, input_b=None, output=None): - """ Set the timelapse output folder """ - logger.debug("Setting up timelapse") - if output is None: - output = str(get_folder(os.path.join(str(self.model.model_dir), - "{}_timelapse".format(self.model.name)))) - self.output_file = str(output) - logger.debug("Timelapse output set to '%s'", self.output_file) - - images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)} - batchsize = min(len(images["a"]), - len(images["b"]), - self.model.training_opts.get("preview_images", 14)) - for side, image_files in images.items(): - self.batchers[side].set_timelapse_feed(image_files, batchsize) - logger.debug("Set up timelapse") - - def output_timelapse(self): - """ Set the timelapse dictionary """ - logger.debug("Ouputting timelapse") - image = self.samples.show_sample() - filename = os.path.join(self.output_file, str(int(time.time())) + ".jpg") - - cv2.imwrite(filename, image) # pylint: disable=no-member - logger.debug("Created timelapse: '%s'", filename) - - -class Landmarks(): - """ Set Landmarks for training into the model's training options""" - def __init__(self, training_opts): - logger.debug("Initializing %s: (training_opts: '%s')", - self.__class__.__name__, training_opts) - self.size = training_opts.get("training_size", 256) - self.paths = training_opts["alignments"] - self.landmarks = self.get_alignments() - logger.debug("Initialized %s", self.__class__.__name__) - - def get_alignments(self): - """ Obtain the landmarks for each faceset """ - landmarks = dict() - for side, fullpath in self.paths.items(): - path, filename = os.path.split(fullpath) - filename, extension = os.path.splitext(filename) - serializer = extension[1:] - alignments = Alignments( - path, - filename=filename, - serializer=serializer) - landmarks[side] = self.transform_landmarks(alignments) - return landmarks - - def transform_landmarks(self, alignments): - """ For each face transform landmarks and return """ - landmarks = dict() - for _, faces, _, _ in alignments.yield_faces(): - for face in faces: - detected_face = DetectedFace() - detected_face.from_alignment(face) - detected_face.load_aligned(None, size=self.size, align_eyes=False) - landmarks[detected_face.hash] = detected_face.aligned_landmarks - return landmarks + """:class:`plugins.train.model.Base.ModelBase` : The model plugin to train the batch on""" + self.batch_size = batch_size + """int : The batch size for each iteration to be trained through the model.""" + + @abc.abstractmethod + def train_batch(self, inputs: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor: + """Override to run a single forward and backwards pass through the model for a single + batch + + Parameters + ---------- + inputs : :class:`torch.Tensor` + The batch of input image tensors to the model in shape `(side, batch_size, + *dims)` with `side` 0 being input A and `side` 1 being input B + targets : list[:class:`torch.Tensor`] + The corresponding batch of target images for the model for each side's output(s). For + each model output an array should exist in the order of model outputs in the format `( + side, batch_size, *dims)` where `side` 0 is "A" and `side` 1 is "B" + + Returns + ------- + :class:`torch.Tensor` + The loss for each side of this batch in layout (A1, ..., An, B1, ..., Bn) + """ diff --git a/plugins/train/trainer/_display.py b/plugins/train/trainer/_display.py new file mode 100644 index 0000000000..0e3236cb34 --- /dev/null +++ b/plugins/train/trainer/_display.py @@ -0,0 +1,626 @@ +#!/usr/bin/env python3 +""" Handles the creation of display images for preview window and timelapses """ +from __future__ import annotations + +import logging +import time +import typing as T +import os + +import cv2 +import numpy as np +import torch + +from lib.image import hex_to_rgb +from lib.utils import get_folder, get_image_paths, get_module_objects +from plugins.train import train_config as cfg + +if T.TYPE_CHECKING: + from keras import KerasTensor + from lib.training import Feeder + from plugins.train.model._base import ModelBase + +logger = logging.getLogger(__name__) + + +class Samples(): + """ Compile samples for display for preview and time-lapse + + Parameters + ---------- + model: plugin from :mod:`plugins.train.model` + The selected model that will be running this trainer + coverage_ratio: float + Ratio of face to be cropped out of the training image. + mask_opacity: int + The opacity (as a percentage) to use for the mask overlay + mask_color: str + The hex RGB value to use the mask overlay + + Attributes + ---------- + images: dict + The :class:`numpy.ndarray` training images for generating previews on each side. The + dictionary should contain 2 keys ("a" and "b") with the values being the training images + for generating samples corresponding to each side. + """ + def __init__(self, + model: ModelBase, + coverage_ratio: float, + mask_opacity: int, + mask_color: str) -> None: + logger.debug("Initializing %s: model: '%s', coverage_ratio: %s, mask_opacity: %s, " + "mask_color: %s)", + self.__class__.__name__, model, coverage_ratio, mask_opacity, mask_color) + self._model = model + self._display_mask = cfg.Loss.learn_mask() or cfg.Loss.penalized_mask_loss() + self.images: dict[T.Literal["a", "b"], list[np.ndarray]] = {} + self._coverage_ratio = coverage_ratio + self._mask_opacity = mask_opacity / 100.0 + self._mask_color = np.array(hex_to_rgb(mask_color))[..., 2::-1] / 255. + logger.debug("Initialized %s", self.__class__.__name__) + + def toggle_mask_display(self) -> None: + """ Toggle the mask overlay on or off depending on user input. """ + if not (cfg.Loss.learn_mask() or cfg.Loss.penalized_mask_loss()): + return + display_mask = not self._display_mask + print("\x1b[2K", end="\r") # Clear last line + logger.info("Toggling mask display %s...", "on" if display_mask else "off") + self._display_mask = display_mask + + def show_sample(self) -> np.ndarray: + """ Compile a preview image. + + Returns + ------- + :class:`numpy.ndarry` + A compiled preview image ready for display or saving + """ + logger.debug("Showing sample") + feeds: dict[T.Literal["a", "b"], np.ndarray] = {} + for idx, side in enumerate(T.get_args(T.Literal["a", "b"])): + feed = self.images[side][0] + input_shape = self._model.model.input_shape[idx][1:] + if input_shape[0] / feed.shape[1] != 1.0: + feeds[side] = self._resize_sample(side, feed, input_shape[0]) + else: + feeds[side] = feed + + preds = self._get_predictions(feeds["a"], feeds["b"]) + return self._compile_preview(preds) + + @classmethod + def _resize_sample(cls, + side: T.Literal["a", "b"], + sample: np.ndarray, + target_size: int) -> np.ndarray: + """ Resize a given image to the target size. + + Parameters + ---------- + side: str + The side ("a" or "b") that the samples are being generated for + sample: :class:`numpy.ndarray` + The sample to be resized + target_size: int + The size that the sample should be resized to + + Returns + ------- + :class:`numpy.ndarray` + The sample resized to the target size + """ + scale = target_size / sample.shape[1] + if scale == 1.0: + # cv2 complains if we don't do this :/ + return np.ascontiguousarray(sample) + logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)", + side, sample.shape, target_size, scale) + interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + retval = np.array([cv2.resize(img, (target_size, target_size), interpolation=interpn) + for img in sample]) + logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) + return retval + + def _filter_multiscale_output(self, standard: list[KerasTensor], swapped: list[KerasTensor] + ) -> tuple[list[KerasTensor], list[KerasTensor]]: + """ Only return the largest predictions if the model has multi-scaled output + + Parameters + ---------- + standard: list[:class:`keras.KerasTensor`] + The standard output from the model + swapped: list[:class:`keras.KerasTensor`] + The swapped output from the model + + Returns + ------- + standard: list[:class:`keras.KerasTensor`] + The standard output from the model, filtered to just the largest output + swapped: list[:class:`keras.KerasTensor`] + The swapped output from the model, filtered to just the largest output + """ + sizes = T.cast(set[int], set(p.shape[1] for p in standard)) + if len(sizes) == 1: + return standard, swapped + logger.debug("Received outputs. standard: %s, swapped: %s", + [s.shape for s in standard], [s.shape for s in swapped]) + logger.debug("Stripping multi-scale outputs for sizes %s", sizes) + standard = [s for s in standard if s.shape[1] == max(sizes)] + swapped = [s for s in swapped if s.shape[1] == max(sizes)] + logger.debug("Stripped outputs. standard: %s, swapped: %s", + [s.shape for s in standard], [s.shape for s in swapped]) + return standard, swapped + + def _collate_output(self, standard: list[torch.Tensor], swapped: list[torch.Tensor] + ) -> tuple[list[np.ndarray], list[np.ndarray]]: + """ Merge the mask onto the preview image's 4th channel if learn mask is selected. + Return as numpy array + + Parameters + ---------- + standard: list[:class:`torch.Tensor`] + The standard output from the model + swapped: list[:class:`torch.Tensor`] + The swapped output from the model + + Returns + ------- + standard: list[:class:`numpy.ndarray`] + The standard output from the model, with mask merged + swapped: list[:class:`numpy.ndarray`] + The swapped output from the model, with mask merged + """ + logger.debug("Received tensors. standard: %s, swapped: %s", + [s.shape for s in standard], [s.shape for s in swapped]) + + # Pull down outputs + nstandard = [p.cpu().detach().numpy() for p in standard] + nswapped = [p.cpu().detach().numpy() for p in swapped] + + if cfg.Loss.learn_mask(): # Add mask to 4th channel of final output + nstandard = [np.concatenate(nstandard[idx * 2: (idx * 2) + 2], axis=-1) + for idx in range(2)] + nswapped = [np.concatenate(nswapped[idx * 2: (idx * 2) + 2], axis=-1) + for idx in range(2)] + logger.debug("Collated output. standard: %s, swapped: %s", + [(s.shape, s.dtype) for s in nstandard], + [(s.shape, s.dtype) for s in nswapped]) + return nstandard, nswapped + + def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray + ) -> dict[T.Literal["a_a", "a_b", "b_b", "b_a"], np.ndarray]: + """ Feed the samples to the model and return predictions + + Parameters + ---------- + feed_a: :class:`numpy.ndarray` + Feed images for the "a" side + feed_a: :class:`numpy.ndarray` + Feed images for the "b" side + + Returns + ------- + list: + List of :class:`numpy.ndarray` of predictions received from the model + """ + logger.debug("Getting Predictions") + preds: dict[T.Literal["a_a", "a_b", "b_b", "b_a"], np.ndarray] = {} + + with torch.inference_mode(): + standard = self._model.model([feed_a, feed_b]) + swapped = self._model.model([feed_b, feed_a]) + + standard, swapped = self._filter_multiscale_output(standard, swapped) + standard, swapped = self._collate_output(standard, swapped) + + preds["a_a"] = standard[0] + preds["b_b"] = standard[1] + preds["a_b"] = swapped[0] + preds["b_a"] = swapped[1] + + logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()}) + return preds + + def _compile_preview(self, predictions: dict[T.Literal["a_a", "a_b", "b_b", "b_a"], np.ndarray] + ) -> np.ndarray: + """ Compile predictions and images into the final preview image. + + Parameters + ---------- + predictions: dict[Literal["a_a", "a_b", "b_b", "b_a"], np.ndarray + The predictions from the model + + Returns + ------- + :class:`numpy.ndarry` + A compiled preview image ready for display or saving + """ + figures: dict[T.Literal["a", "b"], np.ndarray] = {} + headers: dict[T.Literal["a", "b"], np.ndarray] = {} + + for side, samples in self.images.items(): + other_side = "a" if side == "b" else "b" + preds = [predictions[T.cast(T.Literal["a_a", "a_b", "b_b", "b_a"], + f"{side}_{side}")], + predictions[T.cast(T.Literal["a_a", "a_b", "b_b", "b_a"], + f"{other_side}_{side}")]] + display = self._to_full_frame(side, samples, preds) + headers[side] = self._get_headers(side, display[0].shape[1]) + figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) + if self.images[side][1].shape[0] % 2 == 1: + figures[side] = np.concatenate([figures[side], + np.expand_dims(figures[side][0], 0)]) + + width = 4 + if width // 2 != 1: + headers = self._duplicate_headers(headers, width // 2) + + header = np.concatenate([headers["a"], headers["b"]], axis=1) + figure = np.concatenate([figures["a"], figures["b"]], axis=0) + height = int(figure.shape[0] / width) + figure = figure.reshape((width, height) + figure.shape[1:]) + figure = _stack_images(figure) + figure = np.concatenate((header, figure), axis=0) + + logger.debug("Compiled sample") + return np.clip(figure * 255, 0, 255).astype('uint8') + + def _to_full_frame(self, + side: T.Literal["a", "b"], + samples: list[np.ndarray], + predictions: list[np.ndarray]) -> list[np.ndarray]: + """ Patch targets and prediction images into images of model output size. + + Parameters + ---------- + side: {"a" or "b"} + The side that these samples are for + samples: list + List of :class:`numpy.ndarray` of feed images and sample images + predictions: list + List of :class: `numpy.ndarray` of predictions from the model + + Returns + ------- + list + The images resized and collated for display in the preview frame + """ + logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)", + side, len(samples), [pred.shape for pred in predictions]) + faces, full = samples[:2] + + if self._model.color_order.lower() == "rgb": # Switch color order for RGB model display + full = full[..., ::-1] + faces = faces[..., ::-1] + predictions = [pred[..., 2::-1] for pred in predictions] + + full = self._process_full(side, full, predictions[0].shape[1], (0., 0., 1.0)) + images = [faces] + predictions + + if self._display_mask: + images = self._compile_masked(images, samples[-1]) + elif cfg.Loss.learn_mask(): + # Remove masks when learn mask is selected but mask toggle is off + images = [batch[..., :3] for batch in images] + + images = [self._overlay_foreground(full.copy(), image) for image in images] + + return images + + def _process_full(self, + side: T.Literal["a", "b"], + images: np.ndarray, + prediction_size: int, + color: tuple[float, float, float]) -> np.ndarray: + """ Add a frame overlay to preview images indicating the region of interest. + + This applies the red border that appears in the preview images. + + Parameters + ---------- + side: {"a" or "b"} + The side that these samples are for + images: :class:`numpy.ndarray` + The input training images to to process + prediction_size: int + The size of the predicted output from the model + color: tuple + The (Blue, Green, Red) color to use for the frame + + Returns + ------- + :class:`numpy,ndarray` + The input training images, sized for output and annotated for coverage + """ + logger.debug("full_size: %s, prediction_size: %s, color: %s", + images.shape[1], prediction_size, color) + + display_size = int((prediction_size / self._coverage_ratio // 2) * 2) + images = self._resize_sample(side, images, display_size) # Resize targets to display size + padding = (display_size - prediction_size) // 2 + if padding == 0: + logger.debug("Resized background. Shape: %s", images.shape) + return images + + length = display_size // 4 + t_l, b_r = (padding - 1, display_size - padding) + for img in images: + cv2.rectangle(img, (t_l, t_l), (t_l + length, t_l + length), color, 1) + cv2.rectangle(img, (b_r, t_l), (b_r - length, t_l + length), color, 1) + cv2.rectangle(img, (b_r, b_r), (b_r - length, b_r - length), color, 1) + cv2.rectangle(img, (t_l, b_r), (t_l + length, b_r - length), color, 1) + logger.debug("Overlayed background. Shape: %s", images.shape) + return images + + def _compile_masked(self, faces: list[np.ndarray], masks: np.ndarray) -> list[np.ndarray]: + """ Add the mask to the faces for masked preview. + + Places an opaque red layer over areas of the face that are masked out. + + Parameters + ---------- + faces: list + The :class:`numpy.ndarray` sample faces and predictions that are to have the mask + applied + masks: :class:`numpy.ndarray` + The masks that are to be applied to the faces + + Returns + ------- + list + List of :class:`numpy.ndarray` faces with the opaque mask layer applied + """ + orig_masks = 1. - masks + masks3: list[np.ndarray] | np.ndarray = [] + + if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions + pred_masks = [1. - face[..., -1][..., None] for face in faces[-2:]] + faces[-2:] = [face[..., :-1] for face in faces[-2:]] + masks3 = [orig_masks, *pred_masks] + else: + masks3 = np.repeat(np.expand_dims(orig_masks, axis=0), 3, axis=0) + + retval: list[np.ndarray] = [] + overlays3 = np.ones_like(faces) * self._mask_color + for previews, overlays, compiled_masks in zip(faces, overlays3, masks3): + compiled_masks *= self._mask_opacity + overlays *= compiled_masks + previews *= (1. - compiled_masks) + retval.append(previews + overlays) + logger.debug("masked shapes: %s", [faces.shape for faces in retval]) + return retval + + @classmethod + def _overlay_foreground(cls, backgrounds: np.ndarray, foregrounds: np.ndarray) -> np.ndarray: + """ Overlay the preview images into the center of the background images + + Parameters + ---------- + backgrounds: :class:`numpy.ndarray` + Background images for placing the preview images onto + backgrounds: :class:`numpy.ndarray` + Preview images for placing onto the background images + + Returns + ------- + :class:`numpy.ndarray` + The preview images compiled into the full frame size for each preview + """ + offset = (backgrounds.shape[1] - foregrounds.shape[1]) // 2 + for foreground, background in zip(foregrounds, backgrounds): + background[offset:offset + foreground.shape[0], + offset:offset + foreground.shape[1], :3] = foreground + logger.debug("Overlayed foreground. Shape: %s", backgrounds.shape) + return backgrounds + + @classmethod + def _get_headers(cls, side: T.Literal["a", "b"], width: int) -> np.ndarray: + """ Set header row for the final preview frame + + Parameters + ---------- + side: {"a" or "b"} + The side that the headers should be generated for + width: int + The width of each column in the preview frame + + Returns + ------- + :class:`numpy.ndarray` + The column headings for the given side + """ + logger.debug("side: '%s', width: %s", + side, width) + titles = ("Original", "Swap") if side == "a" else ("Swap", "Original") + height = int(width / 4.5) + total_width = width * 3 + logger.debug("height: %s, total_width: %s", height, total_width) + font = cv2.FONT_HERSHEY_SIMPLEX + texts = [f"{titles[0]} ({side.upper()})", + f"{titles[0]} > {titles[0]}", + f"{titles[0]} > {titles[1]}"] + scaling = (width / 144) * 0.45 + text_sizes = [cv2.getTextSize(texts[idx], font, scaling, 1)[0] + for idx in range(len(texts))] + text_y = int((height + text_sizes[0][1]) / 2) + text_x = [int((width - text_sizes[idx][0]) / 2) + width * idx + for idx in range(len(texts))] + logger.debug("texts: %s, text_sizes: %s, text_x: %s, text_y: %s", + texts, text_sizes, text_x, text_y) + header_box = np.ones((height, total_width, 3), np.float32) + for idx, text in enumerate(texts): + cv2.putText(header_box, + text, + (text_x[idx], text_y), + font, + scaling, + (0, 0, 0), + 1, + lineType=cv2.LINE_AA) + logger.debug("header_box.shape: %s", header_box.shape) + return header_box + + @classmethod + def _duplicate_headers(cls, + headers: dict[T.Literal["a", "b"], np.ndarray], + columns: int) -> dict[T.Literal["a", "b"], np.ndarray]: + """ Duplicate headers for the number of columns displayed for each side. + + Parameters + ---------- + headers: dict + The headers to be duplicated for each side + columns: int + The number of columns that the header needs to be duplicated for + + Returns + ------- + :class:dict + The original headers duplicated by the number of columns for each side + """ + for side, header in headers.items(): + duped = tuple(header for _ in range(columns)) + headers[side] = np.concatenate(duped, axis=1) + logger.debug("side: %s header.shape: %s", side, header.shape) + return headers + + +class Timelapse(): + """ Create a time-lapse preview image. + + Parameters + ---------- + model: plugin from :mod:`plugins.train.model` + The selected model that will be running this trainer + coverage_ratio: float + Ratio of face to be cropped out of the training image. + image_count: int + The number of preview images to be displayed in the time-lapse + mask_opacity: int + The opacity (as a percentage) to use for the mask overlay + mask_color: str + The hex RGB value to use the mask overlay + feeder: :class:`~lib.training.generator.Feeder` + The feeder for generating the time-lapse images. + image_paths: dict + The full paths to the training images for each side of the model + """ + def __init__(self, + model: ModelBase, + coverage_ratio: float, + image_count: int, + mask_opacity: int, + mask_color: str, + feeder: Feeder, + image_paths: dict[T.Literal["a", "b"], list[str]]) -> None: + logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, " + "mask_opacity: %s, mask_color: %s, feeder: %s, image_paths: %s)", + self.__class__.__name__, model, coverage_ratio, image_count, mask_opacity, + mask_color, feeder, len(image_paths)) + self._num_images = image_count + self._samples = Samples(model, coverage_ratio, mask_opacity, mask_color) + self._model = model + self._feeder = feeder + self._image_paths = image_paths + self._output_file = "" + logger.debug("Initialized %s", self.__class__.__name__) + + def _setup(self, input_a: str, input_b: str, output: str) -> None: + """ Setup the time-lapse folder locations and the time-lapse feed. + + Parameters + ---------- + input_a: str + The full path to the time-lapse input folder containing faces for the "a" side + input_b: str + The full path to the time-lapse input folder containing faces for the "b" side + output: str, optional + The full path to the time-lapse output folder. If ``None`` is provided this will + default to the model folder + """ + logger.debug("Setting up time-lapse") + if not output: + output = get_folder(os.path.join(str(self._model.io.model_dir), + f"{self._model.name}_timelapse")) + self._output_file = output + logger.debug("Time-lapse output set to '%s'", self._output_file) + + # Rewrite paths to pull from the training images so mask and face data can be accessed + images: dict[T.Literal["a", "b"], list[str]] = {} + for side, input_ in zip(T.get_args(T.Literal["a", "b"]), (input_a, input_b)): + training_path = os.path.dirname(self._image_paths[side][0]) + images[side] = [os.path.join(training_path, os.path.basename(pth)) + for pth in get_image_paths(input_)] + + batchsize = min(len(images["a"]), + len(images["b"]), + self._num_images) + self._feeder.set_timelapse_feed(images, batchsize) + logger.debug("Set up time-lapse") + + def output_timelapse(self, timelapse_kwargs: dict[T.Literal["input_a", + "input_b", + "output"], str]) -> None: + """ Generate the time-lapse samples and output the created time-lapse to the specified + output folder. + + Parameters + ---------- + timelapse_kwargs: dict: + The keyword arguments for setting up the time-lapse. All values should be full paths + the keys being `input_a`, `input_b`, `output` + """ + logger.debug("Ouputting time-lapse") + if not self._output_file: + self._setup(**T.cast(dict[str, str], timelapse_kwargs)) + + logger.debug("Getting time-lapse samples") + self._samples.images = self._feeder.generate_preview(is_timelapse=True) + logger.debug("Got time-lapse samples: %s", + {side: len(images) for side, images in self._samples.images.items()}) + + image = self._samples.show_sample() + if image is None: + return + filename = os.path.join(self._output_file, str(int(time.time())) + ".jpg") + + cv2.imwrite(filename, image) + logger.debug("Created time-lapse: '%s'", filename) + + +def _stack_images(images: np.ndarray) -> np.ndarray: + """ Stack images evenly for preview. + + Parameters + ---------- + images: :class:`numpy.ndarray` + The preview images to be stacked + + Returns + ------- + :class:`numpy.ndarray` + The stacked preview images + """ + logger.debug("Stack images") + + def get_transpose_axes(num): + if num % 2 == 0: + logger.debug("Even number of images to stack") + y_axes = list(range(1, num - 1, 2)) + x_axes = list(range(0, num - 1, 2)) + else: + logger.debug("Odd number of images to stack") + y_axes = list(range(0, num - 1, 2)) + x_axes = list(range(1, num - 1, 2)) + return y_axes, x_axes, [num - 1] + + images_shape = np.array(images.shape) + new_axes = get_transpose_axes(len(images_shape)) + new_shape = [np.prod(images_shape[x]) for x in new_axes] + logger.debug("Stacked images") + return np.transpose(images, axes=np.concatenate(new_axes)).reshape(new_shape) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/trainer/distributed.py b/plugins/train/trainer/distributed.py new file mode 100644 index 0000000000..ee1a877b42 --- /dev/null +++ b/plugins/train/trainer/distributed.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" Original Trainer """ +from __future__ import annotations +import logging +import typing as T +import warnings + +from keras import ops +import torch + + +from lib.utils import get_module_objects +from .original import Trainer as OriginalTrainer + +if T.TYPE_CHECKING: + from plugins.train.model._base import ModelBase + import keras + +logger = logging.getLogger(__name__) + + +class WrappedModel(torch.nn.Module): + """ A torch module that wraps a dual input Faceswap model with a single input version that is + compatible with DataParallel training + + Parameters + ---------- + model : :class:`keras.Model` + The original faceswap model that is to be wrapped + """ + def __init__(self, model: keras.Model): + logger.debug("Wrapping keras model: %s", model.name) + super().__init__() + self._keras_model = model + logger.debug("Wrapped keras model: %s (%s)", model.name, self) + + def forward(self, + input_a: torch.Tensor, + input_b: torch.Tensor, + targets_a: torch.Tensor, + targets_b: torch.Tensor, + *targets: torch.Tensor) -> torch.Tensor: + """ Run the forward pass per GPU + + Parameters + ---------- + input_a : :class:`torch.Tensor` + The A batch of input images for 1 GPU + input_b : :class:`torch.Tensor` + The B batch of input images for 1 GPU + targets_a : :class:`torch.Tensor` | list[torch.Tensor] + The A batch of target images for 1 GPU. If this is a multi-output model then this list + will be the target images per output for all items in the current batch, regardless of + GPU. If we have 1 output, this will be a Tensor for this GPUs current batch output + targets_b : :class:`torch.Tensor` | list[torch.Tensor] + The B batch of target images for 1 GPU. If this is a multi-output model then this list + will be the target images per output for all items in the current batch, regardless of + GPU. If we have 1 output, this will be a Tensor for this GPUs current batch output + targets : :class:`torch.Tensor` | list[torch.Tensor], optional + Used for multi-output models. Any additional outputs can be added here. They should be + added in A-B order + + + Returns + ------- + :class:`torch.Tensor` + The loss outputs for each side of the model for 1 GPU + """ + preds = self._keras_model((input_a, input_b), training=True) + self._keras_model.zero_grad() + + if targets: # Go from [A1, B1, A2, B2, A3, B3] to [A1, A2, A3, B1, B2, B3] + all_targets = [targets_a, targets_b, *targets] + assert len(all_targets) % 2 == 0 + loss_targets = all_targets[0::2] + all_targets[1::2] + else: + loss_targets = [targets_a, targets_b] + + losses = torch.stack([loss_fn(y_true, y_pred) + for loss_fn, y_true, y_pred in zip(self._keras_model.loss, + loss_targets, + preds)]) + logger.trace("Losses: %s", losses) # type:ignore[attr-defined] + return losses + + +class Trainer(OriginalTrainer): + """ Distributed training with torch.nn.DataParallel + + Parameters + ---------- + model : plugin from :mod:`plugins.train.model` + The model that will be running this trainer + batch_size : int + The requested batch size for iteration to be trained through the model. + """ + def __init__(self, model: ModelBase, batch_size: int) -> None: + + self._gpu_count = torch.cuda.device_count() + batch_size = self._validate_batch_size(batch_size) + self._is_multi_out: bool | None = None + + super().__init__(model, batch_size) + + self._distributed_model = self._set_distributed() + + def _validate_batch_size(self, batch_size: int) -> int: + """ Validate that the batch size is suitable for the number of GPUs and update accordingly. + + Parameters + ---------- + batch_size : int + The requested training batch size + + Returns + ------- + int + A valid batch size for the GPU configuration + """ + if batch_size < self._gpu_count: + logger.warning("Batch size (%s) is less than the number of GPUs (%s). Updating batch " + "size to: %s", batch_size, self._gpu_count, self._gpu_count) + batch_size = self._gpu_count + if batch_size % self._gpu_count: + new_batch_size = (batch_size // self._gpu_count) * self._gpu_count + logger.warning("Batch size %s is sub-optimal for %s GPUs. You may want to adjust your " + "batch size to %s or %s.", + batch_size, + self._gpu_count, + new_batch_size, + new_batch_size + self._gpu_count) + return batch_size + + def _handle_torch_gpu_mismatch_warning( + self, warn_messages: list[warnings.WarningMessage] | None) -> None: + """ Handle the warning generated by Torch when significantly mismatched GPUs are used and + remove potentially confusing information not relevant for Faceswap + + Parameters + ---------- + warn_messages : list[:class:`warnings.WarningMessage] + Any qualifying warning messages that may have been generated when wrapping the model + """ + if warn_messages is None or not warn_messages: + return + warn_msg = warn_messages[0] + terminate = "You can do so by" + msg = "" + for x in str(warn_msg.message).split("\n"): + x = x.strip() + if not x: + continue + if terminate in msg: + msg = msg[:msg.find(terminate)] + break + msg += f" {x}" + logger.warning(msg.strip()) + + def _set_distributed(self) -> torch.nn.DataParallel: + """Wrap the loaded model in a torch.nn.DataParallel instance + + Returns + ------- + :class:`torch.nn.Parallel` + A wrapped version of the faceswap model compatible with distributed training + """ + name = self.model.model.name + logger.debug("Setting distributed training for '%s'", name) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings("default", + message="There is an imbalance between your GPUs", + category=UserWarning) + # We already set CUDA_VISIBLE_DEVICES from -X command line flag, so just need to wrap + wrapped = torch.nn.DataParallel(WrappedModel(model=self.model.model)) + self._handle_torch_gpu_mismatch_warning(w) + + logger.info("Distributed training enabled. Model: '%s', devices: %s", + name, wrapped.device_ids) + return wrapped + + def _forward(self, + inputs: torch.Tensor, + targets: list[torch.Tensor]) -> torch.Tensor: + """ Perform the forward pass on the model + + Parameters + ---------- + inputs : :class:`torch.Tensor` + The batch of input image tensors to the model in shape `(side, batch_size, + *dims)` with `side` 0 being input A and `side` 1 being input B + targets : list[:class:`torch.Tensor`] + The corresponding batch of target images for the model for each side's output(s). For + each model output an array should exist in the order of model outputs in the format `( + side, batch_size, *dims)` with `side` 0 being input A and `side` 1 being input B + + Returns + ------- + :class:`torch.Tensor` + The loss for each side of this batch in layout (A1, ..., An, B1, ..., Bn) + """ + if self._is_multi_out is None: + self._is_multi_out = len(targets) > 1 + logger.debug("Setting multi-out to: %s", self._is_multi_out) + + if self._is_multi_out: + multi_targets = tuple(t[i] for t in targets[1:] for i in range(2)) + else: + multi_targets = () + + loss: torch.Tensor = self._distributed_model(inputs[0], + inputs[1], + targets[0][0], + targets[0][1], + *multi_targets) + scaled = T.cast(torch.Tensor, ops.sum(ops.reshape(loss, (self._gpu_count, 2, -1)), + axis=0) / self._gpu_count) + return scaled.flatten() + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/trainer/original.py b/plugins/train/trainer/original.py index 14cb9afbaf..0b5164eef2 100644 --- a/plugins/train/trainer/original.py +++ b/plugins/train/trainer/original.py @@ -1,4 +1,97 @@ #!/usr/bin/env python3 """ Original Trainer """ +from __future__ import annotations -from ._base import TrainerBase as Trainer +import logging +import typing as T + +from keras import ops +from keras.src.tree import flatten +import torch + +from lib.utils import get_module_objects +from ._base import TrainerBase + + +logger = logging.getLogger(__name__) + + +class Trainer(TrainerBase): + """ Original trainer """ + + def _forward(self, + inputs: torch.Tensor, + targets: list[torch.Tensor]) -> torch.Tensor: + """ Perform the forward pass on the model + + Parameters + ---------- + inputs : :class:`torch.Tensor` + The batch of input image tensors to the model in shape `(side, batch_size, + *dims)` with `side` 0 being input A and `side` 1 being input B + targets : list[:class:`torch.Tensor`] + The corresponding batch of target images for the model for each side's output(s). For + each model output an array should exist in the order of model outputs in the format `( + side, batch_size, *dims)` with `side` 0 being input A and `side` 1 being input B + + Returns + ------- + :class:`torch.Tensor` + The loss for each side of this batch in layout (A1, ..., An, B1, ..., Bn) + """ + feed_targets = [[t[i] for t in targets] for i in range(2)] + preds = self.model.model((inputs[0], inputs[1]), training=True) + self.model.model.zero_grad() + + losses = torch.stack([loss_fn(y_true, y_pred) + for loss_fn, y_true, y_pred in zip(self.model.model.loss, + flatten(feed_targets), + preds)]) + logger.trace("Losses: %s", losses) # type:ignore[attr-defined] + return losses + + def _backwards_and_apply(self, all_loss: torch.Tensor) -> None: + """ Perform the backwards pass on the model + + Parameters + ---------- + all_loss : :class:`torch.Tensor` + The loss for each output from the model + """ + total_loss = T.cast(torch.Tensor, + self.model.model.optimizer.scale_loss(ops.sum(all_loss))) + total_loss.backward() + + trainable_weights = self.model.model.trainable_weights[:] + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + self.model.model.optimizer.apply(gradients, trainable_weights) + + def train_batch(self, + inputs: torch.Tensor, + targets: list[torch.Tensor]) -> torch.Tensor: + """Run a single forward and backwards pass through the model for a single batch + + Parameters + ---------- + inputs : :class:`torch.Tensor` + The batch of input image tensors to the model in shape `(side, batch_size, + *dims)` with `side` 0 being input A and `side` 1 being input B + targets : list[:class:`torch.Tensor`] + The corresponding batch of target images for the model for each side's output(s). For + each model output an array should exist in the order of model outputs in the format `( + side, batch_size, *dims)` with `side` 0 being input A and `side` 1 being input B + + Returns + ------- + :class:`torch.Tensor` + The loss for each side of this batch in layout (A1, ..., An, B1, ..., Bn) + """ + loss_tensor = self._forward(inputs, targets) + self._backwards_and_apply(loss_tensor) + return loss_tensor + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/trainer/trainer_config.py b/plugins/train/trainer/trainer_config.py new file mode 100644 index 0000000000..08d49656f4 --- /dev/null +++ b/plugins/train/trainer/trainer_config.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" Default configurations for trainers """ +import gettext +import logging + +from lib.config import ConfigItem +from lib.utils import get_module_objects + +logger = logging.getLogger(__name__) + + +# LOCALES +_LANG = gettext.translation("plugins.train.trainer.train_config", + localedir="locales", fallback=True) +_ = _LANG.gettext + + +def get_defaults() -> tuple[str, str, dict[str, ConfigItem]]: + """ Obtain the default values for adding to the config.ini file + + Returns + ------- + helptext : str + The help text for the training config section + section : str + The section name for the config items + defaults : dict[str, :class:`lib.config.objects.ConfigItem`] + The option names and config items + """ + section = "trainer.augmentation" + helptext = _( + "Data Augmentation Options.\n" + "WARNING: The defaults for augmentation will be fine for 99.9% of use cases. " + "Only change them if you absolutely know what you are doing!") + defaults = {k: v for k, v in globals().items() + if isinstance(v, ConfigItem)} + logger.debug("Training config. Helptext: %s, options: %s", helptext, defaults) + return helptext, section, defaults + + +preview_images = ConfigItem( + datatype=int, + default=14, + group=_("evaluation"), + info=_("Number of sample faces to display for each side in the preview when training."), + rounding=2, + min_max=(2, 16)) + +mask_opacity = ConfigItem( + datatype=int, + default=30, + group=_("evaluation"), + info=_("The opacity of the mask overlay in the training preview. Lower values are more " + "transparent."), + rounding=2, + min_max=(0, 100)) + +mask_color = ConfigItem( + datatype=str, + default="#ff0000", + choices="colorchooser", + group=_("evaluation"), + info=_("The RGB hex color to use for the mask overlay in the training preview.")) + +zoom_amount = ConfigItem( + datatype=int, + default=5, + group=_("image augmentation"), + info=_("Percentage amount to randomly zoom each training image in and out."), + rounding=1, + min_max=(0, 25)) + +rotation_range = ConfigItem( + datatype=int, + default=10, + group=_("image augmentation"), + info=_("Percentage amount to randomly rotate each training image."), + rounding=1, + min_max=(0, 25)) + +shift_range = ConfigItem( + datatype=int, + default=5, + group=_("image augmentation"), + info=_("Percentage amount to randomly shift each training image horizontally and " + "vertically."), + rounding=1, + min_max=(0, 25)) + +flip_chance = ConfigItem( + datatype=int, + default=50, + group=_("image augmentation"), + info=_("Percentage chance to randomly flip each training image horizontally.\n" + "NB: This is ignored if the 'no-flip' option is enabled"), + rounding=1, + min_max=(0, 75)) + +color_lightness = ConfigItem( + datatype=int, + default=30, + group=_("color augmentation"), + info=_("Percentage amount to randomly alter the lightness of each training image.\n" + "NB: This is ignored if the 'no-augment-color' option is enabled"), + rounding=1, + min_max=(0, 75)) + +color_ab = ConfigItem( + datatype=int, + default=8, + group=_("color augmentation"), + info=_("Percentage amount to randomly alter the 'a' and 'b' colors of the L*a*b* color " + "space of each training image.\nNB: This is ignored if the 'no-augment-color' option" + "is enabled"), + rounding=1, + min_max=(0, 50)) + +color_clahe_chance = ConfigItem( + datatype=int, + default=50, + group=_("color augmentation"), + info=_("Percentage chance to perform Contrast Limited Adaptive Histogram Equalization on " + "each training image.\nNB: This is ignored if the 'no-augment-color' option is " + "enabled"), + rounding=1, + min_max=(0, 75), + fixed=False) + +color_clahe_max_size = ConfigItem( + datatype=int, + default=4, + group=_("color augmentation"), + info=_("The grid size dictates how much Contrast Limited Adaptive Histogram Equalization is " + "performed on any training image selected for clahe. Contrast will be applied " + "randomly with a gridsize of 0 up to the maximum. This value is a multiplier " + "calculated from the training image size.\nNB: This is ignored if the " + "'no-augment-color' option is enabled"), + rounding=1, + min_max=(1, 8)) + + +__all__ = get_module_objects(__name__) diff --git a/plugins/train/training.py b/plugins/train/training.py new file mode 100644 index 0000000000..a74a73e0ab --- /dev/null +++ b/plugins/train/training.py @@ -0,0 +1,369 @@ +#! /usr/env/bin/python3 +""" Run the training loop for a training plugin """ +from __future__ import annotations + +import logging +import os +import typing as T +import time +import warnings + +import numpy as np +import torch + +from torch.cuda import OutOfMemoryError + +from lib.training import Feeder, LearningRateFinder, LearningRateWarmup +from lib.training.tensorboard import TorchTensorBoard +from lib.utils import get_module_objects, FaceswapError +from plugins.train import train_config as mod_cfg +from plugins.train.trainer import trainer_config as trn_cfg + +from plugins.train.trainer._display import Samples, Timelapse + +if T.TYPE_CHECKING: + from collections.abc import Callable + from plugins.train.trainer._base import TrainerBase + +logger = logging.getLogger(__name__) + + +# Suppress non-Faceswap related Keras warning about backend padding mismatches +warnings.filterwarnings("ignore", + message="You might experience inconsistencies", + category=UserWarning) + + +class Trainer: + """ Handles the feeding of training images to Faceswap models, the generation of Tensorboard + logs and the creation of sample/time-lapse preview images. + + All Trainer plugins must inherit from this class. + + Parameters + ---------- + plugin : :class:`TrainerBase` + The plugin that will be processing each batch + images : dict[literal["a", "b"], list[str]] + The file paths for the images to be trained on for each side. The dictionary should contain + 2 keys ("a" and "b") with the values being a list of full paths corresponding to each side. + """ + + def __init__(self, plugin: TrainerBase, images: dict[T.Literal["a", "b"], list[str]]) -> None: + self._batch_size = plugin.batch_size + self._plugin = plugin + self._model = plugin.model + + self._feeder = Feeder(images, plugin.model, plugin.batch_size) + + self._exit_early = self._handle_lr_finder() + if self._exit_early: + logger.debug("Exiting from LR Finder") + return + + self._warmup = self._get_warmup() + self._model.state.add_session_batchsize(plugin.batch_size) + self._images = images + self._sides = sorted(key for key in self._images.keys()) + + self._tensorboard = self._set_tensorboard() + self._samples = Samples(self._model, + self._model.coverage_ratio, + trn_cfg.mask_opacity(), + trn_cfg.mask_color()) + + num_images = trn_cfg.preview_images() + assert isinstance(num_images, int) + self._timelapse = Timelapse(self._model, + self._model.coverage_ratio, + num_images, + trn_cfg.mask_opacity(), + trn_cfg.mask_color(), + self._feeder, + self._images) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def exit_early(self) -> bool: + """ True if the trainer should exit early, without perfoming any training steps """ + return self._exit_early + + @property + def batch_size(self) -> int: + """int : The batch size that the model is set to train at. """ + return self._batch_size + + def _handle_lr_finder(self) -> bool: + """ Handle the learning rate finder. + + If this is a new model, then find the optimal learning rate and return ``True`` if user has + just requested the graph, otherwise return ``False`` to continue training + + If it as existing model, set the learning rate to the value found by the learing rate + finder and return ``False`` to continue training + + Returns + ------- + bool + ``True`` if the learning rate finder options dictate that training should not continue + after finding the optimal leaning rate + """ + if not self._model.command_line_arguments.use_lr_finder: + return False + + if self._model.state.lr_finder > -1: + learning_rate = self._model.state.lr_finder + logger.info("Setting learning rate from Learning Rate Finder to %s", + f"{learning_rate:.1e}") + self._model.model.optimizer.learning_rate.assign(learning_rate) + self._model.state.update_session_config("learning_rate", learning_rate) + return False + + if self._model.state.iterations == 0 and self._model.state.session_id == 1: + lrf = LearningRateFinder(self) + success = lrf.find() + return mod_cfg.lr_finder_mode() == "graph_and_exit" or not success + + logger.debug("No learning rate finder rate. Not setting") + return False + + def _get_warmup(self) -> LearningRateWarmup: + """ Obtain the learning rate warmup instance + + Returns + ------- + :class:`plugins.train.lr_warmup.LRWarmup` + The Learning Rate Warmup object + """ + target_lr = float(self._model.model.optimizer.learning_rate.value.cpu().numpy()) + return LearningRateWarmup(self._model.model, target_lr, self._model.warmup_steps) + + def _set_tensorboard(self) -> TorchTensorBoard | None: + """ Set up Tensorboard callback for logging loss. + + Bypassed if command line option "no-logs" has been selected. + + Returns + ------- + :class:`keras.callbacks.TensorBoard` | None + Tensorboard object for the the current training session. ``None`` if Tensorboard + logging is not selected + """ + if self._model.state.current_session["no_logs"]: + logger.verbose("TensorBoard logging disabled") # type: ignore + return None + logger.debug("Enabling TensorBoard Logging") + + logger.debug("Setting up TensorBoard Logging") + log_dir = os.path.join(str(self._model.io.model_dir), + f"{self._model.name}_logs", + f"session_{self._model.state.session_id}") + tensorboard = TorchTensorBoard(log_dir=log_dir, + write_graph=True, + update_freq="batch") + tensorboard.set_model(self._model.model) + logger.verbose("Enabled TensorBoard Logging") # type: ignore + return tensorboard + + def toggle_mask(self) -> None: + """ Toggle the mask overlay on or off based on user input. """ + self._samples.toggle_mask_display() + + def train_one_batch(self) -> np.ndarray: + """ Process a single batch through the model and obtain the loss + + Returns + ------- + :class:`numpy.ndarray` + The total loss in the first position then A losses, by output order, then B losses, by + output order + """ + try: + inputs, targets = self._feeder.get_batch() + loss_t = self._plugin.train_batch(torch.from_numpy(inputs), + [torch.from_numpy(t) for t in targets]) + loss_cpu = loss_t.detach().cpu().numpy() + retval = np.array([sum(loss_cpu), *loss_cpu]) + except OutOfMemoryError as err: + msg = ("You do not have enough GPU memory available to train the selected model at " + "the selected settings. You can try a number of things:" + "\n1) Close any other application that is using your GPU (web browsers are " + "particularly bad for this)." + "\n2) Lower the batchsize (the amount of images fed into the model each " + "iteration)." + "\n3) Try enabling 'Mixed Precision' training." + "\n4) Use a more lightweight model, or select the model's 'LowMem' option " + "(in config) if it has one.") + raise FaceswapError(msg) from err + return retval + + def train_one_step(self, + viewer: Callable[[np.ndarray, str], None] | None, + timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"], + str] | None) -> None: + """ Running training on a batch of images for each side. + + Triggered from the training cycle in :class:`scripts.train.Train`. + + * Runs a training batch through the model. + + * Outputs the iteration's loss values to the console + + * Logs loss to Tensorboard, if logging is requested. + + * If a preview or time-lapse has been requested, then pushes sample images through the \ + model to generate the previews + + * Creates a snapshot if the total iterations trained so far meet the requested snapshot \ + criteria + + Notes + ----- + As every iteration is called explicitly, the Parameters defined should always be ``None`` + except on save iterations. + + Parameters + ---------- + viewer: :func:`scripts.train.Train._show` or ``None`` + The function that will display the preview image + timelapse_kwargs: dict + The keyword arguments for generating time-lapse previews. If a time-lapse preview is + not required then this should be ``None``. Otherwise all values should be full paths + the keys being `input_a`, `input_b`, `output`. + """ + self._model.state.increment_iterations() + logger.trace("Training one step: (iteration: %s)", self._model.iterations) # type: ignore + snapshot_interval = self._model.command_line_arguments.snapshot_interval + do_snapshot = (snapshot_interval != 0 and + self._model.iterations - 1 >= snapshot_interval and + (self._model.iterations - 1) % snapshot_interval == 0) + self._warmup() + loss = self.train_one_batch() + self._log_tensorboard(loss) + loss = self._collate_and_store_loss(loss[1:]) + self._print_loss(loss) + if do_snapshot: + self._model.io.snapshot() + self._update_viewers(viewer, timelapse_kwargs) + + def _log_tensorboard(self, loss: np.ndarray) -> None: + """ Log current loss to Tensorboard log files + + Parameters + ---------- + loss : :class:`numpy.ndarray` + The total loss in the first position then A losses, by output order, then B losses, by + output order + """ + if not self._tensorboard: + return + logger.trace("Updating TensorBoard log") # type: ignore + logs = {log[0]: float(log[1]) + for log in zip(self._model.state.loss_names, loss)} + + self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs) + + def _collate_and_store_loss(self, loss: np.ndarray) -> np.ndarray: + """ Collate the loss into totals for each side. + + The losses are summed into a total for each side. Loss totals are added to + :attr:`model.state._history` to track the loss drop per save iteration for backup purposes. + + If NaN protection is enabled, Checks for NaNs and raises an error if detected. + + Parameters + ---------- + loss : :class:`numpy.ndarray` + The total loss in the first position then A losses, by output order, then B losses, by + output order + + Returns + ------- + :class:`numpy.ndarray` + 2 ``floats`` which is the total loss for each side (eg sum of face + mask loss) + + Raises + ------ + FaceswapError + If a NaN is detected, a :class:`FaceswapError` will be raised + """ + # NaN protection + if mod_cfg.nan_protection() and not all(np.isfinite(val) for val in loss): + logger.critical("NaN Detected. Loss: %s", loss) + raise FaceswapError("A NaN was detected and you have NaN protection enabled. Training " + "has been terminated.") + + split = len(loss) // 2 + combined_loss = np.array([sum(loss[:split]), sum(loss[split:])]) + self._model.add_history(combined_loss) + logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore + return combined_loss + + def _print_loss(self, loss: np.ndarray) -> None: + """ Outputs the loss for the current iteration to the console. + + Parameters + ---------- + loss : :class`numpy.ndarray` + The loss for each side. List should contain 2 ``floats`` side "a" in position 0 and + side "b" in position `. + """ + output = ", ".join([f"Loss {side}: {side_loss:.5f}" + for side, side_loss in zip(("A", "B"), loss)]) + timestamp = time.strftime("%H:%M:%S") + output = f"[{timestamp}] [#{self._model.iterations:05d}] {output}" + print(f"{output}", end="\r") + + def _update_viewers(self, + viewer: Callable[[np.ndarray, str], None] | None, + timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"], + str] | None) -> None: + """ Update the preview viewer and timelapse output + + Parameters + ---------- + viewer: :func:`scripts.train.Train._show` or ``None`` + The function that will display the preview image + timelapse_kwargs: dict + The keyword arguments for generating time-lapse previews. If a time-lapse preview is + not required then this should be ``None``. Otherwise all values should be full paths + the keys being `input_a`, `input_b`, `output`. + """ + if viewer is not None: + self._samples.images = self._feeder.generate_preview() + samples = self._samples.show_sample() + if samples is not None: + viewer(samples, + "Training - 'S': Save Now. 'R': Refresh Preview. 'M': Toggle Mask. 'F': " + "Toggle Screen Fit-Actual Size. 'ENTER': Save and Quit") + + if timelapse_kwargs: + self._timelapse.output_timelapse(timelapse_kwargs) + + def _clear_tensorboard(self) -> None: + """ Stop Tensorboard logging. + + Tensorboard logging needs to be explicitly shutdown on training termination. Called from + :class:`scripts.train.Train` when training is stopped. + """ + if not self._tensorboard: + return + logger.debug("Ending Tensorboard Session: %s", self._tensorboard) + self._tensorboard.on_train_end() + + def save(self, is_exit: bool = False) -> None: + """ Save the model + + Parameters + ---------- + is_exit: bool, optional + ``True`` if save has been called on model exit. Default: ``False`` + """ + self._model.io.save(is_exit=is_exit) + assert self._tensorboard is not None + self._tensorboard.on_save() + if is_exit: + self._clear_tensorboard() + + +__all__ = get_module_objects(__name__) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..71f762d260 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[tool.flake8] +max-line-length = 99 +max-complexity=10 +statistics = true +count = true +exclude = [".git", "__pycache__"] +per-file-ignores = ["__init__.py:F401"] + +[tool.pylint.DESIGN] +min-public-methods = 1 +max-args = 10 +max-attributes = 10 +max-positional-arguments = 10 + +[tool.pylint.TYPECHECK] +generated-members = ["cv2"] + +[[tool.mypy.overrides]] +module = [ + "fastcluster.*", + "ffmpy.*", + "h5py.*", + "imageio_ffmpeg.*", + "keras.*", + "numexpr.*", + "pexpect.*", + "pynvml.*", + "scipy.*", + "sklearn.*", + "tensorboard.*", + "torch.*", + "tqdm.*", + "win32console.*", + "winpty.*",] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["."] +filterwarnings = ["ignore::DeprecationWarning:keras.*:"] + +[tool.pyright] +reportUnsupportedDunderAll = false diff --git a/requirements.txt b/requirements.txt deleted file mode 100755 index 2ca4acea3d..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,20 +0,0 @@ -tqdm -psutil -pathlib -numpy==1.15.4 -opencv-python -scikit-image -scikit-learn -matplotlib==2.2.2 -ffmpy==0.2.2 -nvidia-ml-py3 -h5py==2.9.0 -Keras==2.2.4 -cmake -dlib -face-recognition - -# tensorflow is included within the docker image. -# If you are looking for dependencies for a manual install, -# you may want to install tensorflow-gpu==1.4.0 for CUDA 8.0 or tensorflow-gpu>=1.11.0 for CUDA 9.0 -# NB: MTCNN will not work with tensorflow releases prior to 1.6.0 diff --git a/requirements/__init__.py b/requirements/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/_requirements_base.txt b/requirements/_requirements_base.txt new file mode 100644 index 0000000000..b077480b07 --- /dev/null +++ b/requirements/_requirements_base.txt @@ -0,0 +1,20 @@ +packaging>=25.0 +tqdm>=4.67 +psutil>=7.1.0 +numexpr>=2.14.0 +numpy>=2.2.0 +opencv-python>=4.12.0 +pillow>=12.0.0 +scikit-learn>=1.7.2 +fastcluster>=1.3.0 +matplotlib>=3.10.7 +imageio>=2.37.0 +# ffmpeg binary >=0.6.0 breaks convert. +# TODO fix convert to use latest binary +imageio-ffmpeg>=0.4.9,<0.6.0 +ffmpy>=0.6.0 +pywin32>=305 ; sys_platform == "win32" +#torchvision>=0.18.0,<0.25.0 +torchvision>=0.18.0,<0.25.0 +tensorboard>=2.20.0 +keras>=3.12.0,<3.13.0 diff --git a/requirements/_requirements_dev.txt b/requirements/_requirements_dev.txt new file mode 100644 index 0000000000..3e2a558c9c --- /dev/null +++ b/requirements/_requirements_dev.txt @@ -0,0 +1,11 @@ +# Additional optional development tool requirements +flake8 +flake8-pyproject +mypy +pylint +pytest +pytest-mock +types-setuptools +types-PyYAML +types-psutil +types-tensorflow diff --git a/requirements/requirements.py b/requirements/requirements.py new file mode 100644 index 0000000000..4b573d90ab --- /dev/null +++ b/requirements/requirements.py @@ -0,0 +1,205 @@ +#! /usr/env/bin/python3 +""" Parses the contents of python requirements.txt files and holds the information in a parsable +format + +NOTE: Only packages from the Python Standard Library should be imported in this module +""" +from __future__ import annotations + +import logging +import typing as T +import os + +from importlib import import_module, util as import_util + +if T.TYPE_CHECKING: + from packaging.markers import Marker + from packaging.requirements import Requirement + from packaging.specifiers import Specifier + +logger = logging.getLogger(__name__) + + +PYTHON_VERSIONS: dict[str, tuple[int, int]] = {"rocm_60": (3, 12)} +""" dict[str, tuple[int, int]] : Mapping of requirement file names to the maximum supported +Python version, if below the project maximum """ + + +class Requirements: + """ Parse requirement information + + Parameters + ---------- + include_dev : bool, optional + ``True`` to additionally load requirements from the dev requirements file + """ + def __init__(self, include_dev: bool = False) -> None: + self._include_dev = include_dev + self._marker: type[Marker] | None = None + self._requirement: type[Requirement] | None = None + self._specifier: type[Specifier] | None = None + self._global_options: dict[str, list[str]] = {} + self._requirements: dict[str, list[Requirement]] = {} + + @property + def packaging_available(self) -> bool: + """ bool : ``True`` if the packaging Library is available otherwise ``False`` """ + if self._requirement is not None: + return True + return import_util.find_spec("packaging") is not None + + @property + def requirements(self) -> dict[str, list[Requirement]]: + """ dict[str, list[Requirement]] : backend type as key, list of required packages as + value """ + if not self._requirements: + self._load_requirements() + return self._requirements + + @property + def global_options(self) -> dict[str, list[str]]: + """ dict[str, list[str]] : The global pip install options for each backend """ + if not self._requirements: + self._load_requirements() + return self._global_options + + def __repr__(self) -> str: + """ Pretty print the required packages for logging """ + props = ", ".join( + f"{k}={repr(getattr(self, k))}" + for k, v in self.__class__.__dict__.items() + if isinstance(v, property) and not k.startswith("_")) + return f"{self.__class__.__name__}({props})" + + def _import_packaging(self) -> None: + """ Import the packaging library and set the required classes to class attributes. """ + if self._requirement is not None: + return + + logger.debug("Importing packaging library") + mark_mod = import_module("packaging.markers") + req_mod = import_module("packaging.requirements") + spec_mod = import_module("packaging.specifiers") + self._marker = mark_mod.Marker + self._requirement = req_mod.Requirement + self._specifier = spec_mod.Specifier + + @classmethod + def _parse_file(cls, file_path: str) -> tuple[list[str], list[str]]: + """ Parse a requirements file + + Parameters + ---------- + file_path : str + The full path to a requirements file to parse + + Returns + ------- + global_options : list[str] + Any global options collected from the requirements file + requirements : list[str] + The requirements strings from the requirments file + """ + global_options = [] + requirements = [] + with open(file_path, encoding="utf8") as f: + for line in f: + line = line.strip() # Skip blanks, comments and nested requirement files + if not line or line.startswith(("#", "-r")): + continue + + line = line.split("#", maxsplit=1)[0] # Strip inline comments + + if line.startswith("-"): # Collect global option + global_options.append(line) + continue + requirements.append(line) # Collect requirement + + logger.debug("Parsed requirements file '%s'. global_options: %s, requirements: %s", + os.path.basename(file_path), global_options, requirements) + return global_options, requirements + + def parse_requirements(self, packages: list[str]) -> list[Requirement]: + """ Drop in replacement for deprecated pkg_resources.parse_requirements + + Parameters + ---------- + packages: list[str] + List of packages formatted from a requirements.txt file + + Returns + ------- + list[:class:`packaging.Requirement`] + List of Requirement objects + """ + self._import_packaging() + assert self._requirement is not None + requirements = [self._requirement(p) for p in packages] + retval = [r for r in requirements if r.marker is None or r.marker.evaluate()] + if len(retval) != len(requirements): + logger.debug("Filtered invalid packages %s", + [(r.name, r.marker) for r in set(requirements).difference(set(retval))]) + logger.debug("Parsed requirements %s: %s", packages, retval) + return retval + + def _parse_options(self, options: list[str]) -> list[str]: + """ Parse global options from a requirements file and only return valid options + + Parameters + ---------- + options: list[str] + List of global options formatted from a requirements.txt file + + Returns + ------- + list[str] + List of global options valid for the running system + """ + if not options: + return options + assert self._marker is not None + retval = [] + for opt in options: + if ";" not in opt: + retval.append(opt) + continue + directive, marker = opt.split(";", maxsplit=1) + if not self._marker(marker.strip()).evaluate(): + logger.debug("Filtered invalid option: '%s'", opt) + continue + retval.append(directive.strip()) + + logger.debug("Selected options: %s", retval) + return retval + + def _load_requirements(self) -> None: + """ Parse the requirements files and populate information to :attr:`_requirements` """ + req_path = os.path.dirname(os.path.realpath(__file__)) + base_file = os.path.join(req_path, "_requirements_base.txt") + req_files = [os.path.join(req_path, f) + for f in os.listdir(req_path) + if f.startswith("requirements_") + and os.path.splitext(f)[-1] == ".txt"] + + opts_base, reqs_base = self._parse_file(base_file) + parsed_reqs_base = self.parse_requirements(reqs_base) + parsed_opts_base = self._parse_options(opts_base) + + if self._include_dev: + opts_dev, reqs_dev = self._parse_file(os.path.join(req_path, "_requirements_dev.txt")) + opts_base += opts_dev + parsed_reqs_base += self.parse_requirements(reqs_dev) + parsed_opts_base += self._parse_options(opts_dev) + + for req_file in req_files: + backend = os.path.splitext(os.path.basename(req_file))[0].replace("requirements_", "") + assert backend + opts, reqs = self._parse_file(req_file) + self._requirements[backend] = parsed_reqs_base + self.parse_requirements(reqs) + self._global_options[backend] = parsed_opts_base + self._parse_options(opts) + logger.debug("[%s] Requirements: %s , Options: %s", + backend, self._requirements[backend], self._global_options[backend]) + + +if __name__ == "__main__": + print(Requirements(include_dev=True)) diff --git a/requirements/requirements_apple-silicon.txt b/requirements/requirements_apple-silicon.txt new file mode 100644 index 0000000000..48599420c0 --- /dev/null +++ b/requirements/requirements_apple-silicon.txt @@ -0,0 +1,5 @@ +-r _requirements_base.txt +# These next 2 should have been installed, but some users complain of errors +decorator +cloudpickle +torch>=2.3.0,<2.10.0 diff --git a/requirements/requirements_cpu.txt b/requirements/requirements_cpu.txt new file mode 100644 index 0000000000..e3567a428b --- /dev/null +++ b/requirements/requirements_cpu.txt @@ -0,0 +1,3 @@ +-r _requirements_base.txt +--extra-index-url https://download.pytorch.org/whl/cpu +torch>=2.3.0,<2.10.0 diff --git a/requirements/requirements_nvidia.txt b/requirements/requirements_nvidia.txt new file mode 100644 index 0000000000..70cfd9676e --- /dev/null +++ b/requirements/requirements_nvidia.txt @@ -0,0 +1,2 @@ +# Meta requirements file for latest Nvidia version +-r _requirements_nvidia_13.txt diff --git a/requirements/requirements_nvidia_11.txt b/requirements/requirements_nvidia_11.txt new file mode 100644 index 0000000000..10dd42b80f --- /dev/null +++ b/requirements/requirements_nvidia_11.txt @@ -0,0 +1,8 @@ +# Cuda compatibility 3.5-9.0 +# GTX7xx - RTX40xx +# Maximum supported Python: 3.13 +-r _requirements_base.txt +# Exclude badly numbered Python2 version of nvidia-ml-py +nvidia-ml-py>=12.535,<300 +--extra-index-url https://download.pytorch.org/whl/cu118 +torch>=2.7.0,<2.8.0 diff --git a/requirements/requirements_nvidia_12.txt b/requirements/requirements_nvidia_12.txt new file mode 100644 index 0000000000..cefd2da151 --- /dev/null +++ b/requirements/requirements_nvidia_12.txt @@ -0,0 +1,7 @@ +# Cuda compatibility 5.0-12.0 +# GTX9xx - RTX50xx +-r _requirements_base.txt +# Exclude badly numbered Python2 version of nvidia-ml-py +nvidia-ml-py>=12.535,<300 +--extra-index-url https://download.pytorch.org/whl/cu126 +torch>=2.7.0,<2.10.0 diff --git a/requirements/requirements_nvidia_13.txt b/requirements/requirements_nvidia_13.txt new file mode 100644 index 0000000000..79ccbcdd0c --- /dev/null +++ b/requirements/requirements_nvidia_13.txt @@ -0,0 +1,7 @@ +# Cuda compatibility 7.5- +# RTX 20xx - +-r _requirements_base.txt +# Exclude badly numbered Python2 version of nvidia-ml-py +nvidia-ml-py>=12.535,<300 +--extra-index-url https://download.pytorch.org/whl/cu130 +torch>=2.9.0,<2.10.0 diff --git a/requirements/requirements_rocm.txt b/requirements/requirements_rocm.txt new file mode 100644 index 0000000000..76f61581ea --- /dev/null +++ b/requirements/requirements_rocm.txt @@ -0,0 +1,2 @@ +# Meta requirements file for latest ROCm version +-r _requirements_rocm_64.txt diff --git a/requirements/requirements_rocm_60.txt b/requirements/requirements_rocm_60.txt new file mode 100644 index 0000000000..23d6b4dc3e --- /dev/null +++ b/requirements/requirements_rocm_60.txt @@ -0,0 +1,4 @@ +# Maximum supported Python: 3.12 +-r _requirements_base.txt +--extra-index-url https://download.pytorch.org/whl/rocm6.0 +torch>=2.4.0,<2.5.0 diff --git a/requirements/requirements_rocm_61.txt b/requirements/requirements_rocm_61.txt new file mode 100644 index 0000000000..efaad01b67 --- /dev/null +++ b/requirements/requirements_rocm_61.txt @@ -0,0 +1,4 @@ +# Maximum supported Python: 3.13 +-r _requirements_base.txt +--extra-index-url https://download.pytorch.org/whl/rocm6.1 +torch>=2.5.0,<2.7.0 diff --git a/requirements/requirements_rocm_62.txt b/requirements/requirements_rocm_62.txt new file mode 100644 index 0000000000..0c47e1ae20 --- /dev/null +++ b/requirements/requirements_rocm_62.txt @@ -0,0 +1,5 @@ +# Maximum supported Python: 3.13 +-r _requirements_base.txt +--extra-index-url https://download.pytorch.org/whl/rocm6.2 +--extra-index-url https://download.pytorch.org/whl/rocm6.2.4 +torch>=2.5.0,<2.8.0 diff --git a/requirements/requirements_rocm_63.txt b/requirements/requirements_rocm_63.txt new file mode 100644 index 0000000000..73c53bdb27 --- /dev/null +++ b/requirements/requirements_rocm_63.txt @@ -0,0 +1,3 @@ +-r _requirements_base.txt +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch>=2.7.0,<2.10.0 diff --git a/requirements/requirements_rocm_64.txt b/requirements/requirements_rocm_64.txt new file mode 100644 index 0000000000..a4deb56725 --- /dev/null +++ b/requirements/requirements_rocm_64.txt @@ -0,0 +1,3 @@ +-r _requirements_base.txt +--extra-index-url https://download.pytorch.org/whl/rocm6.4 +torch>=2.8.0,<2.10.0 diff --git a/scripts/convert.py b/scripts/convert.py index dbee8ec036..13e5f2ea96 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -1,161 +1,645 @@ #!/usr/bin python3 -""" The script to run the convert process of faceswap """ - +""" Main entry point to the convert process of FaceSwap """ +from __future__ import annotations +from dataclasses import dataclass, field import logging import re import os import sys -from pathlib import Path +import typing as T +from threading import Event +from time import sleep import cv2 +import numpy as np from tqdm import tqdm -from scripts.fsmedia import Alignments, Images, PostProcess, Utils -from lib.faces_detect import DetectedFace -from lib.multithreading import BackgroundGenerator +from scripts import fsmedia +from scripts.fsmedia import PostProcess, finalize +from lib.serializer import get_serializer +from lib.convert import Converter +from lib.align import AlignedFace, DetectedFace, update_legacy_png_header +from lib.gpu_stats import GPUStats +from lib.image import read_image_meta_batch, ImagesLoader +from lib.multithreading import MultiThread, total_cpus from lib.queue_manager import queue_manager -from lib.utils import get_folder, get_image_paths, hash_image_file +from lib.utils import (get_module_objects, FaceswapError, get_folder, + get_image_paths, handle_deprecated_cliopts) +from plugins.extract import ExtractMedia, Extractor from plugins.plugin_loader import PluginLoader +from plugins.train import train_config as mod_cfg + +if T.TYPE_CHECKING: + from argparse import Namespace + from collections.abc import Callable + from plugins.convert.writer._base import Output + from plugins.train.model._base import ModelBase + from lib.align.aligned_face import CenteringType + from lib.queue_manager import EventQueue + + +logger = logging.getLogger(__name__) -from .extract import Plugins as Extractor -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +@dataclass +class ConvertItem: + """ A single frame with associated objects passing through the convert process. + + Parameters + ---------- + input: :class:`~plugins.extract.extract_media.ExtractMedia` + The ExtractMedia object holding the :attr:`filename`, :attr:`image` and attr:`list` of + :class:`~lib.align.DetectedFace` objects loaded from disk + feed_faces: list, Optional + list of :class:`lib.align.AlignedFace` objects for feeding into the model's predict + function + reference_faces: list, Optional + list of :class:`lib.align.AlignedFace` objects at model output sized for using as reference + in the convert functionfor feeding into the model's predict + swapped_faces: :class:`np.ndarray` + The swapped faces returned from the model's predict function + """ + inbound: ExtractMedia + feed_faces: list[AlignedFace] = field(default_factory=list) + reference_faces: list[AlignedFace] = field(default_factory=list) + swapped_faces: np.ndarray = field(default_factory=lambda: np.array([])) class Convert(): - """ The convert process. """ - def __init__(self, arguments): - logger.debug("Initializing %s: (args: %s)", self.__class__.__name__, arguments) - self.args = arguments - self.output_dir = get_folder(self.args.output_dir) - self.extractor = None - self.faces_count = 0 + """ The Faceswap Face Conversion Process. + + The conversion process is responsible for swapping the faces on source frames with the output + from a trained model. - self.images = Images(self.args) - self.alignments = Alignments(self.args, False, self.images.is_video) + It leverages a series of user selected post-processing plugins, executed from + :class:`lib.convert.Converter`. - # Update Legacy alignments - Legacy(self.alignments, self.images.input_images, arguments.input_aligned_dir) + The convert process is self contained and should not be referenced by any other scripts, so it + contains no public properties. - self.post_process = PostProcess(arguments) - self.verify_output = False + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The arguments to be passed to the convert process as generated from Faceswap's command + line arguments + """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (args: %s)", self.__class__.__name__, arguments) + self._args = handle_deprecated_cliopts(arguments) + + self._images = ImagesLoader(self._args.input_dir, fast_count=True) + self._alignments = self._get_alignments() + self._opts = OptionalActions(self._args, self._images.file_list, self._alignments) - self.opts = OptionalActions(self.args, self.images.input_images, self.alignments) + self._add_queues() + self._predictor = Predict(self._queue_size, arguments) + self._disk_io = DiskIO(self._alignments, self._images, self._predictor, arguments) + self._predictor.launch(self._disk_io.load_queue) + self._validate() + get_folder(self._args.output_dir) + + configfile = self._args.configfile if hasattr(self._args, "configfile") else None + self._converter = Converter(self._predictor.output_size, + self._predictor.coverage_ratio, + self._predictor.centering, + self._disk_io.draw_transparent, + self._disk_io.pre_encode, + arguments, + configfile=configfile) + self._patch_threads = self._get_threads() logger.debug("Initialized %s", self.__class__.__name__) - def process(self): - """ Original & LowMem models go with converter + @property + def _queue_size(self) -> int: + """ int: Size of the converter queues. 2 for single process otherwise 4 """ + retval = 2 if self._args.singleprocess or self._args.jobs == 1 else 4 + logger.debug(retval) + return retval - Note: GAN prediction outputs a mask + an image, while other - predicts only an image. """ - Utils.set_verbosity(self.args.loglevel) + @property + def _pool_processes(self) -> int: + """ int: The number of threads to run in parallel. Based on user options and number of + available processors. """ + if self._args.singleprocess: + retval = 1 + elif self._args.jobs > 0: + retval = min(self._args.jobs, total_cpus(), self._images.count) + else: + retval = min(total_cpus(), self._images.count) + retval = 1 if retval == 0 else retval + logger.debug(retval) + return retval + + def _get_alignments(self) -> fsmedia.Alignments: + """ Perform validation checks and legacy updates and return alignemnts object + + Returns + ------- + :class:`~scripts.fsmedia.Alignments` + The alignments file for the extract job + """ + retval = fsmedia.Alignments(self._args, False, self._images.is_video) + if retval.version == 1.0: + logger.error("The alignments file format has been updated since the given alignments " + "file was generated. You need to update the file to proceed.") + logger.error("To do this run the 'Alignments Tool' > 'Extract' Job.") + sys.exit(1) - if not self.alignments.have_alignments_file: - self.load_extractor() + retval.update_legacy_has_source(os.path.basename(self._args.input_dir)) + return retval - model = self.load_model() - converter = self.load_converter(model) + def _validate(self) -> None: + """ Validate the Command Line Options. - batch = BackgroundGenerator(self.prepare_images(), 1) + Ensure that certain cli selections are valid and won't result in an error. Checks: + * If frames have been passed in with video output, ensure user supplies reference + video. + * If "on-the-fly" and a Neural Network mask is selected, warn and switch to 'extended' + * If a mask-type is selected, ensure it exists in the alignments file. + * If a predicted mask-type is selected, ensure model has been trained with a mask + otherwise attempt to select first available masks, otherwise raise error. - for item in batch.iterator(): - self.convert(converter, item) + Raises + ------ + FaceswapError + If an invalid selection has been found. - if self.extractor: + """ + if (self._args.writer == "ffmpeg" and + not self._images.is_video and + self._args.reference_video is None): + raise FaceswapError("Output as video selected, but using frames as input. You must " + "provide a reference video ('-ref', '--reference-video').") + + if (self._args.on_the_fly and + self._args.mask_type not in ("none", "extended", "components")): + logger.warning("You have selected an incompatible mask type ('%s') for On-The-Fly " + "conversion. Switching to 'extended'", self._args.mask_type) + self._args.mask_type = "extended" + + if (not self._args.on_the_fly and + self._args.mask_type not in ("none", "predicted") and + not self._alignments.mask_is_valid(self._args.mask_type)): + msg = (f"You have selected the Mask Type `{self._args.mask_type}` but at least one " + "face does not have this mask stored in the Alignments File.\nYou should " + "generate the required masks with the Mask Tool or set the Mask Type option to " + "an existing Mask Type.\nA summary of existing masks is as follows:\nTotal " + f"faces: {self._alignments.faces_count}, " + f"Masks: {self._alignments.mask_summary}") + raise FaceswapError(msg) + + if self._args.mask_type == "predicted" and not self._predictor.has_predicted_mask: + available_masks = [k for k, v in self._alignments.mask_summary.items() + if k != "none" and v == self._alignments.faces_count] + if not available_masks: + msg = ("Predicted Mask selected, but the model was not trained with a mask and no " + "masks are stored in the Alignments File.\nYou should generate the " + "required masks with the Mask Tool or set the Mask Type to `none`.") + raise FaceswapError(msg) + mask_type = available_masks[0] + logger.warning("Predicted Mask selected, but the model was not trained with a " + "mask. Selecting first available mask: '%s'", mask_type) + self._args.mask_type = mask_type + + def _add_queues(self) -> None: + """ Add the queues for in, patch and out. """ + logger.debug("Adding queues. Queue size: %s", self._queue_size) + for qname in ("convert_in", "convert_out", "patch"): + queue_manager.add_queue(qname, self._queue_size) + + def _get_threads(self) -> MultiThread: + """ Get the threads for patching the converted faces onto the frames. + + Returns + :class:`lib.multithreading.MultiThread` + The threads that perform the patching of swapped faces onto the output frames + """ + save_queue = queue_manager.get_queue("convert_out") + patch_queue = queue_manager.get_queue("patch") + return MultiThread(self._converter.process, patch_queue, save_queue, + thread_count=self._pool_processes, name="patch") + + def process(self) -> None: + """ The entry point for triggering the Conversion Process. + + Should only be called from :class:`lib.cli.launcher.ScriptExecutor` + + Raises + ------ + FaceswapError + Error raised if the process runs out of memory + """ + logger.debug("Starting Conversion") + # queue_manager.debug_monitor(5) + try: + self._convert_images() + self._disk_io.save_thread.join() queue_manager.terminate_queues() - Utils.finalize(self.images.images_found, - self.faces_count, - self.verify_output) - - def load_extractor(self): - """ Set on the fly extraction """ - logger.warning("No Alignments file found. Extracting on the fly.") - logger.warning("NB: This will use the inferior dlib-hog for extraction " - "and dlib pose predictor for landmarks. It is recommended " - "to perfom Extract first for superior results") - extract_args = {"detector": "dlib-hog", - "aligner": "dlib", - "loglevel": self.args.loglevel} - self.extractor = Extractor(None, extract_args) - self.extractor.launch_detector() - self.extractor.launch_aligner() - - def load_model(self): - """ Load the model requested for conversion """ - logger.debug("Loading Model") - model_dir = get_folder(self.args.model_dir) - model = PluginLoader.get_model(self.args.trainer)(model_dir, self.args.gpus, predict=True) - logger.debug("Loaded Model") - return model + finalize(self._images.count, + self._predictor.faces_count, + self._predictor.verify_output) + logger.debug("Completed Conversion") + except MemoryError as err: + msg = ("Faceswap ran out of RAM running convert. Conversion is very system RAM " + "heavy, so this can happen in certain circumstances when you have a lot of " + "cpus but not enough RAM to support them all." + "\nYou should lower the number of processes in use by either setting the " + "'singleprocess' flag (-sp) or lowering the number of parallel jobs (-j).") + raise FaceswapError(msg) from err - def load_converter(self, model): - """ Load the requested converter for conversion """ - conv = self.args.converter - converter = PluginLoader.get_converter(conv)( - model.converter(self.args.swap_model), - model=model, - arguments=self.args) - return converter - - def prepare_images(self): - """ Prepare the images for conversion """ - filename = "" - if self.extractor: - load_queue = queue_manager.get_queue("load") - for filename, image in tqdm(self.images.load(), - total=self.images.images_found, - file=sys.stdout): - - if (self.args.discard_frames and - self.opts.check_skipframe(filename) == "discard"): - continue + def _convert_images(self) -> None: + """ Start the multi-threaded patching process, monitor all threads for errors and join on + completion. """ + logger.debug("Converting images") + self._patch_threads.start() + while True: + self._check_thread_error() + if self._disk_io.completion_event.is_set(): + logger.debug("DiskIO completion event set. Joining Pool") + break + if self._patch_threads.completed(): + logger.debug("All patch threads completed") + break + sleep(1) + self._patch_threads.join() + + logger.debug("Putting EOF") + queue_manager.get_queue("convert_out").put("EOF") + logger.debug("Converted images") + + def _check_thread_error(self) -> None: + """ Monitor all running threads for errors, and raise accordingly. + + Raises + ------ + Error + Re-raises any error encountered within any of the running threads + """ + for thread in (self._predictor.thread, + self._disk_io.load_thread, + self._disk_io.save_thread, + self._patch_threads): + thread.check_and_raise_error() + + +class DiskIO(): # pylint:disable=too-many-instance-attributes + """ Disk Input/Output for the converter process. + + Background threads to: + * Load images from disk and get the detected faces + * Save images back to disk + + Parameters + ---------- + alignments: :class:`scripts.fsmedia.Alignments` + The alignments for the input video + images: :class:`lib.image.ImagesLoader` + The input images + predictor: :class:`Predict` + The object for generating predictions from the model + arguments: :class:`argparse.Namespace` + The arguments that were passed to the convert process as generated from Faceswap's command + line arguments + """ + + def __init__(self, + alignments: fsmedia.Alignments, + images: ImagesLoader, + predictor: Predict, + arguments: Namespace) -> None: + logger.debug("Initializing %s: (alignments: %s, images: %s, predictor: %s, arguments: %s)", + self.__class__.__name__, alignments, images, predictor, arguments) + self._alignments = alignments + self._images = images + self._args = arguments + self._pre_process = PostProcess(arguments) + self._completion_event = Event() + + # For frame skipping + self._imageidxre = re.compile(r"(\d+)(?!.*\d\.)(?=\.\w+$)") + self._frame_ranges = self._get_frame_ranges() + self._writer = self._get_writer(predictor) + + # Extractor for on the fly detection + self._extractor = self._load_extractor() + + self._queues: dict[T.Literal["load", "save"], EventQueue] = {} + self._threads: dict[T.Literal["load", "save"], MultiThread] = {} + self._init_threads() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def completion_event(self) -> Event: + """ :class:`event.Event`: Event is set when the DiskIO Save task is complete """ + return self._completion_event + + @property + def draw_transparent(self) -> bool: + """ bool: ``True`` if the selected writer can output transparent and it's Draw_transparent + configuration item is set otherwise ``False`` """ + return self._writer.output_alpha + + @property + def pre_encode(self) -> Callable[[np.ndarray, T.Any], list[bytes]] | None: + """ python function: Selected writer's pre-encode function, if it has one, + otherwise ``None`` """ + dummy = np.zeros((20, 20, 3), dtype="uint8") + test = self._writer.pre_encode(dummy) + retval: Callable | None = None if test is None else self._writer.pre_encode + logger.debug("Writer pre_encode function: %s", retval) + return retval + + @property + def save_thread(self) -> MultiThread: + """ :class:`lib.multithreading.MultiThread`: The thread that is running the image writing + operation. """ + return self._threads["save"] + + @property + def load_thread(self) -> MultiThread: + """ :class:`lib.multithreading.MultiThread`: The thread that is running the image loading + operation. """ + return self._threads["load"] + + @property + def load_queue(self) -> EventQueue: + """ :class:`~lib.queue_manager.EventQueue`: The queue that images and detected faces are " + "loaded into. """ + return self._queues["load"] + + @property + def _total_count(self) -> int: + """ int: The total number of frames to be converted """ + if self._frame_ranges and not self._args.keep_unchanged: + retval = sum(fr[1] - fr[0] + 1 for fr in self._frame_ranges) + else: + retval = self._images.count + logger.debug(retval) + return retval + + # Initialization + def _get_writer(self, predictor: Predict) -> Output: + """ Load the selected writer plugin. - frame = os.path.basename(filename) - if self.extractor: - detected_faces = self.detect_faces(load_queue, filename, image) + Parameters + ---------- + predictor: :class:`Predict` + The object for generating predictions from the model + + Returns + ------- + :mod:`plugins.convert.writer` plugin + The requested writer plugin + """ + args = [self._args.output_dir] + if self._args.writer in ("ffmpeg", "gif"): + args.extend([self._total_count, self._frame_ranges]) + if self._args.writer == "ffmpeg": + if self._images.is_video: + args.append(self._args.input_dir) else: - detected_faces = self.alignments_faces(frame, image) + args.append(self._args.reference_video) + if self._args.writer == "patch": + args.append(predictor.output_size) + logger.debug("Writer args: %s", args) + configfile = self._args.configfile if hasattr(self._args, "configfile") else None + return PluginLoader.get_converter("writer", self._args.writer)(*args, + configfile=configfile) - faces_count = len(detected_faces) - if faces_count != 0: - # Post processing requires a dict with "detected_faces" key - self.post_process.do_actions( - {"detected_faces": detected_faces}) - self.faces_count += faces_count + def _get_frame_ranges(self) -> list[tuple[int, int]] | None: + """ Obtain the frame ranges that are to be converted. - if faces_count > 1: - self.verify_output = True - logger.verbose("Found more than one face in " - "an image! '%s'", frame) - - yield filename, image, detected_faces - - def detect_faces(self, load_queue, filename, image): - """ Extract the face from a frame (If alignments file not found) """ - inp = {"filename": filename, - "image": image} - load_queue.put(inp) - faces = next(self.extractor.detect_faces()) - - landmarks = faces["landmarks"] - detected_faces = faces["detected_faces"] - final_faces = list() - - for idx, face in enumerate(detected_faces): - detected_face = DetectedFace() - detected_face.from_dlib_rect(face) - detected_face.landmarksXY = landmarks[idx] - final_faces.append(detected_face) - return final_faces - - def alignments_faces(self, frame, image): - """ Get the face from alignments file """ - if not self.check_alignments(frame): - return list() - - faces = self.alignments.get_faces_in_frame(frame) - detected_faces = list() + If frame ranges have been specified, then split the command line formatted arguments into + ranges that can be used. + + Returns + list or ``None`` + A list of frames to be processed, or ``None`` if the command line argument was not + used + """ + if not self._args.frame_ranges: + logger.debug("No frame range set") + return None + + minframe, maxframe = None, None + if self._images.is_video: + minframe, maxframe = 1, self._images.count + else: + indices = [int(self._imageidxre.findall(os.path.basename(filename))[0]) + for filename in self._images.file_list] + if indices: + minframe, maxframe = min(indices), max(indices) + logger.debug("minframe: %s, maxframe: %s", minframe, maxframe) + + if minframe is None or maxframe is None: + raise FaceswapError("Frame Ranges specified, but could not determine frame numbering " + "from filenames") + + retval = [] + for rng in self._args.frame_ranges: + if "-" not in rng: + raise FaceswapError("Frame Ranges not specified in the correct format") + start, end = rng.split("-") + retval.append((max(int(start), minframe), min(int(end), maxframe))) + logger.debug("frame ranges: %s", retval) + return retval + + def _load_extractor(self) -> Extractor | None: + """ Load the CV2-DNN Face Extractor Chain. + + For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU. + Results are poor. + + Returns + ------- + :class:`plugins.extract.Pipeline.Extractor` + The face extraction chain to be used for on-the-fly conversion + """ + if not self._alignments.have_alignments_file and not self._args.on_the_fly: + logger.error("No alignments file found. Please provide an alignments file for your " + "destination video (recommended) or enable on-the-fly conversion (not " + "recommended).") + sys.exit(1) + if self._alignments.have_alignments_file: + if self._args.on_the_fly: + logger.info("On-The-Fly conversion selected, but an alignments file was found. " + "Using pre-existing alignments file: '%s'", self._alignments.file) + else: + logger.debug("Alignments file found: '%s'", self._alignments.file) + return None + + logger.debug("Loading extractor") + logger.warning("On-The-Fly conversion selected. This will use the inferior cv2-dnn for " + "extraction and will produce poor results.") + logger.warning("It is recommended to generate an alignments file for your destination " + "video with Extract first for superior results.") + extractor = Extractor(detector="cv2-dnn", + aligner="cv2-dnn", + masker=self._args.mask_type, + multiprocess=True, + rotate_images=None, + min_size=20) + extractor.launch() + logger.debug("Loaded extractor") + return extractor + + def _init_threads(self) -> None: + """ Initialize queues and threads. + + Creates the load and save queues and the load and save threads. Starts the threads. + """ + logger.debug("Initializing DiskIO Threads") + for task in T.get_args(T.Literal["load", "save"]): + self._add_queue(task) + self._start_thread(task) + logger.debug("Initialized DiskIO Threads") + + def _add_queue(self, task: T.Literal["load", "save"]) -> None: + """ Add the queue to queue_manager and to :attr:`self._queues` for the given task. + + Parameters + ---------- + task: {"load", "save"} + The task that the queue is to be added for + """ + logger.debug("Adding queue for task: '%s'", task) + if task == "load": + q_name = "convert_in" + elif task == "save": + q_name = "convert_out" + else: + q_name = task + self._queues[task] = queue_manager.get_queue(q_name) + logger.debug("Added queue for task: '%s'", task) + + def _start_thread(self, task: T.Literal["load", "save"]) -> None: + """ Create the thread for the given task, add it it :attr:`self._threads` and start it. + + Parameters + ---------- + task: {"load", "save"} + The task that the thread is to be created for + """ + logger.debug("Starting thread: '%s'", task) + args = self._completion_event if task == "save" else None + func = getattr(self, f"_{task}") + io_thread = MultiThread(func, args, thread_count=1) + io_thread.start() + self._threads[task] = io_thread + logger.debug("Started thread: '%s'", task) + + # Loading tasks + def _load(self, *args) -> None: # pylint:disable=unused-argument + """ Load frames from disk. + + In a background thread: + * Loads frames from disk. + * Discards or passes through cli selected skipped frames + * Pairs the frame with its :class:`~lib.align.DetectedFace` objects + * Performs any pre-processing actions + * Puts the frame and detected faces to the load queue + """ + logger.debug("Load Images: Start") + idx = 0 + for filename, image in self._images.load(): + idx += 1 + if self._queues["load"].shutdown_event.is_set(): + logger.debug("Load Queue: Stop signal received. Terminating") + break + if image is None or (not image.any() and image.ndim not in (2, 3)): + # All black frames will return not numpy.any() so check dims too + logger.warning("Unable to open image. Skipping: '%s'", filename) + continue + if self._check_skipframe(filename): + if self._args.keep_unchanged: + logger.trace("Saving unchanged frame: %s", filename) # type:ignore + out_file = os.path.join(self._args.output_dir, os.path.basename(filename)) + self._queues["save"].put((out_file, image)) + else: + logger.trace("Discarding frame: '%s'", filename) # type:ignore + continue + + detected_faces = self._get_detected_faces(filename, image) + item = ConvertItem(ExtractMedia(filename, image, detected_faces)) + self._pre_process.do_actions(item.inbound) + self._queues["load"].put(item) + + logger.debug("Putting EOF") + self._queues["load"].put("EOF") + logger.debug("Load Images: Complete") + + def _check_skipframe(self, filename: str) -> bool: + """ Check whether a frame is to be skipped. + + Parameters + ---------- + filename: str + The filename of the frame to check + + Returns + ------- + bool + ``True`` if the frame is to be skipped otherwise ``False`` + """ + if not self._frame_ranges: + return False + indices = self._imageidxre.findall(filename) + if not indices: + logger.warning("Could not determine frame number. Frame will be converted: '%s'", + filename) + return False + idx = int(indices[0]) + skipframe = not any(map(lambda b: b[0] <= idx <= b[1], self._frame_ranges)) + logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore[attr-defined] + return skipframe + + def _get_detected_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]: + """ Return the detected faces for the given image. + + If we have an alignments file, then the detected faces are created from that file. If + we're running On-The-Fly then they will be extracted from the extractor. + + Parameters + ---------- + filename: str + The filename to return the detected faces for + image: :class:`numpy.ndarray` + The frame that the detected faces exist in + + Returns + ------- + list + List of :class:`lib.align.DetectedFace` objects + """ + logger.trace("Getting faces for: '%s'", filename) # type:ignore + if not self._extractor: + detected_faces = self._alignments_faces(os.path.basename(filename), image) + else: + detected_faces = self._detect_faces(filename, image) + logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore + return detected_faces + + def _alignments_faces(self, frame_name: str, image: np.ndarray) -> list[DetectedFace]: + """ Return detected faces from an alignments file. + + Parameters + ---------- + frame_name: str + The name of the frame to return the detected faces for + image: :class:`numpy.ndarray` + The frame that the detected faces exist in + + Returns + ------- + list + List of :class:`lib.align.DetectedFace` objects + """ + if not self._check_alignments(frame_name): + return [] + + faces = self._alignments.get_faces_in_frame(frame_name) + detected_faces = [] for rawface in faces: face = DetectedFace() @@ -163,166 +647,563 @@ def alignments_faces(self, frame, image): detected_faces.append(face) return detected_faces - def check_alignments(self, frame): - """ If we have no alignments for this image, skip it """ - have_alignments = self.alignments.frame_exists(frame) + def _check_alignments(self, frame_name: str) -> bool: + """ Ensure that we have alignments for the current frame. + + If we have no alignments for this image, skip it and output a message. + + Parameters + ---------- + frame_name: str + The name of the frame to check that we have alignments for + + Returns + ------- + bool + ``True`` if we have alignments for this face, otherwise ``False`` + """ + have_alignments = self._alignments.frame_exists(frame_name) if not have_alignments: - tqdm.write("No alignment found for {}, " - "skipping".format(frame)) + tqdm.write(f"No alignment found for {frame_name}, skipping") return have_alignments - def convert(self, converter, item): - """ Apply the conversion transferring faces onto frames """ - try: - filename, image, faces = item - skip = self.opts.check_skipframe(filename) + def _detect_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]: + """ Extract the face from a frame for On-The-Fly conversion. + + Pulls detected faces out of the Extraction pipeline. + + Parameters + ---------- + filename: str + The filename to return the detected faces for + image: :class:`numpy.ndarray` + The frame that the detected faces exist in + + Returns + ------- + list + List of :class:`lib.align.DetectedFace` objects + """ + assert self._extractor is not None + self._extractor.input_queue.put(ExtractMedia(filename, image)) + faces = next(self._extractor.detected_faces()) + return faces.detected_faces + + # Saving tasks + def _save(self, completion_event: Event) -> None: + """ Save the converted images. + + Puts the selected writer into a background thread and feeds it from the output of the + patch queue. + + Parameters + ---------- + completion_event: :class:`event.Event` + An even that this process triggers when it has finished saving + """ + logger.debug("Save Images: Start") + write_preview = self._args.redirect_gui and self._writer.is_stream + preview_image = os.path.join(self._writer.output_folder, ".gui_preview.jpg") + logger.debug("Write preview for gui: %s", write_preview) + for idx in tqdm(range(self._total_count), desc="Converting", file=sys.stdout): + if self._queues["save"].shutdown_event.is_set(): + logger.debug("Save Queue: Stop signal received. Terminating") + break + item: tuple[str, np.ndarray | bytes] | T.Literal["EOF"] = self._queues["save"].get() + if item == "EOF": + logger.debug("EOF Received") + break + filename, image = item + # Write out preview image for the GUI every 10 frames if writing to stream + if write_preview and idx % 10 == 0 and not os.path.exists(preview_image): + logger.debug("Writing GUI Preview image: '%s'", preview_image) + assert isinstance(image, np.ndarray) + cv2.imwrite(preview_image, image) + self._writer.write(filename, image) + self._writer.close() + completion_event.set() + logger.debug("Save Faces: Complete") + + +class Predict(): # pylint:disable=too-many-instance-attributes + """ Obtains the output from the Faceswap model. + + Parameters + ---------- + queue_size: int + The maximum size of the input queue + arguments: :class:`argparse.Namespace` + The arguments that were passed to the convert process as generated from Faceswap's command + line arguments + """ + def __init__(self, queue_size: int, arguments: Namespace) -> None: + logger.debug("Initializing %s: (args: %s, queue_size: %s)", + self.__class__.__name__, arguments, queue_size) + self._args = arguments + self._in_queue: EventQueue | None = None + self._out_queue = queue_manager.get_queue("patch") + self._serializer = get_serializer("json") + self._faces_count = 0 + self._verify_output = False + + self._model = self._load_model() + self._batchsize = self._get_batchsize(queue_size) + self._sizes = self._get_io_sizes() + self._coverage_ratio = self._model.coverage_ratio + self._y_offset = mod_cfg.vertical_offset() / 100. + self._centering: CenteringType = T.cast("CenteringType", mod_cfg.centering()) + + self._thread: MultiThread | None = None + logger.debug("Initialized %s: (out_queue: %s)", self.__class__.__name__, self._out_queue) + + @property + def thread(self) -> MultiThread: + """ :class:`~lib.multithreading.MultiThread`: The thread that is running the prediction + function from the Faceswap model. """ + assert self._thread is not None + return self._thread + + @property + def in_queue(self) -> EventQueue: + """ :class:`~lib.queue_manager.EventQueue`: The input queue to the predictor. """ + assert self._in_queue is not None + return self._in_queue + + @property + def out_queue(self) -> EventQueue: + """ :class:`~lib.queue_manager.EventQueue`: The output queue from the predictor. """ + return self._out_queue + + @property + def faces_count(self) -> int: + """ int: The total number of faces seen by the Predictor. """ + return self._faces_count + + @property + def verify_output(self) -> bool: + """ bool: ``True`` if multiple faces have been found in frames, otherwise ``False``. """ + return self._verify_output + + @property + def coverage_ratio(self) -> float: + """ float: The coverage ratio that the model was trained at. """ + return self._coverage_ratio + + @property + def centering(self) -> CenteringType: + """ str: The centering that the model was trained on (`"head", "face"` or `"legacy"`) """ + return self._centering + + @property + def has_predicted_mask(self) -> bool: + """ bool: ``True`` if the model was trained to learn a mask, otherwise ``False``. """ + return bool(mod_cfg.Loss.learn_mask()) + + @property + def output_size(self) -> int: + """ int: The size in pixels of the Faceswap model output. """ + return self._sizes["output"] + + def _get_io_sizes(self) -> dict[str, int]: + """ Obtain the input size and output size of the model. + + Returns + ------- + dict + input_size in pixels and output_size in pixels + """ + input_shape = self._model.model.input_shape + input_shape = [input_shape] if not isinstance(input_shape, list) else input_shape + output_shape = self._model.model.output_shape + output_shape = [output_shape] if not isinstance(output_shape, list) else output_shape + retval = {"input": input_shape[0][1], "output": output_shape[-1][1]} + logger.debug(retval) + return retval + + def _load_model(self) -> ModelBase: + """ Load the Faceswap model. + + Returns + ------- + :mod:`plugins.train.model` plugin + The trained model in the specified model folder + """ + logger.debug("Loading Model") + model_dir = get_folder(self._args.model_dir, make_folder=False) + if not model_dir: + raise FaceswapError(f"{self._args.model_dir} does not exist.") + trainer = self._get_model_name(model_dir) + model = PluginLoader.get_model(trainer)(model_dir, self._args, predict=True) + model.build() + logger.debug("Loaded Model") + return model + + def _get_batchsize(self, queue_size: int) -> int: + """ Get the batch size for feeding the model. + + Sets the batch size to 1 if inference is being run on CPU, otherwise the minimum of the + input queue size and the model's `convert_batchsize` configuration option. + + Parameters + ---------- + queue_size: int + The queue size that is feeding the predictor + + Returns + ------- + int + The batch size that the model is to be fed at. + """ + logger.debug("Getting batchsize") + is_cpu = GPUStats is None or GPUStats().device_count == 0 + batchsize = 1 if is_cpu else mod_cfg.convert_batchsize() + batchsize = min(queue_size, batchsize) + logger.debug("Got batchsize: %s", batchsize) + return batchsize + + def _get_model_name(self, model_dir: str) -> str: + """ Return the name of the Faceswap model used. + + Retrieve the name of the model from the model's state file. + + Parameters + ---------- + model_dir: str + The folder that contains the trained Faceswap model + + Returns + ------- + str + The name of the Faceswap model being used. + + """ + statefiles = [fname for fname in os.listdir(str(model_dir)) + if fname.endswith("_state.json")] + if len(statefiles) != 1: + raise FaceswapError("There should be 1 state file in your model folder. " + f"{len(statefiles)} were found.") + statefile = os.path.join(str(model_dir), statefiles[0]) + + state = self._serializer.load(statefile) + trainer = state.get("name", None) + + if not trainer: + raise FaceswapError("Trainer name could not be read from state file.") + logger.debug("Trainer from state file: '%s'", trainer) + return trainer + + def launch(self, load_queue: EventQueue) -> None: + """ Launch the prediction process in a background thread. + + Starts the prediction thread and returns the thread. + + Parameters + ---------- + load_queue: :class:`~lib.queue_manager.EventQueue` + The queue that contains images and detected faces for feeding the model + """ + self._in_queue = load_queue + self._thread = MultiThread(self._predict_faces, thread_count=1) + self._thread.start() + + def _predict_faces(self) -> None: + """ Run Prediction on the Faceswap model in a background thread. + + Reads from the :attr:`self._in_queue`, prepares images for prediction + then puts the predictions back to the :attr:`self.out_queue` + """ + faces_seen = 0 + consecutive_no_faces = 0 + batch: list[ConvertItem] = [] + assert self._in_queue is not None + while True: + item: T.Literal["EOF"] | ConvertItem = self._in_queue.get() + if item == "EOF": + logger.debug("EOF Received") + if batch: # Process out any remaining items + self._process_batch(batch, faces_seen) + break + logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore + faces_count = len(item.inbound.detected_faces) - if not skip: - for face in faces: - image = converter.patch_image(image, face) - filename = str(self.output_dir / Path(filename).name) + # Safety measure. If a large stream of frames appear that do not have faces, + # these will stack up into RAM. Keep a count of consecutive frames with no faces. + # If self._batchsize number of frames appear, force the current batch through + # to clear RAM. + consecutive_no_faces = consecutive_no_faces + 1 if faces_count == 0 else 0 + self._faces_count += faces_count + if faces_count > 1: + self._verify_output = True + logger.verbose("Found more than one face in an image! '%s'", # type:ignore + os.path.basename(item.inbound.filename)) + + self.load_aligned(item) + faces_seen += faces_count + + batch.append(item) + + if faces_seen < self._batchsize and consecutive_no_faces < self._batchsize: + logger.trace("Continuing. Current batchsize: %s, " # type:ignore + "consecutive_no_faces: %s", faces_seen, consecutive_no_faces) + continue + + self._process_batch(batch, faces_seen) + + consecutive_no_faces = 0 + faces_seen = 0 + batch = [] + + logger.debug("Putting EOF") + self._out_queue.put("EOF") + logger.debug("Load queue complete") + + def _process_batch(self, batch: list[ConvertItem], faces_seen: int): + """ Predict faces on the given batch of images and queue out to patch thread + + Parameters + ---------- + batch: list + List of :class:`ConvertItem` objects for the current batch + faces_seen: int + The number of faces seen in the current batch + + Returns + ------- + :class:`np.narray` + The predicted faces for the current batch + """ + logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore + len(batch), faces_seen) + feed_batch = [feed_face for item in batch for feed_face in item.feed_faces] + if faces_seen != 0: + feed_faces = self._compile_feed_faces(feed_batch) + batch_size = None + predicted = self._predict(feed_faces, batch_size) + else: + predicted = np.array([]) + + self._queue_out_frames(batch, predicted) + + def load_aligned(self, item: ConvertItem) -> None: + """ Load the model's feed faces and the reference output faces. + + For each detected face in the incoming item, load the feed face and reference face + images, correctly sized for input and output respectively. + + Parameters + ---------- + item: :class:`ConvertMedia` + The convert media object, containing the ExctractMedia for the current image + """ + logger.trace("Loading aligned faces: '%s'", item.inbound.filename) # type:ignore + feed_faces = [] + reference_faces = [] + for detected_face in item.inbound.detected_faces: + feed_face = AlignedFace(detected_face.landmarks_xy, + image=item.inbound.image, + centering=self._centering, + size=self._sizes["input"], + coverage_ratio=self._coverage_ratio, + y_offset=self._y_offset, + dtype="float32") + if self._sizes["input"] == self._sizes["output"]: + reference_faces.append(feed_face) + else: + reference_faces.append(AlignedFace(detected_face.landmarks_xy, + image=item.inbound.image, + centering=self._centering, + size=self._sizes["output"], + coverage_ratio=self._coverage_ratio, + y_offset=self._y_offset, + dtype="float32")) + feed_faces.append(feed_face) + item.feed_faces = feed_faces + item.reference_faces = reference_faces + logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore + + @staticmethod + def _compile_feed_faces(feed_faces: list[AlignedFace]) -> np.ndarray: + """ Compile a batch of faces for feeding into the Predictor. + + Parameters + ---------- + feed_faces: list + List of :class:`~lib.align.AlignedFace` objects sized for feeding into the model + + Returns + ------- + :class:`numpy.ndarray` + A batch of faces ready for feeding into the Faceswap model. + """ + logger.trace("Compiling feed face. Batchsize: %s", len(feed_faces)) # type:ignore + retval = np.stack([T.cast(np.ndarray, feed_face.face)[..., :3] + for feed_face in feed_faces]) / 255.0 + logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore + return retval + + def _predict(self, feed_faces: np.ndarray, batch_size: int | None = None) -> np.ndarray: + """ Run the Faceswap models' prediction function. + + Parameters + ---------- + feed_faces: :class:`numpy.ndarray` + The batch to be fed into the model + batch_size: int, optional + Used for plaidml only. Indicates to the model what batch size is being processed. + Default: ``None`` + + Returns + ------- + :class:`numpy.ndarray` + The swapped faces for the given batch + """ + logger.trace("Predicting: Batchsize: %s", len(feed_faces)) # type:ignore - if self.args.draw_transparent: - filename = "{}.png".format(os.path.splitext(filename)[0]) - logger.trace("Set extension to png: `%s`", filename) + if self._model.color_order.lower() == "rgb": + feed_faces = feed_faces[..., ::-1] - cv2.imwrite(filename, image) # pylint: disable=no-member - except Exception as err: - logger.error("Failed to convert image: '%s'. Reason: %s", filename, err) - raise + feed = feed_faces + logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore + inbound = self._model.model.predict(feed, + verbose=0, # pyright:ignore[reportArgumentType] + batch_size=batch_size) + predicted: list[np.ndarray] = inbound if isinstance(inbound, list) else [inbound] -class OptionalActions(): - """ Process the optional actions for convert """ + if self._model.color_order.lower() == "rgb": + predicted[0] = predicted[0][..., ::-1] - def __init__(self, args, input_images, alignments): + logger.trace("Output shape(s): %s", # type:ignore + [predict.shape for predict in predicted]) + + # Only take last output(s) + if predicted[-1].shape[-1] == 1: # Merge mask to alpha channel + retval = np.concatenate(predicted[-2:], axis=-1).astype("float32") + else: + retval = predicted[-1].astype("float32") + + logger.trace("Final shape: %s", retval.shape) # type:ignore + return retval + + def _queue_out_frames(self, batch: list[ConvertItem], swapped_faces: np.ndarray) -> None: + """ Compile the batch back to original frames and put to the Out Queue. + + For batching, faces are split away from their frames. This compiles all detected faces + back to their parent frame before putting each frame to the out queue in batches. + + Parameters + ---------- + batch: dict + The batch that was used as the input for the model predict function + swapped_faces: :class:`numpy.ndarray` + The predictions returned from the model's predict function + """ + logger.trace("Queueing out batch. Batchsize: %s", len(batch)) # type:ignore + pointer = 0 + for item in batch: + num_faces = len(item.inbound.detected_faces) + if num_faces != 0: + item.swapped_faces = swapped_faces[pointer:pointer + num_faces] + + logger.trace("Putting to queue. ('%s', detected_faces: %s, " # type:ignore + "reference_faces: %s, swapped_faces: %s)", item.inbound.filename, + len(item.inbound.detected_faces), len(item.reference_faces), + item.swapped_faces.shape[0]) + pointer += num_faces + self._out_queue.put(batch) + logger.trace("Queued out batch. Batchsize: %s", len(batch)) # type:ignore + + +class OptionalActions(): # pylint:disable=too-few-public-methods + """ Process specific optional actions for Convert. + + Currently only handles skip faces. This class should probably be (re)moved. + + Parameters + ---------- + arguments : :class:`argparse.Namespace` + The arguments that were passed to the convert process as generated from Faceswap's command + line arguments + input_images : list[str] + List of input image files + alignments : :class:`scripts.fsmedia.Alignments` + The alignments file for this conversion + """ + def __init__(self, + arguments: Namespace, + input_images: list[str], + alignments: fsmedia.Alignments) -> None: logger.debug("Initializing %s", self.__class__.__name__) - self.args = args - self.input_images = input_images - self.alignments = alignments - self.frame_ranges = self.get_frame_ranges() - self.imageidxre = re.compile(r"[^(mp4)](\d+)(?!.*\d)") + self._args = arguments + self._input_images = input_images + self._alignments = alignments - self.remove_skipped_faces() + self._remove_skipped_faces() logger.debug("Initialized %s", self.__class__.__name__) # SKIP FACES # - def remove_skipped_faces(self): - """ Remove deleted faces from the loaded alignments """ + def _remove_skipped_faces(self) -> None: + """ If the user has specified an input aligned directory, remove any non-matching faces + from the alignments file. """ logger.debug("Filtering Faces") - face_hashes = self.get_face_hashes() - if not face_hashes: - logger.debug("No face hashes. Not skipping any faces") + accept_dict = self._get_face_metadata() + if not accept_dict: + logger.debug("No aligned face data. Not skipping any faces") return - pre_face_count = self.alignments.faces_count - self.alignments.filter_hashes(face_hashes, filter_out=False) - logger.info("Faces filtered out: %s", pre_face_count - self.alignments.faces_count) + pre_face_count = self._alignments.faces_count + self._alignments.filter_faces(accept_dict, filter_out=False) + logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count) + + def _get_face_metadata(self) -> dict[str, list[int]]: + """ Check for the existence of an aligned directory for identifying which faces in the + target frames should be swapped. If it exists, scan the folder for face's metadata - def get_face_hashes(self): - """ Check for the existence of an aligned directory for identifying - which faces in the target frames should be swapped. - If it exists, obtain the hashes of the faces in the folder """ - face_hashes = list() - input_aligned_dir = self.args.input_aligned_dir + Returns + ------- + dict + Dictionary of source frame names with a list of associated face indices to be skipped + """ + retval: dict[str, list[int]] = {} + input_aligned_dir = self._args.input_aligned_dir if input_aligned_dir is None: - logger.verbose("Aligned directory not specified. All faces listed in the " - "alignments file will be converted") - elif not os.path.isdir(input_aligned_dir): + logger.verbose("Aligned directory not specified. All faces listed in " # type:ignore + "the alignments file will be converted") + return retval + if not os.path.isdir(input_aligned_dir): logger.warning("Aligned directory not found. All faces listed in the " "alignments file will be converted") - else: - file_list = [path for path in get_image_paths(input_aligned_dir)] - logger.info("Getting Face Hashes for selected Aligned Images") - for face in tqdm(file_list, desc="Hashing Faces"): - face_hashes.append(hash_image_file(face)) - logger.debug("Face Hashes: %s", (len(face_hashes))) - if not face_hashes: - logger.error("Aligned directory is empty, no faces will be converted!") - exit(1) - elif len(face_hashes) <= len(self.input_images) / 3: - logger.warning("Aligned directory contains far fewer images than the input " - "directory, are you sure this is the right folder?") - return face_hashes - - # SKIP FRAME RANGES # - def get_frame_ranges(self): - """ split out the frame ranges and parse out 'min' and 'max' values """ - if not self.args.frame_ranges: - return None + return retval - minmax = {"min": 0, # never any frames less than 0 - "max": float("inf")} - rng = [tuple(map(lambda q: minmax[q] if q in minmax.keys() else int(q), - v.split("-"))) - for v in self.args.frame_ranges] - return rng - - def check_skipframe(self, filename): - """ Check whether frame is to be skipped """ - if not self.frame_ranges: - return None - idx = int(self.imageidxre.findall(filename)[0]) - skipframe = not any(map(lambda b: b[0] <= idx <= b[1], - self.frame_ranges)) - if skipframe and self.args.discard_frames: - skipframe = "discard" - return skipframe + log_once = False + filelist = get_image_paths(input_aligned_dir) + for fullpath, metadata in tqdm(read_image_meta_batch(filelist), + total=len(filelist), + desc="Reading Face Data", + leave=False): + if "itxt" not in metadata or "source" not in metadata["itxt"]: + # UPDATE LEGACY FACES FROM ALIGNMENTS FILE + if not log_once: + logger.warning("Legacy faces discovered in '%s'. These faces will be updated", + input_aligned_dir) + log_once = True + data = update_legacy_png_header(fullpath, self._alignments) + if not data: + raise FaceswapError( + f"Some of the faces being passed in from '{input_aligned_dir}' could not " + f"be matched to the alignments file '{self._alignments.file}'\n" + "Please double check your sources and try again.") + meta = data["source"] + else: + meta = metadata["itxt"]["source"] + retval.setdefault(meta["source_filename"], []).append(meta["face_index"]) + if not retval: + raise FaceswapError("Aligned directory is empty, no faces will be converted!") + if len(retval) <= len(self._input_images) / 3: + logger.warning("Aligned directory contains far fewer images than the input " + "directory, are you sure this is the right folder?") + return retval -class Legacy(): - """ Update legacy alignments: - - Rotate landmarks and bounding boxes on legacy alignments - and remove the 'r' parameter - - Add face hashes to alignments file - """ - def __init__(self, alignments, frames, faces_dir): - self.alignments = alignments - self.frames = {os.path.basename(frame): frame - for frame in frames} - self.process(faces_dir) - - def process(self, faces_dir): - """ Run the rotate alignments process """ - rotated = self.alignments.get_legacy_rotation() - hashes = self.alignments.get_legacy_no_hashes() - if not rotated and not hashes: - return - if rotated: - logger.info("Legacy rotated frames found. Converting...") - self.rotate_landmarks(rotated) - self.alignments.save() - if hashes and faces_dir: - logger.info("Legacy alignments found. Adding Face Hashes...") - self.add_hashes(hashes, faces_dir) - self.alignments.save() - - def rotate_landmarks(self, rotated): - """ Rotate the landmarks """ - for rotate_item in tqdm(rotated, desc="Rotating Landmarks"): - frame = self.frames.get(rotate_item, None) - if frame is None: - logger.debug("Skipping missing frame: '%s'", rotate_item) - continue - self.alignments.rotate_existing_landmarks(rotate_item, frame) - - def add_hashes(self, hashes, faces_dir): - """ Add Face Hashes to the alignments file """ - all_faces = dict() - face_files = sorted(face for face in os.listdir(faces_dir) if "_" in face) - for face in face_files: - filename, extension = os.path.splitext(face) - index = filename[filename.rfind("_") + 1:] - if not index.isdigit(): - continue - orig_frame = filename[:filename.rfind("_")] + extension - all_faces.setdefault(orig_frame, dict())[int(index)] = os.path.join(faces_dir, face) - for frame in tqdm(hashes): - if frame not in all_faces.keys(): - logger.warning("Skipping missing frame: '%s'", frame) - continue - hash_faces = all_faces[frame] - for index, face_path in hash_faces.items(): - hash_faces[index] = hash_image_file(face_path) - self.alignments.add_face_hashes(frame, hash_faces) +__all__ = get_module_objects(__name__) diff --git a/scripts/extract.py b/scripts/extract.py index ef3f87b125..23b346f738 100644 --- a/scripts/extract.py +++ b/scripts/extract.py @@ -1,415 +1,831 @@ #!/usr/bin python3 -""" The script to run the extract process of faceswap """ +""" Main entry point to the extract process of FaceSwap """ +from __future__ import annotations import logging import os import sys -from pathlib import Path +import typing as T +from argparse import Namespace +from multiprocessing import Process + +import numpy as np from tqdm import tqdm +import torch +from lib.align.alignments import PNGHeaderDict + +from lib.image import encode_image, generate_thumbnail, ImagesLoader, ImagesSaver, read_image_meta +from lib.multithreading import MultiThread +from lib.utils import (get_folder, get_module_objects, handle_deprecated_cliopts, + IMAGE_EXTENSIONS, VIDEO_EXTENSIONS) +from plugins.extract import ExtractMedia, Extractor +from scripts.fsmedia import Alignments, PostProcess, finalize -from lib.faces_detect import DetectedFace -from lib.gpu_stats import GPUStats -from lib.multithreading import MultiThread, PoolProcess, SpawnProcess -from lib.queue_manager import queue_manager, QueueEmpty -from lib.utils import get_folder, hash_encode_image -from plugins.plugin_loader import PluginLoader -from scripts.fsmedia import Alignments, Images, PostProcess, Utils +if T.TYPE_CHECKING: + from lib.align.alignments import PNGHeaderAlignmentsDict -tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +# tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning # TODO? +logger = logging.getLogger(__name__) class Extract(): - """ The extract process. """ - def __init__(self, arguments): + """ The Faceswap Face Extraction Process. + + The extraction process is responsible for detecting faces in a series of images/video, aligning + these faces and then generating a mask. + + It leverages a series of user selected plugins, chained together using + :mod:`plugins.extract.pipeline`. + + The extract process is self contained and should not be referenced by any other scripts, so it + contains no public properties. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The arguments to be passed to the extraction process as generated from Faceswap's command + line arguments + """ + def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments) - self.args = arguments - self.output_dir = get_folder(self.args.output_dir) - logger.info("Output Directory: %s", self.args.output_dir) - self.images = Images(self.args) - self.alignments = Alignments(self.args, True, self.images.is_video) - self.plugins = Plugins(self.args) - - self.post_process = PostProcess(arguments) - - self.verify_output = False - self.save_interval = None - if hasattr(self.args, "save_interval"): - self.save_interval = self.args.save_interval - logger.debug("Initialized %s", self.__class__.__name__) + self._args = handle_deprecated_cliopts(arguments) + self._input_locations = self._get_input_locations() + self._validate_batchmode() + + configfile = self._args.configfile if hasattr(self._args, "configfile") else None + normalization = None if self._args.normalization == "none" else self._args.normalization + maskers = ["components", "extended"] + maskers += self._args.masker if self._args.masker else [] + recognition = ("vgg_face2" + if arguments.identity or arguments.filter or arguments.nfilter + else None) + self._extractor = Extractor(self._args.detector, + self._args.aligner, + maskers, + recognition=recognition, + configfile=configfile, + multiprocess=not self._args.singleprocess, + rotate_images=self._args.rotate_images, + min_size=self._args.min_size, + normalize_method=normalization, + re_feed=self._args.re_feed, + re_align=self._args.re_align) + self._filter = Filter(self._args.ref_threshold, + self._args.filter, + self._args.nfilter, + self._extractor) + + def _get_input_locations(self) -> list[str]: + """ Obtain the full path to input locations. Will be a list of locations if batch mode is + selected, or a containing a single location if batch mode is not selected. + + Returns + ------- + list: + The list of input location paths + """ + if not self._args.batch_mode or os.path.isfile(self._args.input_dir): + return [self._args.input_dir] # Not batch mode or a single file + + retval = [os.path.join(self._args.input_dir, fname) + for fname in os.listdir(self._args.input_dir) + if (os.path.isdir(os.path.join(self._args.input_dir, fname)) # folder images + and any(os.path.splitext(iname)[-1].lower() in IMAGE_EXTENSIONS + for iname in os.listdir(os.path.join(self._args.input_dir, fname)))) + or os.path.splitext(fname)[-1].lower() in VIDEO_EXTENSIONS] # video + + logger.debug("Input locations: %s", retval) + return retval + + def _validate_batchmode(self) -> None: + """ Validate the command line arguments. + + If batch-mode selected and there is only one object to extract from, then batch mode is + disabled + + If processing in batch mode, some of the given arguments may not make sense, in which case + a warning is shown and those options are reset. + + """ + if not self._args.batch_mode: + return - def process(self): - """ Perform the extraction process """ + if os.path.isfile(self._args.input_dir): + logger.warning("Batch mode selected but input is not a folder. Switching to normal " + "mode") + self._args.batch_mode = False + + if not self._input_locations: + logger.error("Batch mode selected, but no valid files found in input location: '%s'. " + "Exiting.", self._args.input_dir) + sys.exit(1) + + if self._args.alignments_path: + logger.warning("Custom alignments path not supported for batch mode. " + "Reverting to default.") + self._args.alignments_path = None + + def _output_for_input(self, input_location: str) -> str: + """ Obtain the path to an output folder for faces for a given input location. + + If not running in batch mode, then the user supplied output location will be returned, + otherwise a sub-folder within the user supplied output location will be returned based on + the input filename + + Parameters + ---------- + input_location: str + The full path to an input video or folder of images + """ + if not self._args.batch_mode: + return self._args.output_dir + + retval = os.path.join(self._args.output_dir, + os.path.splitext(os.path.basename(input_location))[0]) + logger.debug("Returning output: '%s' for input: '%s'", retval, input_location) + return retval + + def process(self) -> None: + """ The entry point for triggering the Extraction Process. + + Should only be called from :class:`lib.cli.launcher.ScriptExecutor` + """ logger.info('Starting, this may take a while...') - Utils.set_verbosity(self.args.loglevel) -# queue_manager.debug_monitor(1) - self.threaded_io("load") - save_thread = self.threaded_io("save") - self.run_extraction() - save_thread.join() - self.alignments.save() - Utils.finalize(self.images.images_found, - self.alignments.faces_count, - self.verify_output) - - def threaded_io(self, task, io_args=None): - """ Perform I/O task in a background thread """ + if self._args.batch_mode: + logger.info("Batch mode selected processing: %s", self._input_locations) + for job_no, location in enumerate(self._input_locations): + if self._args.batch_mode: + logger.info("Processing job %s of %s: '%s'", + job_no + 1, len(self._input_locations), location) + arguments = Namespace(**self._args.__dict__) + arguments.input_dir = location + arguments.output_dir = self._output_for_input(location) + else: + arguments = self._args + extract = _Extract(self._extractor, arguments) + if sys.platform == "linux" and len(self._input_locations) > 1: + # TODO - Running this in a process is hideously hacky. However, there is a memory + # leak in some instances when running in batch mode. Many days have been spent + # trying to track this down to no avail (most likely coming from C-code.) Running + # the extract job inside a process prevents the memory leak in testing. This should + # be replaced if/when the memory leak is found + # Only done for Linux as not reported elsewhere and this new process won't work in + # Windows because it can't fork. + proc = Process(target=extract.process) + proc.start() + proc.join() + else: + extract.process() + self._extractor.reset_phase_index() + + +class Filter(): + """ Obtains and holds face identity embeddings for any filter/nfilter image files + passed in from the command line. + + Parameters + ---------- + filter_files: list or ``None`` + The list of filter file(s) passed in as command line arguments + nfilter_files: list or ``None`` + The list of nfilter file(s) passed in as command line arguments + extractor: :class:`~plugins.extract.pipeline.Extractor` + The extractor pipeline for obtaining face identity from images + """ + def __init__(self, + threshold: float, + filter_files: list[str] | None, + nfilter_files: list[str] | None, + extractor: Extractor) -> None: + logger.debug("Initializing %s: (threshold: %s, filter_files: %s, nfilter_files: %s " + "extractor: %s)", self.__class__.__name__, threshold, filter_files, + nfilter_files, extractor) + self._threshold = threshold + self._filter_files, self._nfilter_files = self._validate_inputs(filter_files, + nfilter_files) + + if not self._filter_files and not self._nfilter_files: + logger.debug("Filter not selected. Exiting %s", self.__class__.__name__) + return + + self._embeddings: list[np.ndarray] = [np.array([]) for _ in self._filter_files] + self._nembeddings: list[np.ndarray] = [np.array([]) for _ in self._nfilter_files] + self._extractor = extractor + + self._get_embeddings() + self._extractor.recognition.add_identity_filters(self.embeddings, + self.n_embeddings, + self._threshold) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def active(self): + """ bool: ``True`` if filter files have been passed in command line arguments. ``False`` if + no filter files have been provided """ + return bool(self._filter_files) or bool(self._nfilter_files) + + @property + def embeddings(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The filter embeddings""" + if self._embeddings and all(np.any(e) for e in self._embeddings): + retval = np.concatenate(self._embeddings, axis=0) + else: + retval = np.array([]) + return retval + + @property + def n_embeddings(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The n-filter embeddings""" + if self._nembeddings and all(np.any(e) for e in self._nembeddings): + retval = np.concatenate(self._nembeddings, axis=0) + else: + retval = np.array([]) + return retval + + @classmethod + def _files_from_folder(cls, input_location: list[str]) -> list[str]: + """ Test whether the input location is a folder and if so, return the list of contained + image files, otherwise return the original input location + + Parameters + --------- + input_files: list + A list of full paths to individual files or to a folder location + + Returns + ------- + bool + Either the original list of files provided, or the image files that exist in the + provided folder location + """ + if not input_location or len(input_location) > 1: + return input_location + + test_folder = input_location[0] + if not os.path.isdir(test_folder): + logger.debug("'%s' is not a folder. Returning original list", test_folder) + return input_location + + retval = [os.path.join(test_folder, fname) + for fname in os.listdir(test_folder) + if os.path.splitext(fname)[-1].lower() in IMAGE_EXTENSIONS] + logger.info("Collected files from folder '%s': %s", test_folder, + [os.path.basename(f) for f in retval]) + return retval + + def _validate_inputs(self, + filter_files: list[str] | None, + nfilter_files: list[str] | None) -> tuple[list[str], list[str]]: + """ Validates that the given filter/nfilter files exist, are image files and are unique + + Parameters + ---------- + filter_files: list or ``None`` + The list of filter file(s) passed in as command line arguments + nfilter_files: list or ``None`` + The list of nfilter file(s) passed in as command line arguments + + Returns + ------- + filter_files: list + List of full paths to filter files + nfilter_files: list + List of full paths to nfilter files + """ + error = False + retval: list[list[str]] = [] + + for files in (filter_files, nfilter_files): + filt_files = [] if files is None else self._files_from_folder(files) + for file in filt_files: + if (not os.path.isfile(file) or + os.path.splitext(file)[-1].lower() not in IMAGE_EXTENSIONS): + logger.warning("Filter file '%s' does not exist or is not an image file", file) + error = True + retval.append(filt_files) + + filters = retval[0] + nfilters = retval[1] + f_fnames = set(os.path.basename(fname) for fname in filters) + n_fnames = set(os.path.basename(fname) for fname in nfilters) + if f_fnames.intersection(n_fnames): + error = True + logger.warning("filter and nfilter filenames should be unique. The following " + "filenames exist in both folders: %s", f_fnames.intersection(n_fnames)) + + if error: + logger.error("There was a problem processing filter files. See the above warnings for " + "details") + sys.exit(1) + logger.debug("filter_files: %s, nfilter_files: %s", retval[0], retval[1]) + + return filters, nfilters + + @classmethod + def _identity_from_extracted(cls, filename) -> tuple[np.ndarray, bool]: + """ Test whether the given image is a faceswap extracted face and contains identity + information. If so, return the identity embedding + + Parameters + ---------- + filename: str + Full path to the image file to load + + Returns + ------- + :class:`numpy.ndarray` + The identity embeddings, if they can be obtained from the image header, otherwise an + empty array + bool + ``True`` if the image is a faceswap extracted image otherwise ``False`` + """ + if os.path.splitext(filename)[-1].lower() != ".png": + logger.debug("'%s' not a png. Returning empty array", filename) + return np.array([]), False + + meta = read_image_meta(filename) + if "itxt" not in meta or "alignments" not in meta["itxt"]: + logger.debug("'%s' does not contain faceswap data. Returning empty array", filename) + return np.array([]), False + + align: "PNGHeaderAlignmentsDict" = meta["itxt"]["alignments"] + if "identity" not in align or "vggface2" not in align["identity"]: + logger.debug("'%s' does not contain identity data. Returning empty array", filename) + return np.array([]), True + + retval = np.array(align["identity"]["vggface2"]) + logger.debug("Obtained identity for '%s'. Shape: %s", filename, retval.shape) + + return retval, True + + def _process_extracted(self, item: ExtractMedia) -> None: + """ Process the output from the extraction pipeline. + + If no face has been detected, or multiple faces are detected for the inclusive filter, + embeddings and filenames are removed from the filter. + + if a single face is detected or multiple faces are detected for the exclusive filter, + embeddings are added to the relevent filter list + + Parameters + ---------- + item: :class:`plugins.extract.Pipeline.ExtracMedia` + The output from the extraction pipeline containing the identity encodings + """ + is_filter = item.filename in self._filter_files + lbl = "filter" if is_filter else "nfilter" + filelist = self._filter_files if is_filter else self._nfilter_files + embeddings = self._embeddings if is_filter else self._nembeddings + identities = np.array([face.identity["vggface2"] for face in item.detected_faces]) + idx = filelist.index(item.filename) + + if len(item.detected_faces) == 0: + logger.warning("No faces detected for %s in file '%s'. Image will not be used", + lbl, os.path.basename(item.filename)) + filelist.pop(idx) + embeddings.pop(idx) + return + + if len(item.detected_faces) == 1: + logger.debug("Adding identity for %s from file '%s'", lbl, item.filename) + embeddings[idx] = identities + return + + if len(item.detected_faces) > 1 and is_filter: + logger.warning("%s faces detected for filter in '%s'. These identies will not be used", + len(item.detected_faces), os.path.basename(item.filename)) + filelist.pop(idx) + embeddings.pop(idx) + return + + if len(item.detected_faces) > 1 and not is_filter: + logger.warning("%s faces detected for nfilter in '%s'. All of these identies will be " + "used", len(item.detected_faces), os.path.basename(item.filename)) + embeddings[idx] = identities + return + + def _identity_from_extractor(self, file_list: list[str], aligned: list[str]) -> None: + """ Obtain the identity embeddings from the extraction pipeline + + Parameters + ---------- + filesile_list: list + List of full path to images to run through the extraction pipeline + aligned: list + List of full path to images that exist in attr:`filelist` that are faceswap aligned + images + """ + logger.info("Extracting faces to obtain identity from images") + logger.debug("Files requiring full extraction: %s", + [fname for fname in file_list if fname not in aligned]) + logger.debug("Aligned files requiring identity info: %s", aligned) + + loader = PipelineLoader(file_list, self._extractor, aligned_filenames=aligned) + loader.launch() + + for phase in range(self._extractor.passes): + is_final = self._extractor.final_pass + detected_faces: dict[str, ExtractMedia] = {} + self._extractor.launch() + desc = "Obtaining reference face Identity" + if self._extractor.passes > 1: + desc = (f"{desc} pass {phase + 1} of {self._extractor.passes}: " + f"{self._extractor.phase_text}") + for extract_media in tqdm(self._extractor.detected_faces(), + total=len(file_list), + file=sys.stdout, + desc=desc): + if is_final: + self._process_extracted(extract_media) + else: + extract_media.remove_image() + # cache extract_media for next run + detected_faces[extract_media.filename] = extract_media + + if not is_final: + logger.debug("Reloading images") + loader.reload(detected_faces) + + self._extractor.reset_phase_index() + + def _get_embeddings(self) -> None: + """ Obtain the embeddings for the given filter lists """ + needs_extraction: list[str] = [] + aligned: list[str] = [] + + for files, embed in zip((self._filter_files, self._nfilter_files), + (self._embeddings, self._nembeddings)): + for idx, file in enumerate(files): + identity, is_aligned = self._identity_from_extracted(file) + if np.any(identity): + logger.debug("Obtained identity from png header: '%s'", file) + embed[idx] = identity[None, ...] + continue + + needs_extraction.append(file) + if is_aligned: + aligned.append(file) + + if needs_extraction: + self._identity_from_extractor(needs_extraction, aligned) + + if not self._nfilter_files and not self._filter_files: + logger.error("No faces were detected from your selected identity filter files") + sys.exit(1) + + logger.debug("Filter: (filenames: %s, shape: %s), nFilter: (filenames: %s, shape: %s)", + [os.path.basename(f) for f in self._filter_files], + self.embeddings.shape, + [os.path.basename(f) for f in self._nfilter_files], + self.n_embeddings.shape) + + +class PipelineLoader(): + """ Handles loading and reloading images into the extraction pipeline. + + Parameters + ---------- + path: str or list of str + Full path to a folder of images or a video file or a list of image files + extractor: :class:`~plugins.extract.pipeline.Extractor` + The extractor pipeline for obtaining face identity from images + aligned_filenames: list, optional + Used for when the loader is used for getting face filter embeddings. List of full path to + image files that exist in :attr:`path` that are aligned faceswap images + """ + def __init__(self, + path: str | list[str], + extractor: Extractor, + aligned_filenames: list[str] | None = None) -> None: + logger.debug("Initializing %s: (path: %s, extractor: %s, aligned_filenames: %s)", + self.__class__.__name__, path, extractor, aligned_filenames) + self._images = ImagesLoader(path, fast_count=True) + self._extractor = extractor + self._threads: list[MultiThread] = [] + self._aligned_filenames = [] if aligned_filenames is None else aligned_filenames + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def is_video(self) -> bool: + """ bool: ``True`` if the input location is a video file, ``False`` if it is a folder of + images """ + return self._images.is_video + + @property + def file_list(self) -> list[str]: + """ list: A full list of files in the source location. If the input is a video + then this is a list of dummy filenames as corresponding to an alignments file """ + return self._images.file_list + + @property + def process_count(self) -> int: + """ int: The number of images or video frames to be processed (IE the total count less + items that are to be skipped from the :attr:`skip_list`)""" + return self._images.process_count + + def add_skip_list(self, skip_list: list[int]) -> None: + """ Add a skip list to the :class:`ImagesLoader` + + Parameters + ---------- + skip_list: list + A list of indices corresponding to the frame indices that should be skipped by the + :func:`load` function. + """ + self._images.add_skip_list(skip_list) + + def launch(self) -> None: + """ Launch the image loading pipeline """ + self._threaded_redirector("load") + + def reload(self, detected_faces: dict[str, ExtractMedia]) -> None: + """ Reload images for multiple pipeline passes """ + self._threaded_redirector("reload", (detected_faces, )) + + def check_thread_error(self) -> None: + """ Check if any errors have occurred in the running threads and raise their errors """ + for thread in self._threads: + thread.check_and_raise_error() + + def join(self) -> None: + """ Join all open loader threads """ + for thread in self._threads: + thread.join() + + def _threaded_redirector(self, task: str, io_args: tuple | None = None) -> None: + """ Redirect image input/output tasks to relevant queues in background thread + + Parameters + ---------- + task: str + The name of the task to be put into a background thread + io_args: tuple, optional + Any arguments that need to be provided to the background function + """ logger.debug("Threading task: (Task: '%s')", task) - io_args = tuple() if io_args is None else (io_args, ) - if task == "load": - func = self.load_images - elif task == "save": - func = self.save_faces - elif task == "reload": - func = self.reload_images + io_args = tuple() if io_args is None else io_args + func = getattr(self, f"_{task}") io_thread = MultiThread(func, *io_args, thread_count=1) io_thread.start() - return io_thread + self._threads.append(io_thread) + + def _load(self) -> None: + """ Load the images - def load_images(self): - """ Load the images """ + Loads images from :class:`lib.image.ImagesLoader`, formats them into a dict compatible + with :class:`plugins.extract.Pipeline.Extractor` and passes them into the extraction queue. + """ logger.debug("Load Images: Start") - load_queue = queue_manager.get_queue("load") - for filename, image in self.images.load(): - if load_queue.shutdown.is_set(): + load_queue = self._extractor.input_queue + for filename, image in self._images.load(): + if load_queue.shutdown_event.is_set(): logger.debug("Load Queue: Stop signal received. Terminating") break - if image is None or not image.any(): - logger.warning("Unable to open image. Skipping: '%s'", filename) - continue - imagename = os.path.basename(filename) - if imagename in self.alignments.data.keys(): - logger.trace("Skipping image: '%s'", filename) - continue - item = {"filename": filename, - "image": image} + is_aligned = filename in self._aligned_filenames + item = ExtractMedia(filename, image[..., :3], is_aligned=is_aligned) load_queue.put(item) load_queue.put("EOF") logger.debug("Load Images: Complete") - def reload_images(self, detected_faces): - """ Reload the images and pair to detected face """ + def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None: + """ Reload the images and pair to detected face + + When the extraction pipeline is running in serial mode, images are reloaded from disk, + paired with their extraction data and passed back into the extraction queue + + Parameters + ---------- + detected_faces: dict + Dictionary of :class:`~plugins.extract.extract_media.ExtractMedia` with the filename as + the key for repopulating the image attribute. + """ logger.debug("Reload Images: Start. Detected Faces Count: %s", len(detected_faces)) - load_queue = queue_manager.get_queue("detect") - for filename, image in self.images.load(): - if load_queue.shutdown.is_set(): + load_queue = self._extractor.input_queue + for filename, image in self._images.load(): + if load_queue.shutdown_event.is_set(): logger.debug("Reload Queue: Stop signal received. Terminating") break - logger.trace("Reloading image: '%s'", filename) - detect_item = detected_faces.pop(filename, None) - if not detect_item: + logger.trace("Reloading image: '%s'", filename) # type: ignore + extract_media = detected_faces.pop(filename, None) + if not extract_media: logger.warning("Couldn't find faces for: %s", filename) continue - detect_item["image"] = image - load_queue.put(detect_item) + extract_media.set_image(image) + load_queue.put(extract_media) load_queue.put("EOF") logger.debug("Reload Images: Complete") - @staticmethod - def save_faces(): - """ Save the generated faces """ - logger.debug("Save Faces: Start") - save_queue = queue_manager.get_queue("save") - while True: - if save_queue.shutdown.is_set(): - logger.debug("Save Queue: Stop signal received. Terminating") - break - item = save_queue.get() - if item == "EOF": - break - filename, face = item - - logger.trace("Saving face: '%s'", filename) - try: - with open(filename, "wb") as out_file: - out_file.write(face) - except Exception as err: # pylint: disable=broad-except - logger.error("Failed to save image '%s'. Original Error: %s", filename, err) - continue - logger.debug("Save Faces: Complete") - - def run_extraction(self): - """ Run Face Detection """ - save_queue = queue_manager.get_queue("save") - to_process = self.process_item_count() - frame_no = 0 - size = self.args.size if hasattr(self.args, "size") else 256 - align_eyes = self.args.align_eyes if hasattr(self.args, "align_eyes") else False - - if self.plugins.is_parallel: - logger.debug("Using parallel processing") - self.plugins.launch_aligner() - self.plugins.launch_detector() - if not self.plugins.is_parallel: - logger.debug("Using serial processing") - self.run_detection(to_process) - self.plugins.launch_aligner() - - for faces in tqdm(self.plugins.detect_faces(extract_pass="align"), - total=to_process, - file=sys.stdout, - desc="Extracting faces"): - - filename = faces["filename"] - - self.align_face(faces, align_eyes, size, filename) - self.post_process.do_actions(faces) - - faces_count = len(faces["detected_faces"]) - if faces_count == 0: - logger.verbose("No faces were detected in image: %s", - os.path.basename(filename)) - - if not self.verify_output and faces_count > 1: - self.verify_output = True - - self.output_faces(filename, faces, save_queue) - - frame_no += 1 - if frame_no == self.save_interval: - self.alignments.save() - frame_no = 0 - - save_queue.put("EOF") - - def process_item_count(self): - """ Return the number of items to be processedd """ - processed = sum(os.path.basename(frame) in self.alignments.data.keys() - for frame in self.images.input_images) - logger.debug("Items already processed: %s", processed) - - if processed != 0 and self.args.skip_existing: - logger.info("Skipping previously extracted frames: %s", processed) - if processed != 0 and self.args.skip_faces: - logger.info("Skipping frames with detected faces: %s", processed) - - to_process = self.images.images_found - processed - logger.debug("Items to be Processed: %s", to_process) - if to_process == 0: - logger.error("No frames to process. Exiting") - queue_manager.terminate_queues() - exit(0) - return to_process - - def run_detection(self, to_process): - """ Run detection only """ - self.plugins.launch_detector() - detected_faces = dict() - for detected in tqdm(self.plugins.detect_faces(extract_pass="detect"), - total=to_process, - file=sys.stdout, - desc="Detecting faces"): - exception = detected.get("exception", False) - if exception: - break - - del detected["image"] - filename = detected["filename"] - - detected_faces[filename] = detected - - self.threaded_io("reload", detected_faces) - - def align_face(self, faces, align_eyes, size, filename): - """ Align the detected face and add the destination file path """ - final_faces = list() - image = faces["image"] - landmarks = faces["landmarks"] - detected_faces = faces["detected_faces"] - for idx, face in enumerate(detected_faces): - detected_face = DetectedFace() - detected_face.from_dlib_rect(face, image) - detected_face.landmarksXY = landmarks[idx] - detected_face.load_aligned(image, size=size, align_eyes=align_eyes) - final_faces.append({"file_location": self.output_dir / Path(filename).stem, - "face": detected_face}) - faces["detected_faces"] = final_faces - - def output_faces(self, filename, faces, save_queue): - """ Output faces to save thread """ - final_faces = list() - for idx, detected_face in enumerate(faces["detected_faces"]): - output_file = detected_face["file_location"] - extension = Path(filename).suffix - out_filename = "{}_{}{}".format(str(output_file), str(idx), extension) - - face = detected_face["face"] - resized_face = face.aligned_face - - face.hash, img = hash_encode_image(resized_face, extension) - save_queue.put((out_filename, img)) - final_faces.append(face.to_alignment()) - self.alignments.data[os.path.basename(filename)] = final_faces - -class Plugins(): - """ Detector and Aligner Plugins and queues """ - def __init__(self, arguments, converter_args=None): - logger.debug("Initializing %s", self.__class__.__name__) - self.args = arguments - self.converter_args = converter_args # Arguments from converter for on the fly extract - if converter_args is not None: - self.loglevel = converter_args["loglevel"] - else: - self.loglevel = self.args.loglevel - - self.detector = self.load_detector() - self.aligner = self.load_aligner() - self.is_parallel = self.set_parallel_processing() - - self.process_detect = None - self.process_align = None - self.add_queues() +class _Extract(): + """ The Actual extraction process. + + This class is called by the parent :class:`Extract` process + + Parameters + ---------- + extractor: :class:`~plugins.extract.pipeline.Extractor` + The extractor pipeline for running extractions + arguments: :class:`argparse.Namespace` + The arguments to be passed to the extraction process as generated from Faceswap's command + line arguments + """ + def __init__(self, + extractor: Extractor, + arguments: Namespace) -> None: + logger.debug("Initializing %s: (extractor: %s, args: %s)", self.__class__.__name__, + extractor, arguments) + self._args = arguments + self._output_dir = None if self._args.skip_saving_faces else get_folder( + self._args.output_dir) + + logger.info("Output Directory: %s", self._output_dir) + self._loader = PipelineLoader(self._args.input_dir, extractor) + + self._alignments = Alignments(self._args, True, self._loader.is_video) + self._extractor = extractor + self._extractor.import_data(self._args.input_dir) + + self._existing_count = 0 + self._set_skip_list() + + self._post_process = PostProcess(arguments) + self._verify_output = False logger.debug("Initialized %s", self.__class__.__name__) - def set_parallel_processing(self): - """ Set whether to run detect and align together or separately """ - detector_vram = self.detector.vram - aligner_vram = self.aligner.vram - gpu_stats = GPUStats() - if (detector_vram == 0 - or aligner_vram == 0 - or gpu_stats.device_count == 0): - logger.debug("At least one of aligner or detector have no VRAM requirement. " - "Enabling parallel processing.") - return True - - if hasattr(self.args, "multiprocess") and not self.args.multiprocess: - logger.info("NB: Parallel processing disabled.You may get faster " - "extraction speeds by enabling it with the -mp switch") - return False - - required_vram = detector_vram + aligner_vram + 320 # 320MB buffer - stats = gpu_stats.get_card_most_free() - free_vram = int(stats["free"]) - logger.verbose("%s - %sMB free of %sMB", - stats["device"], - free_vram, - int(stats["total"])) - if free_vram <= required_vram: - logger.warning("Not enough free VRAM for parallel processing. " - "Switching to serial") - return False - return True - - def add_queues(self): - """ Add the required processing queues to Queue Manager """ - for task in ("load", "detect", "align", "save"): - size = 0 - if task == "load" or (not self.is_parallel and task == "detect"): - size = 100 - queue_manager.add_queue(task, maxsize=size) - - def load_detector(self): - """ Set global arguments and load detector plugin """ - if not self.converter_args: - detector_name = self.args.detector.replace("-", "_").lower() - else: - detector_name = self.converter_args["detector"] - logger.debug("Loading Detector: '%s'", detector_name) - # Rotation - rotation = self.args.rotate_images if hasattr(self.args, "rotate_images") else None - # Min acceptable face size: - min_size = self.args.min_size if hasattr(self.args, "min_size") else 0 - - detector = PluginLoader.get_detector(detector_name)( - loglevel=self.loglevel, - rotation=rotation, - min_size=min_size) - - return detector - - def load_aligner(self): - """ Set global arguments and load aligner plugin """ - if not self.converter_args: - aligner_name = self.args.aligner.replace("-", "_").lower() - else: - aligner_name = self.converter_args["aligner"] - - logger.debug("Loading Aligner: '%s'", aligner_name) - - aligner = PluginLoader.get_aligner(aligner_name)( - loglevel=self.loglevel) - - return aligner - - def launch_aligner(self): - """ Launch the face aligner """ - logger.debug("Launching Aligner") - out_queue = queue_manager.get_queue("align") - kwargs = {"in_queue": queue_manager.get_queue("detect"), - "out_queue": out_queue} - - self.process_align = SpawnProcess(self.aligner.run, **kwargs) - event = self.process_align.event - self.process_align.start() - - # Wait for Aligner to take it's VRAM - # The first ever load of the model for FAN has reportedly taken - # up to 3-4 minutes, hence high timeout. - # TODO investigate why this is and fix if possible - for mins in reversed(range(5)): - event.wait(60) - if event.is_set(): - break - if mins == 0: - raise ValueError("Error initializing Aligner") - logger.info("Waiting for Aligner... Time out in %s minutes", mins) - - logger.debug("Launched Aligner") - - def launch_detector(self): - """ Launch the face detector """ - logger.debug("Launching Detector") - out_queue = queue_manager.get_queue("detect") - kwargs = {"in_queue": queue_manager.get_queue("load"), - "out_queue": out_queue} - if self.converter_args: - kwargs["processes"] = 1 - mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess - self.process_detect = mp_func(self.detector.run, **kwargs) - - event = None - if hasattr(self.process_detect, "event"): - event = self.process_detect.event - - self.process_detect.start() - - if event is None: - logger.debug("Launched Detector") + @property + def _save_interval(self) -> int | None: + """ int: The number of frames to be processed between each saving of the alignments file if + it has been provided, otherwise ``None`` """ + if hasattr(self._args, "save_interval"): + return self._args.save_interval + return None + + @property + def _skip_num(self) -> int: + """ int: Number of frames to skip if extract_every_n has been provided """ + return self._args.extract_every_n if hasattr(self._args, "extract_every_n") else 1 + + def _set_skip_list(self) -> None: + """ Add the skip list to the image loader + + Checks against `extract_every_n` and the existence of alignments data (can exist if + `skip_existing` or `skip_existing_faces` has been provided) and compiles a list of frame + indices that should not be processed, providing these to :class:`lib.image.ImagesLoader`. + """ + if self._skip_num == 1 and not self._alignments.data: + logger.debug("No frames to be skipped") return - - for mins in reversed(range(5)): - event.wait(60) - if event.is_set(): - break - if mins == 0: - raise ValueError("Error initializing Detector") - logger.info("Waiting for Detector... Time out in %s minutes", mins) - - logger.debug("Launched Detector") - - def detect_faces(self, extract_pass="detect"): - """ Detect faces from in an image """ - logger.debug("Running Detection. Pass: '%s'", extract_pass) - if self.is_parallel or extract_pass == "align": - out_queue = queue_manager.get_queue("align") - if not self.is_parallel and extract_pass == "detect": - out_queue = queue_manager.get_queue("detect") - - while True: - try: - faces = out_queue.get(True, 1) - if faces == "EOF": - break - if isinstance(faces, dict) and faces.get("exception"): - pid = faces["exception"][0] - t_back = faces["exception"][1].getvalue() - err = "Error in child process {}. {}".format(pid, t_back) - raise Exception(err) - except QueueEmpty: + skip_list = [] + for idx, filename in enumerate(self._loader.file_list): + if idx % self._skip_num != 0: + logger.trace("Adding image '%s' to skip list due to " # type: ignore + "extract_every_n = %s", filename, self._skip_num) + skip_list.append(idx) + # Items may be in the alignments file if skip-existing[-faces] is selected + elif os.path.basename(filename) in self._alignments.data: + self._existing_count += 1 + logger.trace("Removing image: '%s' due to previously existing", # type: ignore + filename) + skip_list.append(idx) + if self._existing_count != 0: + logger.info("Skipping %s frames due to skip_existing/skip_existing_faces.", + self._existing_count) + logger.debug("Adding skip list: %s", skip_list) + self._loader.add_skip_list(skip_list) + + def process(self) -> None: + """ The entry point for triggering the Extraction Process. + + Should only be called from :class:`lib.cli.launcher.ScriptExecutor` + """ + # from lib.queue_manager import queue_manager ; queue_manager.debug_monitor(3) + self._loader.launch() + self._run_extraction() + self._loader.join() + self._alignments.save() + finalize(self._loader.process_count + self._existing_count, + self._alignments.faces_count, + self._verify_output) + + def _run_extraction(self) -> None: + """ The main Faceswap Extraction process + + Receives items from :class:`plugins.extract.Pipeline.Extractor` and either saves out the + faces and data (if on the final pass) or reprocesses data through the pipeline for serial + processing. + """ + size = self._args.size if hasattr(self._args, "size") else 256 + saver = None if self._args.skip_saving_faces else ImagesSaver(self._output_dir, + as_bytes=True) + for phase in range(self._extractor.passes): + is_final = self._extractor.final_pass + detected_faces: dict[str, ExtractMedia] = {} + self._extractor.launch() + self._loader.check_thread_error() + ph_desc = "Extraction" if self._extractor.passes == 1 else self._extractor.phase_text + desc = f"Running pass {phase + 1} of {self._extractor.passes}: {ph_desc}" + for idx, extract_media in enumerate(tqdm(self._extractor.detected_faces(), + total=self._loader.process_count, + file=sys.stdout, + desc=desc, + leave=False)): + self._loader.check_thread_error() + if is_final: + self._output_processing(extract_media, size) + self._output_faces(saver, extract_media) + if self._save_interval and (idx + 1) % self._save_interval == 0: + self._alignments.save() + else: + extract_media.remove_image() + # cache extract_media for next run + detected_faces[extract_media.filename] = extract_media + + if not is_final: + logger.debug("Reloading images and resetting PyTorch memory cache") + torch.cuda.empty_cache() + self._loader.reload(detected_faces) + if saver is not None: + saver.close() + + def _output_processing(self, extract_media: ExtractMedia, size: int) -> None: + """ Prepare faces for output + + Loads the aligned face, generate the thumbnail, perform any processing actions and verify + the output. + + Parameters + ---------- + extract_media: :class:`~plugins.extract.extract_media.ExtractMedia` + Output from :class:`plugins.extract.pipeline.Extractor` + size: int + The size that the aligned face should be created at + """ + for face in extract_media.detected_faces: + face.load_aligned(extract_media.image, + size=size, + centering="head") + face.thumbnail = generate_thumbnail(face.aligned.face, size=96, quality=60) + self._post_process.do_actions(extract_media) + extract_media.remove_image() + + faces_count = len(extract_media.detected_faces) + if faces_count == 0: + logger.verbose("No faces were detected in image: %s", # type: ignore + os.path.basename(extract_media.filename)) + + if not self._verify_output and faces_count > 1: + self._verify_output = True + + def _output_faces(self, saver: ImagesSaver | None, extract_media: ExtractMedia) -> None: + """ Output faces to save thread + + Set the face filename based on the frame name and put the face to the + :class:`~lib.image.ImagesSaver` save queue and add the face information to the alignments + data. + + Parameters + ---------- + saver: :class:`lib.images.ImagesSaver` or ``None`` + The background saver for saving the image or ``None`` if faces are not to be saved + extract_media: :class:`~plugins.extract.extract_media.ExtractMedia` + The output from :class:`~plugins.extract.Pipeline.Extractor` + """ + logger.trace("Outputting faces for %s", extract_media.filename) # type: ignore + final_faces = [] + filename = os.path.splitext(os.path.basename(extract_media.filename))[0] + + skip_idx = 0 + for face_id, face in enumerate(extract_media.detected_faces): + real_face_id = face_id - skip_idx + output_filename = f"{filename}_{real_face_id}.png" + aligned = face.aligned.face + assert aligned is not None + meta: PNGHeaderDict = { + "alignments": face.to_png_meta(), + "source": {"alignments_version": self._alignments.version, + "original_filename": output_filename, + "face_index": real_face_id, + "source_filename": os.path.basename(extract_media.filename), + "source_is_video": self._loader.is_video, + "source_frame_dims": extract_media.image_size}} + image = encode_image(aligned, ".png", metadata=meta) + + sub_folder = extract_media.sub_folders[face_id] + # Binned faces shouldn't risk filename clash, so just use original id + out_name = output_filename if not sub_folder else f"{filename}_{face_id}.png" + + if saver is not None: + saver.save(out_name, image, sub_folder) + + if sub_folder: # This is a filtered out face being binned + skip_idx += 1 continue + final_faces.append(face.to_alignment()) + + self._alignments.data[os.path.basename(extract_media.filename)] = {"faces": final_faces, + "video_meta": {}} + del extract_media + - yield faces - logger.debug("Detection Complete") +__all__ = get_module_objects(__name__) diff --git a/scripts/fsmedia.py b/scripts/fsmedia.py index 1c03af875e..913235610f 100644 --- a/scripts/fsmedia.py +++ b/scripts/fsmedia.py @@ -1,108 +1,144 @@ #!/usr/bin/env python3 -""" Holds the classes for the 3 main Faceswap 'media' objects for - input (extract) and output (convert) tasks. Those being: - Images - Faces - Alignments""" +""" Helper functions for :mod:`~scripts.extract` and :mod:`~scripts.convert`. +Holds the classes for the 2 main Faceswap 'media' objects: Images and Alignments. + +Holds optional pre/post processing functions for convert and extract. +""" +from __future__ import annotations import logging import os -from pathlib import Path +import sys +import typing as T + +from collections.abc import Iterator import cv2 import numpy as np - -from lib.aligner import Extract as AlignerExtract -from lib.alignments import Alignments as AlignmentsBase -from lib.face_filter import FaceFilter as FilterFunc -from lib.utils import (camel_case_split, get_folder, get_image_paths, - set_system_verbosity, _video_extensions) - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Utils(): - """ Holds utility functions that are required by more than one media - object """ - - @staticmethod - def set_verbosity(loglevel): - """ Set the system output verbosity """ - set_system_verbosity(loglevel) - - @staticmethod - def finalize(images_found, num_faces_detected, verify_output): - """ Finalize the image processing """ - logger.info("-------------------------") - logger.info("Images found: %s", images_found) - logger.info("Faces detected: %s", num_faces_detected) +import imageio + +from lib.align import Alignments as AlignmentsBase, get_centered_size +from lib.image import count_frames, read_image +from lib.utils import camel_case_split, get_image_paths, get_module_objects, VIDEO_EXTENSIONS + +if T.TYPE_CHECKING: + from collections.abc import Generator + from argparse import Namespace + from lib.align import AlignedFace + from plugins.extract import ExtractMedia + +logger = logging.getLogger(__name__) + + +def finalize(images_found: int, num_faces_detected: int, verify_output: bool) -> None: + """ Output summary statistics at the end of the extract or convert processes. + + Parameters + ---------- + images_found: int + The number of images/frames that were processed + num_faces_detected: int + The number of faces that have been detected + verify_output: bool + ``True`` if multiple faces were detected in frames otherwise ``False``. + """ + logger.info("-------------------------") + logger.info("Images found: %s", images_found) + logger.info("Faces detected: %s", num_faces_detected) + logger.info("-------------------------") + + if verify_output: + logger.info("Note:") + logger.info("Multiple faces were detected in one or more pictures.") + logger.info("Double check your results.") logger.info("-------------------------") - if verify_output: - logger.info("Note:") - logger.info("Multiple faces were detected in one or more pictures.") - logger.info("Double check your results.") - logger.info("-------------------------") - - logger.info("Process Succesfully Completed. Shutting Down...") + logger.info("Process Successfully Completed. Shutting Down...") class Alignments(AlignmentsBase): - """ Override main alignments class for extract """ - def __init__(self, arguments, is_extract, input_is_video=False): + """ Override :class:`lib.align.Alignments` to add custom loading based on command + line arguments. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments that were passed to Faceswap + is_extract: bool + ``True`` if the process calling this class is extraction otherwise ``False`` + input_is_video: bool, optional + ``True`` if the input to the process is a video, ``False`` if it is a folder of images. + Default: False + """ + def __init__(self, + arguments: Namespace, + is_extract: bool, + input_is_video: bool = False) -> None: logger.debug("Initializing %s: (is_extract: %s, input_is_video: %s)", self.__class__.__name__, is_extract, input_is_video) - self.args = arguments - self.is_extract = is_extract - folder, filename = self.set_folder_filename(input_is_video) - serializer = self.set_serializer() - super().__init__(folder, - filename=filename, - serializer=serializer) + self._args = arguments + self._is_extract = is_extract + folder, filename = self._set_folder_filename(input_is_video) + super().__init__(folder, filename=filename) logger.debug("Initialized %s", self.__class__.__name__) - def set_folder_filename(self, input_is_video): - """ Return the folder for the alignments file""" - if self.args.alignments_path: - logger.debug("Alignments File provided: '%s'", self.args.alignments_path) - folder, filename = os.path.split(str(self.args.alignments_path)) + def _set_folder_filename(self, input_is_video: bool) -> tuple[str, str]: + """ Return the folder and the filename for the alignments file. + + If the input is a video, the alignments file will be stored in the same folder + as the video, with filename `_alignments`. + + If the input is a folder of images, the alignments file will be stored in folder with + the images and just be called 'alignments' + + Parameters + ---------- + input_is_video: bool, optional + ``True`` if the input to the process is a video, ``False`` if it is a folder of images. + + Returns + ------- + folder: str + The folder where the alignments file will be stored + filename: str + The filename of the alignments file + """ + if self._args.alignments_path: + logger.debug("Alignments File provided: '%s'", self._args.alignments_path) + folder, filename = os.path.split(str(self._args.alignments_path)) elif input_is_video: - logger.debug("Alignments from Video File: '%s'", self.args.input_dir) - folder, filename = os.path.split(self.args.input_dir) - filename = "{}_alignments".format(os.path.splitext(filename)[0]) + logger.debug("Alignments from Video File: '%s'", self._args.input_dir) + folder, filename = os.path.split(self._args.input_dir) + filename = f"{os.path.splitext(filename)[0]}_alignments.fsa" else: - logger.debug("Alignments from Input Folder: '%s'", self.args.input_dir) - folder = str(self.args.input_dir) + logger.debug("Alignments from Input Folder: '%s'", self._args.input_dir) + folder = str(self._args.input_dir) filename = "alignments" logger.debug("Setting Alignments: (folder: '%s' filename: '%s')", folder, filename) return folder, filename - def set_serializer(self): - """ Set the serializer to be used for loading and - saving alignments """ - if hasattr(self.args, "serializer") and self.args.serializer: - logger.debug("Serializer provided: '%s'", self.args.serializer) - serializer = self.args.serializer - else: - # If there is a full filename then this will be overriden - # by filename extension - serializer = "json" - logger.debug("No Serializer defaulting to: '%s'", serializer) - return serializer - - def load(self): - """ Override parent loader to handle skip existing on extract """ - data = dict() - if not self.is_extract: - if not self.have_alignments_file: - return data - data = super().load() + def _load(self) -> dict[str, T.Any]: + """ Override the parent :func:`~lib.align.Alignments._load` to handle skip existing + frames and faces on extract. + + If skip existing has been selected, existing alignments are loaded and returned to the + calling script. + + Returns + ------- + dict + Any alignments that have already been extracted if skip existing has been selected + otherwise an empty dictionary + """ + data: dict[str, T.Any] = {} + if not self._is_extract and not self.have_alignments_file: + return data + if not self._is_extract: + data = super()._load() return data - skip_existing = bool(hasattr(self.args, 'skip_existing') - and self.args.skip_existing) - skip_faces = bool(hasattr(self.args, 'skip_faces') - and self.args.skip_faces) + skip_existing = hasattr(self._args, 'skip_existing') and self._args.skip_existing + skip_faces = hasattr(self._args, 'skip_faces') and self._args.skip_faces if not skip_existing and not skip_faces: logger.debug("No skipping selected. Returning empty dictionary") @@ -112,316 +148,474 @@ def load(self): logger.warning("Skip Existing/Skip Faces selected, but no alignments file found!") return data - try: - with open(self.file, self.serializer.roptions) as align: - data = self.serializer.unmarshal(align.read()) - except IOError as err: - logger.error("Error: '%s' not read: %s", self.file, err.strerror) - exit(1) + data = super()._load() if skip_faces: - # Remove items from algnments that have no faces so they will + # Remove items from alignments that have no faces so they will # be re-detected - del_keys = [key for key, val in data.items() if not val] + del_keys = [key for key, val in data.items() if not val["faces"]] logger.debug("Frames with no faces selected for redetection: %s", len(del_keys)) for key in del_keys: if key in data: - logger.trace("Selected for redetection: '%s'", key) + logger.trace("Selected for redetection: '%s'", # type:ignore[attr-defined] + key) del data[key] return data class Images(): - """ Holds the full frames/images """ - def __init__(self, arguments): + """ Handles the loading of frames from a folder of images or a video file for extract + and convert processes. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments that were passed to Faceswap + """ + def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s", self.__class__.__name__) - self.args = arguments - self.is_video = self.check_input_folder() - self.input_images = self.get_input_images() + self._args = arguments + self._is_video = self._check_input_folder() + self._input_images = self._get_input_images() + self._images_found = self._count_images() logger.debug("Initialized %s", self.__class__.__name__) @property - def images_found(self): - """ Number of images or frames """ - if self.is_video: - cap = cv2.VideoCapture(self.args.input_dir) # pylint: disable=no-member - retval = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # pylint: disable=no-member - cap.release() + def is_video(self) -> bool: + """bool: ``True`` if the input is a video file otherwise ``False``. """ + return self._is_video + + @property + def input_images(self) -> str | list[str]: + """str or list: Path to the video file if the input is a video otherwise list of + image paths. """ + return self._input_images + + @property + def images_found(self) -> int: + """int: The number of frames that exist in the video file, or the folder of images. """ + return self._images_found + + def _count_images(self) -> int: + """ Get the number of Frames from a video file or folder of images. + + Returns + ------- + int + The number of frames in the image source + """ + if self._is_video: + retval = int(count_frames(self._args.input_dir, fast=True)) else: - retval = len(self.input_images) + retval = len(self._input_images) return retval - def check_input_folder(self): - """ Check whether the input is a folder or video """ - if not os.path.exists(self.args.input_dir): - logger.error("Input location %s not found.", self.args.input_dir) - exit(1) - if (os.path.isfile(self.args.input_dir) and - os.path.splitext(self.args.input_dir)[1] in _video_extensions): - logger.info("Input Video: %s", self.args.input_dir) + def _check_input_folder(self) -> bool: + """ Check whether the input is a folder or video. + + Returns + ------- + bool + ``True`` if the input is a video otherwise ``False`` + """ + if not os.path.exists(self._args.input_dir): + logger.error("Input location %s not found.", self._args.input_dir) + sys.exit(1) + if (os.path.isfile(self._args.input_dir) and + os.path.splitext(self._args.input_dir)[1].lower() in VIDEO_EXTENSIONS): + logger.info("Input Video: %s", self._args.input_dir) retval = True else: - logger.info("Input Directory: %s", self.args.input_dir) + logger.info("Input Directory: %s", self._args.input_dir) retval = False return retval - def get_input_images(self): - """ Return the list of images or video file that is to be processed """ - if self.is_video: - input_images = self.args.input_dir + def _get_input_images(self) -> str | list[str]: + """ Return the list of images or path to video file that is to be processed. + + Returns + ------- + str or list + Path to the video file if the input is a video otherwise list of image paths. + """ + if self._is_video: + input_images = self._args.input_dir else: - input_images = get_image_paths(self.args.input_dir) + input_images = get_image_paths(self._args.input_dir) return input_images - def load(self): - """ Load an image and yield it with it's filename """ - iterator = self.load_video_frames if self.is_video else self.load_disk_frames + def load(self) -> Generator[tuple[str, np.ndarray], None, None]: + """ Generator to load frames from a folder of images or from a video file. + + Yields + ------ + filename: str + The filename of the current frame + image: :class:`numpy.ndarray` + A single frame + """ + iterator = self._load_video_frames if self._is_video else self._load_disk_frames for filename, image in iterator(): yield filename, image - def load_disk_frames(self): - """ Load frames from disk """ + def _load_disk_frames(self) -> Generator[tuple[str, np.ndarray], None, None]: + """ Generator to load frames from a folder of images. + + Yields + ------ + filename: str + The filename of the current frame + image: :class:`numpy.ndarray` + A single frame + """ logger.debug("Input is separate Frames. Loading images") - for filename in self.input_images: - logger.trace("Loading image: '%s'", filename) - try: - image = cv2.imread(filename) # pylint: disable=no-member - except Exception as err: # pylint: disable=broad-except - logger.error("Failed to load image '%s'. Original Error: %s", filename, err) + for filename in self._input_images: + image = read_image(filename, raise_error=False) + if image is None: continue yield filename, image - def load_video_frames(self): - """ Return frames from a video file """ + def _load_video_frames(self) -> Generator[tuple[str, np.ndarray], None, None]: + """ Generator to load frames from a video file. + + Yields + ------ + filename: str + The filename of the current frame + image: :class:`numpy.ndarray` + A single frame + """ logger.debug("Input is video. Capturing frames") - vidname = os.path.splitext(os.path.basename(self.args.input_dir))[0] - cap = cv2.VideoCapture(self.args.input_dir) # pylint: disable=no-member - i = 0 - while True: - ret, frame = cap.read() - if not ret: - logger.debug("Video terminated") - break - i += 1 - # Keep filename format for outputted face - filename = "{}_{:06d}.png".format(vidname, i) - logger.trace("Loading video frame: '%s'", filename) + vidname, ext = os.path.splitext(os.path.basename(self._args.input_dir)) + reader = imageio.get_reader(self._args.input_dir, "ffmpeg") # type:ignore[arg-type] + for i, frame in enumerate(T.cast(Iterator[np.ndarray], reader)): + # Convert to BGR for cv2 compatibility + frame = frame[:, :, ::-1] + filename = f"{vidname}_{i + 1:06d}{ext}" + logger.trace("Loading video frame: '%s'", filename) # type:ignore[attr-defined] yield filename, frame - cap.release() + reader.close() + + def load_one_image(self, filename) -> np.ndarray: + """ Obtain a single image for the given filename. + + Parameters + ---------- + filename: str + The filename to return the image for + + Returns + ------ + :class:`numpy.ndarray` + The image for the requested filename, + + """ + logger.trace("Loading image: '%s'", filename) # type:ignore[attr-defined] + if self._is_video: + if filename.isdigit(): + frame_no = filename + else: + frame_no = os.path.splitext(filename)[0][filename.rfind("_") + 1:] + logger.trace( # type:ignore[attr-defined] + "Extracted frame_no %s from filename '%s'", frame_no, filename) + retval = self._load_one_video_frame(int(frame_no)) + else: + retval = read_image(filename, raise_error=True) + return retval - @staticmethod - def load_one_image(filename): - """ load requested image """ - logger.trace("Loading image: '%s'", filename) - return cv2.imread(filename) # pylint: disable=no-member + def _load_one_video_frame(self, frame_no: int) -> np.ndarray: + """ Obtain a single frame from a video file. + + Parameters + ---------- + frame_no: int + The frame index for the required frame + + Returns + ------ + :class:`numpy.ndarray` + The image for the requested frame index, + """ + logger.trace("Loading video frame: %s", frame_no) # type:ignore[attr-defined] + reader = imageio.get_reader(self._args.input_dir, "ffmpeg") # type:ignore[arg-type] + reader.set_image_index(frame_no - 1) + frame = reader.get_next_data()[:, :, ::-1] # type:ignore[index] + reader.close() + return frame class PostProcess(): - """ Optional post processing tasks """ - def __init__(self, arguments): + """ Optional pre/post processing tasks for convert and extract. + + Builds a pipeline of actions that have optionally been requested to be performed + in this session. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments that were passed to Faceswap + """ + def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s", self.__class__.__name__) - self.args = arguments - self.actions = self.set_actions() + self._args = arguments + self._actions = self._set_actions() logger.debug("Initialized %s", self.__class__.__name__) - def set_actions(self): - """ Compile the actions to be performed into a list """ - postprocess_items = self.get_items() - actions = list() + def _set_actions(self) -> list[PostProcessAction]: + """ Compile the requested actions to be performed into a list + + Returns + ------- + list + The list of :class:`PostProcessAction` to be performed + """ + postprocess_items = self._get_items() + actions: list["PostProcessAction"] = [] for action, options in postprocess_items.items(): - options = dict() if options is None else options + options = {} if options is None else options args = options.get("args", tuple()) - kwargs = options.get("kwargs", dict()) + kwargs = options.get("kwargs", {}) args = args if isinstance(args, tuple) else tuple() - kwargs = kwargs if isinstance(kwargs, dict) else dict() + kwargs = kwargs if isinstance(kwargs, dict) else {} task = globals()[action](*args, **kwargs) - logger.debug("Adding Postprocess action: '%s'", task) - actions.append(task) + if task.valid: + logger.debug("Adding Postprocess action: '%s'", task) + actions.append(task) - for action in actions: - action_name = camel_case_split(action.__class__.__name__) + for ppaction in actions: + action_name = camel_case_split(ppaction.__class__.__name__) logger.info("Adding post processing item: %s", " ".join(action_name)) return actions - def get_items(self): - """ Set the post processing actions """ - postprocess_items = dict() + def _get_items(self) -> dict[str, dict[str, tuple | dict] | None]: + """ Check the passed in command line arguments for requested actions, + + For any requested actions, add the item to the actions list along with + any relevant arguments and keyword arguments. + + Returns + ------- + dict + The name of the action to be performed as the key. Any action specific + arguments and keyword arguments as the value. + """ + postprocess_items: dict[str, dict[str, tuple | dict] | None] = {} # Debug Landmarks - if (hasattr(self.args, 'debug_landmarks') - and self.args.debug_landmarks): + if (hasattr(self._args, 'debug_landmarks') and self._args.debug_landmarks): postprocess_items["DebugLandmarks"] = None - # Blurry Face - if hasattr(self.args, 'blur_thresh') and self.args.blur_thresh: - kwargs = {"blur_thresh": self.args.blur_thresh} - postprocess_items["BlurryFaceFilter"] = {"kwargs": kwargs} - - # Face Filter post processing - if ((hasattr(self.args, "filter") and self.args.filter is not None) or - (hasattr(self.args, "nfilter") and - self.args.nfilter is not None)): - face_filter = dict() - filter_lists = dict() - if hasattr(self.args, "ref_threshold"): - face_filter["ref_threshold"] = self.args.ref_threshold - for filter_type in ('filter', 'nfilter'): - filter_args = getattr(self.args, filter_type, None) - filter_args = None if not filter_args else filter_args - filter_lists[filter_type] = filter_args - face_filter["filter_lists"] = filter_lists - postprocess_items["FaceFilter"] = {"kwargs": face_filter} - logger.debug("Postprocess Items: %s", postprocess_items) return postprocess_items - def do_actions(self, output_item): - """ Perform the requested post-processing actions """ - for action in self.actions: + def do_actions(self, extract_media: ExtractMedia) -> None: + """ Perform the requested optional post-processing actions on the given image. + + Parameters + ---------- + extract_media: :class:`~plugins.extract.extract_media.ExtractMedia` + The :class:`~plugins.extract.extract_media.ExtractMedia` object to perform the + action on. + + Returns + ------- + :class:`~plugins.extract.extract_media.ExtractMedia` + The original :class:`~plugins.extract.extract_media.ExtractMedia` with any actions + applied + """ + for action in self._actions: logger.debug("Performing postprocess action: '%s'", action.__class__.__name__) - action.process(output_item) + action.process(extract_media) + +class PostProcessAction(): + """ Parent class for Post Processing Actions. -class PostProcessAction(): # pylint: disable=too-few-public-methods - """ Parent class for Post Processing Actions - Usuable in Extract or Convert or both - depending on context """ - def __init__(self, *args, **kwargs): + Usable in Extract or Convert or both depending on context. Any post-processing actions should + inherit from this class. + + Parameters + ----------- + args: tuple + Varies for specific post process action + kwargs: dict + Varies for specific post process action + """ + def __init__(self, *args, **kwargs) -> None: logger.debug("Initializing %s: (args: %s, kwargs: %s)", self.__class__.__name__, args, kwargs) + self._valid = True # Set to False if invalid parameters passed in to disable logger.debug("Initialized base class %s", self.__class__.__name__) - def process(self, output_item): - """ Override for specific post processing action """ + @property + def valid(self) -> bool: + """bool: ``True`` if the action if the parameters passed in for this action are valid, + otherwise ``False`` """ + return self._valid + + def process(self, extract_media: ExtractMedia) -> None: + """ Override for specific post processing action + + Parameters + ---------- + extract_media: :class:`~plugins.extract.extract_media.ExtractMedia` + The :class:`~plugins.extract.extract_media.ExtractMedia` object to perform the + action on. + """ raise NotImplementedError -class BlurryFaceFilter(PostProcessAction): # pylint: disable=too-few-public-methods - """ Move blurry faces to a different folder - Extract Only """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.blur_thresh = kwargs["blur_thresh"] - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self, output_item): - """ Detect and move blurry face """ - extractor = AlignerExtract() - - for idx, detected_face in enumerate(output_item["detected_faces"]): - frame_name = detected_face["file_location"].parts[-1] - face = detected_face["face"] - logger.trace("Checking for blurriness. Frame: '%s', Face: %s", frame_name, idx) - aligned_landmarks = face.aligned_landmarks - resized_face = face.aligned_face - size = face.aligned["size"] - padding = int(size * 0.1875) - feature_mask = extractor.get_feature_mask( - aligned_landmarks / size, - size, padding) - feature_mask = cv2.blur( # pylint: disable=no-member - feature_mask, (10, 10)) - isolated_face = cv2.multiply( # pylint: disable=no-member - feature_mask, - resized_face.astype(float)).astype(np.uint8) - blurry, focus_measure = self.is_blurry(isolated_face) - - if blurry: - blur_folder = detected_face["file_location"].parts[:-1] - blur_folder = get_folder(Path(*blur_folder) / Path("blurry")) - detected_face["file_location"] = blur_folder / Path(frame_name) - logger.verbose("%s's focus measure of %s was below the blur threshold, " - "moving to 'blurry'", frame_name, "{0:.2f}".format(focus_measure)) - - def is_blurry(self, image): - """ Convert to grayscale, and compute the focus measure of the image using the - Variance of Laplacian method """ - gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # pylint: disable=no-member - focus_measure = self.variance_of_laplacian(gray) - - # if the focus measure is less than the supplied threshold, - # then the image should be considered "blurry" - retval = (focus_measure < self.blur_thresh, focus_measure) - logger.trace("Returning: (is_blurry: %s, focus_measure %s)", retval[0], retval[1]) - return retval - - @staticmethod - def variance_of_laplacian(image): - """ Compute the Laplacian of the image and then return the focus - measure, which is simply the variance of the Laplacian """ - retval = cv2.Laplacian(image, cv2.CV_64F).var() # pylint: disable=no-member - logger.trace("Returning: %s", retval) - return retval - - -class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-methods - """ Draw debug landmarks on face - Extract Only """ - - def process(self, output_item): - """ Draw landmarks on image """ - for idx, detected_face in enumerate(output_item["detected_faces"]): - face = detected_face["face"] - logger.trace("Drawing Landmarks. Frame: '%s'. Face: %s", - detected_face["file_location"].parts[-1], idx) - aligned_landmarks = face.aligned_landmarks - for (pos_x, pos_y) in aligned_landmarks: - cv2.circle( # pylint: disable=no-member - face.aligned_face, - (pos_x, pos_y), 2, (0, 0, 255), -1) - - -class FaceFilter(PostProcessAction): - """ Filter in or out faces based on input image(s) - Extract or Convert """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - filter_lists = kwargs["filter_lists"] - ref_threshold = kwargs.get("ref_threshold", 0.6) - self.filter = self.load_face_filter(filter_lists, ref_threshold) - logger.debug("Initialized %s", self.__class__.__name__) - - def load_face_filter(self, filter_lists, ref_threshold): - """ Load faces to filter out of images """ - if not any(val for val in filter_lists.values()): - return None - - filter_files = [self.set_face_filter(f_type, filter_lists[f_type]) - for f_type in ("filter", "nfilter")] - - if any(filters for filters in filter_files): - facefilter = FilterFunc(filter_files[0], - filter_files[1], - ref_threshold) - logger.debug("Face filter: %s", facefilter) - return facefilter - - @staticmethod - def set_face_filter(f_type, f_args): - """ Set the required filters """ - if not f_args: - return list() - - logger.info("%s: %s", f_type.title(), f_args) - filter_files = f_args if isinstance(f_args, list) else [f_args] - filter_files = list(filter(lambda fpath: Path(fpath).exists(), filter_files)) - logger.debug("Face Filter files: %s", filter_files) - return filter_files - - def process(self, output_item): - """ Filter in/out wanted/unwanted faces """ - if not self.filter: - return - - ret_faces = list() - for idx, detected_face in enumerate(output_item["detected_faces"]): - if not self.filter.check(detected_face["face"]): - logger.verbose("Skipping not recognized face! Frame: %s Face %s", - detected_face["file_location"].parts[-1], idx) - continue - logger.trace("Accepting recognised face. Frame: %s. Face: %s", - detected_face["file_location"].parts[-1], idx) - ret_faces.append(detected_face) - output_item["detected_faces"] = ret_faces +class DebugLandmarks(PostProcessAction): + """ Draw debug landmarks on face output. Extract Only """ + def __init__(self, *args, **kwargs) -> None: + super().__init__(self, *args, **kwargs) + self._face_size = 0 + self._legacy_size = 0 + self._font = cv2.FONT_HERSHEY_SIMPLEX + self._font_scale = 0.0 + self._font_pad = 0 + + def _initialize_font(self, size: int) -> None: + """ Set the font scaling sizes on first call + + Parameters + ---------- + size: int + The pixel size of the saved aligned face + """ + self._font_scale = size / 512 + self._font_pad = size // 64 + + def _border_text(self, + image: np.ndarray, + text: str, + color: tuple[int, int, int], + position: tuple[int, int]) -> None: + """ Create text on an image with a black border + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to put bordered text on to + text: str + The text to place the image + color: tuple + The color of the text + position: tuple + The (x, y) co-ordinates to place the text + """ + thickness = 2 + for idx in range(2): + text_color = (0, 0, 0) if idx == 0 else color + cv2.putText(image, + text, + position, + self._font, + self._font_scale, + text_color, + thickness, + lineType=cv2.LINE_AA) + thickness //= 2 + + def _annotate_face_box(self, face: AlignedFace) -> None: + """ Annotate the face extract box and print the original size in pixels + + face: :class:`~lib.align.AlignedFace` + The object containing the aligned face to annotate + """ + assert face.face is not None + color = (0, 255, 0) + roi = face.get_cropped_roi(face.size, self._face_size, "face") + cv2.rectangle(face.face, tuple(roi[:2]), tuple(roi[2:]), color, 1) + + # Size in top right corner + roi_pnts = np.array([[roi[0], roi[1]], + [roi[0], roi[3]], + [roi[2], roi[3]], + [roi[2], roi[1]]]) + orig_roi = face.transform_points(roi_pnts, invert=True) + size = int(round(((orig_roi[1][0] - orig_roi[0][0]) ** 2 + + (orig_roi[1][1] - orig_roi[0][1]) ** 2) ** 0.5)) + text_img = face.face.copy() + text = f"{size}px" + text_size = cv2.getTextSize(text, self._font, self._font_scale, 1)[0] + pos_x = roi[2] - (text_size[0] + self._font_pad) + pos_y = roi[1] + text_size[1] + self._font_pad + + self._border_text(text_img, text, color, (pos_x, pos_y)) + cv2.addWeighted(text_img, 0.75, face.face, 0.25, 0, face.face) + + def _print_stats(self, face: AlignedFace) -> None: + """ Print various metrics on the output face images + + Parameters + ---------- + face: :class:`~lib.align.AlignedFace` + The loaded aligned face + """ + assert face.face is not None + text_image = face.face.copy() + texts = [f"pitch: {face.pose.pitch:.2f}", + f"yaw: {face.pose.yaw:.2f}", + f"roll: {face.pose.roll: .2f}", + f"distance: {face.average_distance:.2f}"] + colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 255, 255)] + text_sizes = [cv2.getTextSize(text, self._font, self._font_scale, 1)[0] for text in texts] + + final_y = face.size - text_sizes[-1][1] + pos_y = [(size[1] + self._font_pad) * (idx + 1) + for idx, size in enumerate(text_sizes)][:-1] + [final_y] + pos_x = self._font_pad + + for idx, text in enumerate(texts): + self._border_text(text_image, text, colors[idx], (pos_x, pos_y[idx])) + + # Apply text to face + cv2.addWeighted(text_image, 0.75, face.face, 0.25, 0, face.face) + + def process(self, extract_media: ExtractMedia) -> None: + """ Draw landmarks on a face. + + Parameters + ---------- + extract_media: :class:`~plugins.extract.extract_media.ExtractMedia` + The :class:`~plugins.extract.extract_media.ExtractMedia` object that contains the faces + to draw the landmarks on to + """ + frame = os.path.splitext(os.path.basename(extract_media.filename))[0] + for idx, face in enumerate(extract_media.detected_faces): + if not self._face_size: + self._face_size = get_centered_size(face.aligned.centering, + "face", + face.aligned.size) + logger.debug("set face size: %s", self._face_size) + if not self._legacy_size: + self._legacy_size = get_centered_size(face.aligned.centering, + "legacy", + face.aligned.size) + logger.debug("set legacy size: %s", self._legacy_size) + if not self._font_scale: + self._initialize_font(face.aligned.size) + + logger.trace("Drawing Landmarks. Frame: '%s'. Face: %s", # type:ignore[attr-defined] + frame, idx) + # Landmarks + assert face.aligned.face is not None + for (pos_x, pos_y) in face.aligned.landmarks.astype("int32"): + cv2.circle(face.aligned.face, (pos_x, pos_y), 1, (0, 255, 255), -1) + # Pose + center = (face.aligned.size // 2, face.aligned.size // 2) + points = (face.aligned.pose.xyz_2d * face.aligned.size).astype("int32") + cv2.line(face.aligned.face, center, tuple(points[1]), (0, 255, 0), 1) + cv2.line(face.aligned.face, center, tuple(points[0]), (255, 0, 0), 1) + cv2.line(face.aligned.face, center, tuple(points[2]), (0, 0, 255), 1) + # Face centering + self._annotate_face_box(face.aligned) + # Legacy centering + roi = face.aligned.get_cropped_roi(face.aligned.size, self._legacy_size, "legacy") + cv2.rectangle(face.aligned.face, tuple(roi[:2]), tuple(roi[2:]), (0, 0, 255), 1) + self._print_stats(face.aligned) + + +__all__ = get_module_objects(__name__) diff --git a/scripts/gui.py b/scripts/gui.py index da2b815d1a..efacc94068 100644 --- a/scripts/gui.py +++ b/scripts/gui.py @@ -2,122 +2,195 @@ """ The optional GUI for faceswap """ import logging -import os import sys import tkinter as tk from tkinter import messagebox, ttk -from lib.gui import (CliOptions, CommandNotebook, ConsoleOut, Session, DisplayNotebook, - get_config, get_images, initialize_images, initialize_config, MainMenuBar, - ProcessWrapper, StatusBar) +from lib.gui import (TaskBar, CliOptions, CommandNotebook, ConsoleOut, DisplayNotebook, + get_images, gui_config as cfg, initialize_images, initialize_config, + LastSession, MainMenuBar, preview_trigger, ProcessWrapper, StatusBar) +from lib.utils import get_module_objects -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) class FaceswapGui(tk.Tk): - """ The Graphical User Interface """ + """ The Graphical User Interface - def __init__(self, pathscript): + Launch the Faceswap GUI + + Parameters + ---------- + debug : bool + Output to the terminal rather than to Faceswap's internal console + config_file : str | None + Path to a custom .ini configuration file. ``None`` to use the default config file + """ + + def __init__(self, debug, config_file): logger.debug("Initializing %s", self.__class__.__name__) super().__init__() + cfg.load_config(config_file) + + self._init_args = {"debug": debug} + self._config = self.initialize_globals() + self.set_fonts() + self._config.set_geometry(1200, 640, cfg.fullscreen()) - self.initialize_globals(pathscript) - self.set_geometry() - self.wrapper = ProcessWrapper(pathscript) + self.wrapper = ProcessWrapper() + self.objects = {} get_images().delete_preview() + preview_trigger().clear(trigger_type=None) self.protocol("WM_DELETE_WINDOW", self.close_app) + self.build_gui() + self._last_session = LastSession(self._config) logger.debug("Initialized %s", self.__class__.__name__) - def initialize_globals(self, pathscript): + def initialize_globals(self): """ Initialize config and images global constants """ cliopts = CliOptions() - scaling_factor = self.get_scaling() - pathcache = os.path.join(pathscript, "lib", "gui", ".cache") statusbar = StatusBar(self) - session = Session() - initialize_config(cliopts, scaling_factor, pathcache, statusbar, session) + config = initialize_config(self, cliopts, statusbar) initialize_images() + return config - def get_scaling(self): - """ Get the display DPI """ - dpi = self.winfo_fpixels("1i") - scaling = dpi / 72.0 - logger.debug("dpi: %s, scaling: %s'", dpi, scaling) - return scaling - - def set_geometry(self): - """ Set GUI geometry """ - scaling_factor = get_config().scaling_factor - self.tk.call("tk", "scaling", scaling_factor) - width = int(1200 * scaling_factor) - height = int(640 * scaling_factor) - logger.debug("Geometry: %sx%s", width, height) - self.geometry("{}x{}+80+80".format(str(width), str(height))) - - def build_gui(self, debug_console): + def set_fonts(self): + """ Set global default font """ + tk.font.nametofont("TkFixedFont").configure(size=self._config.default_font[1]) + for font in ("TkDefaultFont", "TkHeadingFont", "TkMenuFont"): + tk.font.nametofont(font).configure(family=self._config.default_font[0], + size=self._config.default_font[1]) + + def build_gui(self, rebuild=False): """ Build the GUI """ logger.debug("Building GUI") - self.title("Faceswap.py") - self.tk.call('wm', 'iconphoto', self._w, get_images().icons["favicon"]) - self.configure(menu=MainMenuBar(self)) - - topcontainer, bottomcontainer = self.add_containers() - - CommandNotebook(topcontainer) - DisplayNotebook(topcontainer) - ConsoleOut(bottomcontainer, debug_console) + if not rebuild: + self.tk.call('wm', 'iconphoto', self._w, get_images().icons["favicon"]) + self.configure(menu=MainMenuBar(self)) + + if rebuild: + objects = list(self.objects.keys()) + for obj in objects: + self.objects[obj].destroy() + del self.objects[obj] + + self.objects["taskbar"] = TaskBar(self) + self.add_containers() + + self.objects["command"] = CommandNotebook(self.objects["container_top"]) + self.objects["display"] = DisplayNotebook(self.objects["container_top"]) + self.objects["console"] = ConsoleOut(self.objects["container_bottom"], + self._init_args["debug"]) + self.set_initial_focus() + self.set_layout() + self._config.set_default_options() logger.debug("Built GUI") def add_containers(self): """ Add the paned window containers that hold each main area of the gui """ logger.debug("Adding containers") - maincontainer = tk.PanedWindow(self, - sashrelief=tk.RAISED, - orient=tk.VERTICAL) + maincontainer = ttk.PanedWindow(self, + orient=tk.VERTICAL, + name="pw_main") maincontainer.pack(fill=tk.BOTH, expand=True) - topcontainer = tk.PanedWindow(maincontainer, - sashrelief=tk.RAISED, - orient=tk.HORIZONTAL) + topcontainer = ttk.PanedWindow(maincontainer, + orient=tk.HORIZONTAL, + name="pw_top") maincontainer.add(topcontainer) - bottomcontainer = ttk.Frame(maincontainer, height=150) + bottomcontainer = ttk.Frame(maincontainer, name="frame_bottom") maincontainer.add(bottomcontainer) + self.objects["container_main"] = maincontainer + self.objects["container_top"] = topcontainer + self.objects["container_bottom"] = bottomcontainer logger.debug("Added containers") - return topcontainer, bottomcontainer - def close_app(self): + def set_initial_focus(self): + """ Set the tab focus from settings """ + tab = cfg.tab() + logger.debug("Setting focus for tab: %s", tab) + self._config.set_active_tab_by_name(tab) + logger.debug("Focus set to: %s", tab) + + def set_layout(self): + """ Set initial layout """ + self.update_idletasks() + r_width = self.winfo_width() + r_height = self.winfo_height() + w_ratio = cfg.options_panel_width() / 100.0 + h_ratio = 1 - (cfg.console_panel_height() / 100.0) + width = round(r_width * w_ratio) + height = round(r_height * h_ratio) + logger.debug("Setting Initial Layout: (root_width: %s, root_height: %s, width_ratio: %s, " + "height_ratio: %s, width: %s, height: %s", r_width, r_height, w_ratio, + h_ratio, width, height) + self.objects["container_top"].sashpos(0, width) + self.objects["container_main"].sashpos(0, height) + self.update_idletasks() + + def rebuild(self): + """ Rebuild the GUI on config change """ + logger.debug("Redrawing GUI") + session_state = self._last_session.to_dict() + get_images().__init__() # pylint:disable=unnecessary-dunder-call + self.set_fonts() + self.build_gui(rebuild=True) + if session_state is not None: + self._last_session.from_dict(session_state) + logger.debug("GUI Redrawn") + + def close_app(self, *args): # pylint:disable=unused-argument """ Close Python. This is here because the graph animation function continues to run even when tkinter has gone away """ logger.debug("Close Requested") - confirm = messagebox.askokcancel - confirmtxt = "Processes are still running. Are you sure...?" - tk_vars = get_config().tk_vars - if (tk_vars["runningtask"].get() - and not confirm("Close", confirmtxt)): - logger.debug("Close Cancelled") + + if not self._confirm_close_on_running_task(): + return + if not self._config.project.confirm_close(): return - if tk_vars["runningtask"].get(): + + if self._config.tk_vars.running_task.get(): self.wrapper.task.terminate() + + self._last_session.save() get_images().delete_preview() + preview_trigger().clear(trigger_type=None) self.quit() logger.debug("Closed GUI") - exit() + sys.exit(0) + + def _confirm_close_on_running_task(self): + """ Pop a confirmation box to close the GUI if a task is running + + Returns + ------- + bool: ``True`` if user confirms close, ``False`` if user cancels close + """ + if not self._config.tk_vars.running_task.get(): + logger.debug("No tasks currently running") + return True + + confirmtxt = "Processes are still running.\n\nAre you sure you want to exit?" + if not messagebox.askokcancel("Close", confirmtxt, default="cancel", icon="warning"): + logger.debug("Close Cancelled") + return False + logger.debug("Close confirmed") + return True -class Gui(): # pylint: disable=too-few-public-methods +class Gui(): """ The GUI process. """ def __init__(self, arguments): - cmd = sys.argv[0] - pathscript = os.path.realpath(os.path.dirname(cmd)) - self.args = arguments - self.root = FaceswapGui(pathscript) + self.root = FaceswapGui(arguments.debug, arguments.configfile) def process(self): """ Builds the GUI """ - self.root.build_gui(self.args.debug) self.root.mainloop() + + +__all__ = get_module_objects(__name__) diff --git a/scripts/train.py b/scripts/train.py index e2ff1507f4..bd4e6d6dcd 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,274 +1,454 @@ #!/usr/bin python3 -""" The script to run the training process of faceswap """ - +""" Main entry point to the training process of FaceSwap """ +from __future__ import annotations import logging import os import sys +import typing as T -from threading import Lock from time import sleep +from threading import Event import cv2 -import tensorflow as tf -from keras.backend.tensorflow_backend import set_session +import numpy as np +from lib.gui.utils.image import TRAININGPREVIEW +from lib.image import read_image_meta from lib.keypress import KBHit -from lib.multithreading import MultiThread -from lib.queue_manager import queue_manager -from lib.utils import (get_folder, get_image_paths, set_system_verbosity) +from lib.multithreading import MultiThread, FSThread +from lib.training import Preview, PreviewBuffer, TriggerType +from lib.utils import (get_folder, get_image_paths, get_module_objects, handle_deprecated_cliopts, + FaceswapError, IMAGE_EXTENSIONS) from plugins.plugin_loader import PluginLoader +from plugins.train.training import Trainer + +if T.TYPE_CHECKING: + import argparse + from collections.abc import Callable + from plugins.train.model._base import ModelBase + -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) class Train(): - """ The training process. """ - def __init__(self, arguments): + """ The Faceswap Training Process. + + The training process is responsible for training a model on a set of source faces and a set of + destination faces. + + The training process is self contained and should not be referenced by any other scripts, so it + contains no public properties. + + Parameters + ---------- + arguments: argparse.Namespace + The arguments to be passed to the training process as generated from Faceswap's command + line arguments + """ + def __init__(self, arguments: argparse.Namespace) -> None: logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments) - self.args = arguments - self.timelapse = self.set_timelapse() - self.images = self.get_images() - self.stop = False - self.save_now = False - self.preview_buffer = dict() - self.lock = Lock() - - self.trainer_name = self.args.trainer + self._args = handle_deprecated_cliopts(arguments) + + if self._args.summary: + # If just outputting summary we don't need to initialize everything + return + + self._images = self._get_images() + self._timelapse = self._set_timelapse() + gui_cache = os.path.join( + os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache") + self._gui_triggers: dict[T.Literal["mask", "refresh"], str] = { + "mask": os.path.join(gui_cache, ".preview_mask_toggle"), + "refresh": os.path.join(gui_cache, ".preview_trigger")} + self._stop: bool = False + self._save_now: bool = False + self._preview = PreviewInterface(self._args.preview) + logger.debug("Initialized %s", self.__class__.__name__) - def set_timelapse(self): - """ Set timelapse paths if requested """ - if (not self.args.timelapse_input_a and - not self.args.timelapse_input_b and - not self.args.timelapse_output): - return None - if not self.args.timelapse_input_a or not self.args.timelapse_input_b: - raise ValueError("To enable the timelapse, you have to supply " - "all the parameters (--timelapse-input-A and " - "--timelapse-input-B).") - - for folder in (self.args.timelapse_input_a, - self.args.timelapse_input_b, - self.args.timelapse_output): - if folder is not None and not os.path.isdir(folder): - raise ValueError("The Timelapse path '{}' does not exist".format(folder)) + def _get_images(self) -> dict[T.Literal["a", "b"], list[str]]: + """ Check the image folders exist and contains valid extracted faces. Obtain image paths. - kwargs = {"input_a": self.args.timelapse_input_a, - "input_b": self.args.timelapse_input_b, - "output": self.args.timelapse_output} - logger.debug("Timelapse enabled: %s", kwargs) - return kwargs - - def get_images(self): - """ Check the image dirs exist, contain images and return the image - objects """ + Returns + ------- + dict + The image paths for each side. The key is the side, the value is the list of paths + for that side. + """ logger.debug("Getting image paths") - images = dict() + images = {} for side in ("a", "b"): - image_dir = getattr(self.args, "input_{}".format(side)) + side = T.cast(T.Literal["a", "b"], side) + image_dir = getattr(self._args, f"input_{side}") if not os.path.isdir(image_dir): logger.error("Error: '%s' does not exist", image_dir) - exit(1) + sys.exit(1) - if not os.listdir(image_dir): + images[side] = get_image_paths(image_dir, ".png") + if not images[side]: logger.error("Error: '%s' contains no images", image_dir) - exit(1) - - images[side] = get_image_paths(image_dir) - logger.info("Model A Directory: %s", self.args.input_a) - logger.info("Model B Directory: %s", self.args.input_b) + sys.exit(1) + # Validate the first image is a detected face + test_image = next(img for img in images[side]) + meta = read_image_meta(test_image) + logger.debug("Test file: (filename: %s, metadata: %s)", test_image, meta) + if "itxt" not in meta or "alignments" not in meta["itxt"]: + logger.error("The input folder '%s' contains images that are not extracted faces.", + image_dir) + logger.error("You can only train a model on faces generated from Faceswap's " + "extract process. Please check your sources and try again.") + sys.exit(1) + + logger.info("Model %s Directory: '%s' (%s images)", + side.upper(), image_dir, len(images[side])) logger.debug("Got image paths: %s", [(key, str(len(val)) + " images") for key, val in images.items()]) + self._validate_image_counts(images) return images - def process(self): - """ Call the training process object """ - logger.debug("Starting Training Process") - logger.info("Training data directory: %s", self.args.model_dir) - set_system_verbosity(self.args.loglevel) - thread = self.start_thread() - # queue_manager.debug_monitor(1) + @classmethod + def _validate_image_counts(cls, images: dict[T.Literal["a", "b"], list[str]]) -> None: + """ Validate that there are sufficient images to commence training without raising an + error. + + Confirms that there are at least 24 images in each folder. Whilst this is not enough images + to train a Neural Network to any successful degree, it should allow the process to train + without raising errors when generating previews. + + A warning is raised if there are fewer than 250 images on any side. + + Parameters + ---------- + images: dict + The image paths for each side. The key is the side, the value is the list of paths + for that side. + """ + counts = {side: len(paths) for side, paths in images.items()} + msg = ("You need to provide a significant number of images to successfully train a Neural " + "Network. Aim for between 500 - 5000 images per side.") + if any(count < 25 for count in counts.values()): + logger.error("At least one of your input folders contains fewer than 25 images.") + logger.error(msg) + sys.exit(1) + if any(count < 250 for count in counts.values()): + logger.warning("At least one of your input folders contains fewer than 250 images. " + "Results are likely to be poor.") + logger.warning(msg) + + def _set_timelapse(self) -> dict[T.Literal["input_a", "input_b", "output"], str]: + """ Set time-lapse paths if requested. + + Returns + ------- + dict + The time-lapse keyword arguments for passing to the trainer + + """ + if (not self._args.timelapse_input_a and + not self._args.timelapse_input_b and + not self._args.timelapse_output): + return {} + if (not self._args.timelapse_input_a or + not self._args.timelapse_input_b or + not self._args.timelapse_output): + raise FaceswapError("To enable the timelapse, you have to supply all the parameters " + "(--timelapse-input-A, --timelapse-input-B and " + "--timelapse-output).") + + timelapse_output = get_folder(self._args.timelapse_output) - if self.args.preview: - err = self.monitor_preview(thread) - else: - err = self.monitor_console(thread) + for side in ("a", "b"): + side = T.cast(T.Literal["a", "b"], side) + folder = getattr(self._args, f"timelapse_input_{side}") + if folder is not None and not os.path.isdir(folder): + raise FaceswapError(f"The Timelapse path '{folder}' does not exist") + + training_folder = getattr(self._args, f"input_{side}") + if folder == training_folder: + continue # Time-lapse folder is training folder + + filenames = [fname for fname in os.listdir(folder) + if os.path.splitext(fname)[-1].lower() in IMAGE_EXTENSIONS] + if not filenames: + raise FaceswapError(f"The Timelapse path '{folder}' does not contain any valid " + "images") + + # Time-lapse images must appear in the training set, as we need access to alignment and + # mask info. Check filenames are there to save failing much later in the process. + training_images = [os.path.basename(img) for img in self._images[side]] + if not all(img in training_images for img in filenames): + raise FaceswapError(f"All images in the Timelapse folder '{folder}' must exist in " + f"the training folder '{training_folder}'") + + TKey = T.Literal["input_a", "input_b", "output"] + kwargs = {T.cast(TKey, "input_a"): self._args.timelapse_input_a, + T.cast(TKey, "input_b"): self._args.timelapse_input_b, + T.cast(TKey, "output"): timelapse_output} + logger.debug("Timelapse enabled: %s", kwargs) + return kwargs - self.end_thread(thread, err) + def process(self) -> None: + """ The entry point for triggering the Training Process. + + Should only be called from :class:`lib.cli.launcher.ScriptExecutor` + """ + if self._args.summary: + self._load_model() + return + logger.debug("Starting Training Process") + logger.info("Training data directory: %s", self._args.model_dir) + thread = self._start_thread() + # from lib.queue_manager import queue_manager; queue_manager.debug_monitor(1) + err = self._monitor(thread) + self._end_thread(thread, err) logger.debug("Completed Training Process") - def start_thread(self): - """ Put the training process in a thread so we can keep control """ + def _start_thread(self) -> MultiThread: + """ Put the :func:`_training` into a background thread so we can keep control. + + Returns + ------- + :class:`lib.multithreading.MultiThread` + The background thread for running training + """ logger.debug("Launching Trainer thread") - thread = MultiThread(target=self.training) + thread = MultiThread(target=self._training) thread.start() logger.debug("Launched Trainer thread") return thread - def end_thread(self, thread, err): - """ On termination output message and join thread back to main """ + def _end_thread(self, thread: MultiThread, err: bool) -> None: + """ Output message and join thread back to main on termination. + + Parameters + ---------- + thread: :class:`lib.multithreading.MultiThread` + The background training thread + err: bool + Whether an error has been detected in :func:`_monitor` + """ logger.debug("Ending Training thread") if err: msg = "Error caught! Exiting..." log = logger.critical else: msg = ("Exit requested! The trainer will complete its current cycle, " - "save the models and quit (it can take up a couple of seconds " - "depending on your training speed). If you want to kill it now, " - "press Ctrl + c") + "save the models and quit (This can take a couple of minutes " + "depending on your training speed).") + if not self._args.redirect_gui: + msg += " If you want to kill it now, press Ctrl + c" log = logger.info log(msg) - self.stop = True + self._stop = True thread.join() sys.stdout.flush() - logger.debug("Ended Training thread") + logger.debug("Ended training thread") - def training(self): - """ The training process to be run inside a thread """ + def _training(self) -> None: + """ The training process to be run inside a thread. """ + trainer = None try: - sleep(1) # Let preview instructions flush out to logger + sleep(0.5) # Let preview instructions flush out to logger logger.debug("Commencing Training") logger.info("Loading data, this may take a while...") - - if self.args.allow_growth: - self.set_tf_allow_growth() - - model = self.load_model() - trainer = self.load_trainer(model) - self.run_training_cycle(model, trainer) + model = self._load_model() + trainer = self._load_trainer(model) + if trainer.exit_early: + logger.debug("Trainer exits early") + self._stop = True + return + self._run_training_cycle(trainer) except KeyboardInterrupt: try: logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting") - model.save_models() - trainer.clear_tensorboard() + if trainer is not None: + trainer.save(is_exit=True) except KeyboardInterrupt: logger.info("Saving model weights has been cancelled!") - exit(0) + sys.exit(0) except Exception as err: raise err - def load_model(self): - """ Load the model requested for training """ + def _load_model(self) -> ModelBase: + """ Load the model requested for training. + + Returns + ------- + :file:`plugins.train.model` plugin + The requested model plugin + """ logger.debug("Loading Model") - model_dir = get_folder(self.args.model_dir) - model = PluginLoader.get_model(self.trainer_name)( + model_dir = get_folder(self._args.model_dir) + model: ModelBase = PluginLoader.get_model(self._args.trainer)( model_dir, - self.args.gpus, - no_logs=self.args.no_logs, - warp_to_landmarks=self.args.warp_to_landmarks, - no_flip=self.args.no_flip, - training_image_size=self.image_size, - alignments_paths=self.alignments_paths, - preview_scale=self.args.preview_scale) + self._args, + predict=False) + model.build() logger.debug("Loaded Model") return model - @property - def image_size(self): - """ Get the training set image size for storing in model data """ - image = cv2.imread(self.images["a"][0]) # pylint: disable=no-member - size = image.shape[0] - logger.debug("Training image size: %s", size) - return size + def _load_trainer(self, model: ModelBase) -> Trainer: + """ Load the trainer requested for training. - @property - def alignments_paths(self): - """ Set the alignments path to input dirs if not provided """ - alignments_paths = dict() - for side in ("a", "b"): - alignments_path = getattr(self.args, "alignments_path_{}".format(side)) - if not alignments_path: - image_path = getattr(self.args, "input_{}".format(side)) - alignments_path = os.path.join(image_path, "alignments.json") - alignments_paths[side] = alignments_path - logger.debug("Alignments paths: %s", alignments_paths) - return alignments_paths - - def load_trainer(self, model): - """ Load the trainer requested for training """ + Parameters + ---------- + model: :file:`plugins.train.model` plugin + The requested model plugin + + Returns + ------- + :class:`plugins.train.trainer.run_train.Trainer` + The model training loop with the requested trainer plugin loaded + """ logger.debug("Loading Trainer") - trainer = PluginLoader.get_trainer(model.trainer) - trainer = trainer(model, - self.images, - self.args.batch_size) + trainer = "distributed" if self._args.distributed else "original" + if trainer == "distributed": + import torch # pylint:disable=import-outside-toplevel + gpu_count = torch.cuda.device_count() + if gpu_count < 2: + logger.warning("Distributed selected but fewer than 2 GPUs detected. Switching " + "to Original") + trainer = "original" + + retval = Trainer(PluginLoader.get_trainer(trainer)(model, self._args.batch_size), + self._images) logger.debug("Loaded Trainer") - return trainer + return retval + + def _run_training_cycle(self, trainer: Trainer) -> None: + """ Perform the training cycle. + + Handles the background training, updating previews/time-lapse on each save interval, + and saving the model. - def run_training_cycle(self, model, trainer): - """ Perform the training cycle """ + Parameters + ---------- + trainer: :file:`plugins.train.trainer` plugin + The requested model trainer plugin + """ logger.debug("Running Training Cycle") - if self.args.write_image or self.args.redirect_gui or self.args.preview: - display_func = self.show + update_preview_images = False + if self._args.write_image or self._args.redirect_gui or self._args.preview: + display_func: Callable | None = self._show else: display_func = None - for iteration in range(0, self.args.iterations): - logger.trace("Training iteration: %s", iteration) - save_iteration = iteration % self.args.save_interval == 0 - viewer = display_func if save_iteration or self.save_now else None - timelapse = self.timelapse if save_iteration else None + for iteration in range(1, self._args.iterations + 1): + logger.trace("Training iteration: %s", iteration) # type:ignore + save_iteration = iteration % self._args.save_interval == 0 or iteration == 1 + gui_triggers = self._process_gui_triggers() + + if self._preview.should_toggle_mask or gui_triggers["mask"]: + trainer.toggle_mask() + update_preview_images = True + + if self._preview.should_refresh or gui_triggers["refresh"] or update_preview_images: + viewer = display_func + update_preview_images = False + else: + viewer = None + + timelapse = self._timelapse if save_iteration else {} trainer.train_one_step(viewer, timelapse) - if self.stop: + + if viewer is not None and not save_iteration: + # Spammy but required by GUI to know to update window + print("\x1b[2K", end="\r") # Clear last line + logger.info("[Preview Updated]") + + if self._stop: logger.debug("Stop received. Terminating") break - elif save_iteration: - logger.trace("Save Iteration: (iteration: %s", iteration) - model.save_models() - elif self.save_now: - logger.trace("Save Requested: (iteration: %s", iteration) - model.save_models() - self.save_now = False - logger.debug("Training cycle complete") - model.save_models() - trainer.clear_tensorboard() - self.stop = True - - def monitor_preview(self, thread): - """ Generate the preview window and wait for keyboard input """ - logger.debug("Launching Preview Monitor") - logger.info("R|=====================================================================") - logger.info("R|- Using live preview -") - logger.info("R|- Press 'ENTER' on the preview window to save and quit -") - logger.info("R|- Press 'S' on the preview window to save model weights immediately -") - logger.info("R|=====================================================================") - err = False - while True: - try: - with self.lock: - for name, image in self.preview_buffer.items(): - cv2.imshow(name, image) # pylint: disable=no-member - key = cv2.waitKey(1000) # pylint: disable=no-member - if self.stop: - logger.debug("Stop received") - break - if thread.has_error: - logger.debug("Thread error detected") - err = True - break - if key == ord("\n") or key == ord("\r"): - logger.debug("Exit requested") - break - if key == ord("s"): - logger.info("Save requested") - self.save_now = True - except KeyboardInterrupt: - logger.debug("Keyboard Interrupt received") - break - logger.debug("Closed Preview Monitor") - return err + if save_iteration or self._save_now: + logger.debug("Saving (save_iterations: %s, save_now: %s) Iteration: " + "(iteration: %s)", save_iteration, self._save_now, iteration) + trainer.save(is_exit=False) + self._save_now = False + update_preview_images = True - def monitor_console(self, thread): - """ Monitor the console - NB: A custom function needs to be used for this because - input() blocks """ - logger.debug("Launching Console Monitor") - logger.info("R|===============================================") - logger.info("R|- Starting -") - logger.info("R|- Press 'ENTER' to save and quit -") - logger.info("R|- Press 'S' to save model weights immediately -") - logger.info("R|===============================================") - keypress = KBHit(is_gui=self.args.redirect_gui) + logger.debug("Training cycle complete") + trainer.save(is_exit=True) + self._stop = True + + def _output_startup_info(self) -> None: + """ Print the startup information to the console. """ + logger.debug("Launching Monitor") + logger.info("===================================================") + logger.info(" Starting") + if self._args.preview: + logger.info(" Using live preview") + if sys.stdout.isatty(): + logger.info(" Press '%s' to save and quit", + "Stop" if self._args.redirect_gui else "ENTER") + if not self._args.redirect_gui and sys.stdout.isatty(): + logger.info(" Press 'S' to save model weights immediately") + logger.info("===================================================") + + def _check_keypress(self, keypress: KBHit) -> bool: + """ Check if a keypress has been detected. + + Parameters + ---------- + keypress: :class:`lib.keypress.KBHit` + The keypress monitor + + Returns + ------- + bool + ``True`` if an exit keypress has been detected otherwise ``False`` + """ + retval = False + if keypress.kbhit(): + console_key = keypress.getch() + if console_key in ("\n", "\r"): + logger.debug("Exit requested") + retval = True + if console_key in ("s", "S"): + logger.info("Save requested") + self._save_now = True + return retval + + def _process_gui_triggers(self) -> dict[T.Literal["mask", "refresh"], bool]: + """ Check whether a file drop has occurred from the GUI to manually update the preview. + + Returns + ------- + dict + The trigger name as key and boolean as value + """ + retval: dict[T.Literal["mask", "refresh"], bool] = {key: False + for key in self._gui_triggers} + if not self._args.redirect_gui: + return retval + + for trigger, filename in self._gui_triggers.items(): + if os.path.isfile(filename): + logger.debug("GUI Trigger received for: '%s'", trigger) + retval[trigger] = True + logger.debug("Removing gui trigger file: %s", filename) + os.remove(filename) + if trigger == "refresh": + print("\x1b[2K", end="\r") # Clear last line + logger.info("Refresh preview requested...") + return retval + + def _monitor(self, thread: MultiThread) -> bool: + """ Monitor the background :func:`_training` thread for key presses and errors. + + Parameters + ---------- + thread: :class:`~lib.multithreading.MultiThread` + The thread containing the training loop + + Returns + ------- + bool + ``True`` if there has been an error in the background thread otherwise ``False`` + """ + self._output_startup_info() + keypress = KBHit(is_gui=self._args.redirect_gui) err = False while True: try: @@ -276,65 +456,170 @@ def monitor_console(self, thread): logger.debug("Thread error detected") err = True break - if self.stop: + if self._stop: logger.debug("Stop received") break - if keypress.kbhit(): - key = keypress.getch() - if key in ("\n", "\r"): - logger.debug("Exit requested") - break - if key in ("s", "S"): - logger.info("Save requested") - self.save_now = True + + # Preview Monitor + if self._preview.should_quit: + break + if self._preview.should_save: + self._save_now = True + + # Console Monitor + if self._check_keypress(keypress): + break # Exit requested + + sleep(1) except KeyboardInterrupt: logger.debug("Keyboard Interrupt received") break + logger.debug("Closing Monitor") + self._preview.shutdown() keypress.set_normal_term() - logger.debug("Closed Console Monitor") + logger.debug("Closed Monitor") return err - @staticmethod - def keypress_monitor(keypress_queue): - """ Monitor stdin for keypress """ - while True: - keypress_queue.put(sys.stdin.read(1)) - - @staticmethod - def set_tf_allow_growth(): - """ Allow TensorFlow to manage VRAM growth """ - # pylint: disable=no-member - logger.debug("Setting Tensorflow 'allow_growth' option") - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - config.gpu_options.visible_device_list = "0" - set_session(tf.Session(config=config)) - logger.debug("Set Tensorflow 'allow_growth' option") - - def show(self, image, name=""): - """ Generate the preview and write preview file output """ - logger.trace("Updating preview: (name: %s)", name) + def _show(self, image: np.ndarray, name: str = "") -> None: + """ Generate the preview and write preview file output. + + Handles the output and display of preview images. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The preview image to be displayed and/or written out + name: str, optional + The name of the image for saving or display purposes. If an empty string is passed + then it will automatically be named. Default: "" + """ + logger.debug("Updating preview: (name: %s)", name) try: scriptpath = os.path.realpath(os.path.dirname(sys.argv[0])) - if self.args.write_image: - logger.trace("Saving preview to disk") - img = "training_preview.jpg" + if self._args.write_image: + logger.debug("Saving preview to disk") + img = "training_preview.png" imgfile = os.path.join(scriptpath, img) - cv2.imwrite(imgfile, image) # pylint: disable=no-member - logger.trace("Saved preview to: '%s'", img) - if self.args.redirect_gui: - logger.trace("Generating preview for GUI") - img = ".gui_training_preview.jpg" - imgfile = os.path.join(scriptpath, "lib", "gui", - ".cache", "preview", img) - cv2.imwrite(imgfile, image) # pylint: disable=no-member - logger.trace("Generated preview for GUI: '%s'", img) - if self.args.preview: - logger.trace("Generating preview for display: '%s'", name) - with self.lock: - self.preview_buffer[name] = image - logger.trace("Generated preview for display: '%s'", name) + cv2.imwrite(imgfile, image) # pylint:disable=no-member + logger.debug("Saved preview to: '%s'", img) + if self._args.redirect_gui: + logger.debug("Generating preview for GUI") + img = TRAININGPREVIEW + imgfile = os.path.join(scriptpath, "lib", "gui", ".cache", "preview", img) + cv2.imwrite(imgfile, image) # pylint:disable=no-member + logger.debug("Generated preview for GUI: '%s'", imgfile) + if self._args.preview: + logger.debug("Generating preview for display: '%s'", name) + self._preview.buffer.add_image(name, image) + logger.debug("Generated preview for display: '%s'", name) except Exception as err: logging.error("could not preview sample") raise err - logger.trace("Updated preview: (name: %s)", name) + logger.debug("Updated preview: (name: %s)", name) + + +class PreviewInterface(): + """ Run the preview window in a thread and interface with it + + Parameters + ---------- + use_preview: bool + ``True`` if pop-up preview window has been requested otherwise ``False`` + """ + def __init__(self, use_preview: bool) -> None: + self._active = use_preview + self._triggers: TriggerType = {"toggle_mask": Event(), + "refresh": Event(), + "save": Event(), + "quit": Event(), + "shutdown": Event()} + self._buffer = PreviewBuffer() + self._thread = self._launch_thread() + + @property + def buffer(self) -> PreviewBuffer: + """ :class:`PreviewBuffer`: The thread save preview image object """ + return self._buffer + + @property + def should_toggle_mask(self) -> bool: + """ bool: Check whether the mask should be toggled and return the value. If ``True`` is + returned then resets mask toggle back to ``False`` """ + if not self._active: + return False + retval = self._triggers["toggle_mask"].is_set() + if retval: + logger.debug("Sending toggle mask") + self._triggers["toggle_mask"].clear() + return retval + + @property + def should_refresh(self) -> bool: + """ bool: Check whether the preview should be updated and return the value. If ``True`` is + returned then resets the refresh trigger back to ``False`` """ + if not self._active: + return False + retval = self._triggers["refresh"].is_set() + if retval: + logger.debug("Sending should refresh") + self._triggers["refresh"].clear() + return retval + + @property + def should_save(self) -> bool: + """ bool: Check whether a save request has been made. If ``True`` is returned then save + trigger is set back to ``False`` """ + if not self._active: + return False + retval = self._triggers["save"].is_set() + if retval: + logger.debug("Sending should save") + self._triggers["save"].clear() + return retval + + @property + def should_quit(self) -> bool: + """ bool: Check whether an exit request has been made. ``True`` if an exit request has + been made otherwise ``False``. + + Raises + ------ + Error + Re-raises any error within the preview thread + """ + if self._thread is None: + return False + + self._thread.check_and_raise_error() + + retval = self._triggers["quit"].is_set() + if retval: + logger.debug("Sending should stop") + return retval + + def _launch_thread(self) -> FSThread | None: + """ Launch the preview viewer in it's own thread if preview has been selected + + Returns + ------- + :class:`lib.multithreading.FSThread` or ``None`` + The thread that holds the preview viewer if preview is selected otherwise ``None`` + """ + if not self._active: + return None + thread = FSThread(target=Preview, + name="preview", + args=(self._buffer, ), + kwargs={"triggers": self._triggers}) + thread.start() + return thread + + def shutdown(self) -> None: + """ Send a signal to shutdown the preview window. """ + if not self._active: + return + logger.debug("Sending shutdown to preview viewer") + self._triggers["shutdown"].set() + + +__all__ = get_module_objects(__name__) diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 13459d1465..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -max-line-length = 99 -exclude = .git, __pycache__ diff --git a/setup.py b/setup.py index 2da4bf743e..32891e929d 100755 --- a/setup.py +++ b/setup.py @@ -1,794 +1,1015 @@ #!/usr/bin/env python3 """ Install packages for faceswap.py """ +# pylint:disable=too-many-lines +from __future__ import annotations -# >>> ENV -import ctypes -import locale +import logging +import json import os import re import sys -import platform +import typing as T +from importlib import import_module +from shutil import which +from string import printable +from subprocess import PIPE, Popen -from subprocess import CalledProcessError, run, PIPE, Popen +from lib.logger import log_setup +from lib.system import Cuda, Packages, ROCm, System +from lib.utils import get_module_objects, PROJECT_ROOT +from requirements.requirements import Requirements, PYTHON_VERSIONS -INSTALL_FAILED = False -# Revisions of tensorflow-gpu and cuda/cudnn requirements -TENSORFLOW_REQUIREMENTS = {"1.2": ["8.0", "5.1"], - "1.4": ["8.0", "6.0"], - "1.12": ["9.0", "7.2"]} +if T.TYPE_CHECKING: + from packaging.requirements import Requirement + import pip + import lib.utils as lib_utils +logger = logging.getLogger(__name__) +BackendType: T.TypeAlias = T.Literal['nvidia', 'apple_silicon', 'cpu', 'rocm', "all"] -class Environment(): - """ The current install environment """ - def __init__(self): - self.macos_required_packages = ["pynvx==0.0.4"] - self.conda_required_packages = [("ffmpeg", "conda-forge"), ("tk", )] - self.output = Output() - # Flag that setup is being run by installer so steps can be skipped - self.is_installer = False - self.cuda_path = "" - self.cuda_version = "" - self.cudnn_version = "" - self.enable_docker = False - self.enable_cuda = False - self.required_packages = self.get_required_packages() - self.missing_packages = list() - self.conda_missing_packages = list() - - self.process_arguments() - self.check_permission() - self.check_system() - self.check_python() - self.output_runtime_info() - self.check_pip() - self.upgrade_pip() - - self.installed_packages = self.get_installed_packages() - self.get_installed_conda_packages() +# Conda packages that are required for a specific backend +_CONDA_BACKEND_REQUIRED: dict[BackendType, list[str]] = { + "all": ["tk", "git"]} - @property - def encoding(self): - """ Get system encoding """ - return locale.getpreferredencoding() +# Conda packages that are required for a specific OS +_CONDA_OS_REQUIRED: dict[T.Literal["darwin", "linux", "windows"], list[str]] = { + "linux": ["xorg-libxft"]} # required to fix TK fonts on Linux - @property - def os_version(self): - """ Get OS Verion """ - return platform.system(), platform.release() +# Mapping of Conda packages to channel if in not conda-forge +_CONDA_MAPPING: dict[str, str] = {} - @property - def py_version(self): - """ Get Python Verion """ - return platform.python_version(), platform.architecture()[0] +# Force output to utf-8 +sys.stdout.reconfigure(encoding="utf-8", errors="replace") # type:ignore[union-attr] - @property - def is_macos(self): - """ Check whether MacOS """ - return bool(platform.system() == "Darwin") - @property - def is_conda(self): - """ Check whether using Conda """ - return bool("conda" in sys.version.lower()) +class _InstallState: # pylint:disable=too-few-public-methods + """ Marker to track if a step has failed installing """ + failed = False + messages: list[str] = [] - @property - def ld_library_path(self): - """ Get the ld library path """ - return os.environ.get("LD_LIBRARY_PATH", None) + +class Environment(): + """ The current install environment + + Parameters + ---------- + updater : bool, Optional + ``True`` if the script is being called by Faceswap's internal updater. ``False`` if full + setup is running. Default: ``False`` + """ + _backends = (("nvidia", "apple_silicon", "rocm", "cpu")) + + def __init__(self, updater: bool = False) -> None: + self.updater = updater + self.system = System() + logger.debug("Running on: %s", self.system) + if not updater: + self.system.validate() + self.is_installer: bool = False # Flag setup is being run by installer to skip steps + self.include_dev_tools: bool = False + self.backend: T.Literal["nvidia", "apple_silicon", "cpu", "rocm"] | None = None + self.enable_docker: bool = False + self.cuda_cudnn = ["", ""] + self.requirement_version = "" + self.rocm_version: tuple[int, ...] = (0, 0, 0) + self._process_arguments() + self._output_runtime_info() + self._check_pip() @property - def is_admin(self): - """ Check whether user is admin """ - try: - retval = os.getuid() == 0 - except AttributeError: - retval = ctypes.windll.shell32.IsUserAnAdmin() != 0 - return retval + def cuda_version(self) -> str: + """ str : The detected globally installed Cuda Version """ + return self.cuda_cudnn[0] @property - def is_virtualenv(self): - """ Check whether this is a virtual environment """ - if not self.is_conda: - retval = (hasattr(sys, "real_prefix") or - (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)) - else: - prefix = os.path.dirname(sys.prefix) - retval = (os.path.basename(prefix) == "envs") - return retval + def cudnn_version(self) -> str: + """ str : The detected globally installed cuDNN Version """ + return self.cuda_cudnn[1] - def process_arguments(self): - """ Process any cli arguments """ - argv = [arg for arg in sys.argv] - for arg in argv: + def set_backend(self, backend: T.Literal["nvidia", "apple_silicon", "cpu", "rocm"]) -> None: + """ Set the backend to install for + + Parameters + ---------- + backend : Literal["nvidia", "apple_silicon", "cpu", "rocm"] + The backend to setup faceswap for + """ + logger.debug("Setting backend to '%s'", backend) + self.backend = backend + + def set_requirements(self, requirements: str) -> None: + """ Validate that the requirements are compatible with the running Python version and + set the requirements file version to install use + + Parameters + ---------- + backend : str + The requirements file version to use for install + """ + if requirements in PYTHON_VERSIONS: + self.system.validate_python(max_version=PYTHON_VERSIONS[requirements]) + logger.debug("Setting requirements to '%s'", requirements) + self.requirement_version = requirements + + def _parse_backend_from_cli(self, arg: str) -> None: + """ Parse a command line argument and populate :attr:`backend` if valid + + Parameters + ---------- + arg : str + The command line argument to parse + """ + arg = arg.lower() + if not any(arg.startswith(b) for b in self._backends): + return + self.set_backend(next(b for b in self._backends if arg.startswith(b))) # type:ignore[misc] + if arg == "cpu": + self.set_requirements("cpu") + return + # Get Cuda/ROCm requirements file + assert self.backend is not None + req_files = sorted([os.path.splitext(f)[0].replace("requirements_", "") + for f in os.listdir(os.path.join(PROJECT_ROOT, "requirements")) + if os.path.splitext(f)[-1] == ".txt" + and f.startswith("requirements_") + and self.backend in f]) + if arg == self.backend: # Default to latest + logger.debug("No version specified. Defaulting to latest requirements") + self.set_requirements(req_files[-1]) + return + lookup = [r.replace("_", "") for r in req_files] + if arg not in lookup: + logger.debug("Defaulting to latest requirements for unknown lookup '%s'", arg) + self.set_requirements(req_files[-1]) + return + self.set_requirements(req_files[lookup.index(arg)]) + + def _process_arguments(self) -> None: + """ Process any cli arguments and dummy in cli arguments if calling from updater. """ + args = sys.argv[:] + if self.updater: + get_backend = T.cast("lib_utils", # type:ignore[attr-defined,valid-type] + import_module("lib.utils")).get_backend + args.append(f"--{get_backend()}") + logger.debug(args) + if self.system.is_macos and self.system.machine == "arm64": + self.set_backend("apple_silicon") + self.set_requirements("apple-silicon") + for arg in args: if arg == "--installer": self.is_installer = True - if arg == "--gpu": - self.enable_cuda = True - - @staticmethod - def get_required_packages(): - """ Load requirements list """ - packages = list() - pypath = os.path.dirname(os.path.realpath(__file__)) - requirements_file = os.path.join(pypath, "requirements.txt") - with open(requirements_file) as req: - for package in req.readlines(): - package = package.strip() - if package and (not package.startswith("#")): - packages.append(package) - return packages - - def check_permission(self): - """ Check for Admin permissions """ - if self.is_admin: - self.output.info("Running as Root/Admin") - else: - self.output.warning("Running without root/admin privileges") - - def check_system(self): - """ Check the system """ - self.output.info("The tool provides tips for installation\n" - "and installs required python packages") - self.output.info("Setup in %s %s" % (self.os_version[0], self.os_version[1])) - if not self.os_version[0] in ["Windows", "Linux", "Darwin"]: - self.output.error("Your system %s is not supported!" % self.os_version[0]) - exit(1) - - def check_python(self): - """ Check python and virtual environment status """ - self.output.info("Installed Python: {0} {1}".format(self.py_version[0], - self.py_version[1])) - if not (self.py_version[0].split(".")[0] == "3" - and self.py_version[0].split(".")[1] in ("3", "4", "5", "6") - and self.py_version[1] == "64bit"): - self.output.error("Please run this script with Python version 3.3, 3.4, 3.5 or 3.6 " - "64bit and try again.") - exit(1) - - def output_runtime_info(self): - """ Output runtime info """ - if self.is_conda: - self.output.info("Running in Conda") - if self.is_virtualenv: - self.output.info("Running in a Virtual Environment") - self.output.info("Encoding: {}".format(self.encoding)) - - def check_pip(self): + continue + if arg == "--dev": + self.include_dev_tools = True + continue + if not self.backend and arg.startswith("--"): + self._parse_backend_from_cli(arg[2:]) + + def _output_runtime_info(self) -> None: + """ Output run time info """ + logger.info("Setup in %s %s", self.system.system.title(), self.system.release) + logger.info("Running as %s", "Root/Admin" if self.system.is_admin else "User") + if self.system.is_conda: + logger.info("Running in Conda") + if self.system.is_virtual_env: + logger.info("Running in a Virtual Environment") + logger.info("Encoding: %s", self.system.encoding) + + def _check_pip(self) -> None: """ Check installed pip version """ try: - import pip - except ImportError: - self.output.error("Import pip failed. Please Install python3-pip and try again") - exit(1) - - def upgrade_pip(self): - """ Upgrade pip to latest version """ - if not self.is_conda: - # Don't do this with Conda, as we must use conda's pip - self.output.info("Upgrading pip...") - pipexe = [sys.executable, "-m", "pip"] - pipexe.extend(["install", "--no-cache-dir", "-qq", "--upgrade"]) - if not self.is_admin and not self.is_virtualenv: - pipexe.append("--user") - pipexe.append("pip") - run(pipexe) - import pip - pip_version = pip.__version__ - self.output.info("Installed pip: {}".format(pip_version)) - - def get_installed_packages(self): - """ Get currently installed packages """ - installed_packages = dict() - chk = Popen("{} -m pip freeze".format(sys.executable), - shell=True, stdout=PIPE) - installed = chk.communicate()[0].decode(self.encoding).splitlines() - - for pkg in installed: - if "==" not in pkg: - continue - item = pkg.split("==") - installed_packages[item[0]] = item[1] - return installed_packages + _pip = T.cast("pip", import_module("pip")) # type:ignore[valid-type] + except ModuleNotFoundError: + logger.error("Import pip failed. Please Install python3-pip and try again") + sys.exit(1) + logger.info("Pip version: %s", _pip.__version__) # type:ignore[attr-defined] + + def _configure_keras(self) -> None: + """ Set up the keras.json file to use Torch as the backend """ + if "KERAS_HOME" in os.environ: + keras_dir = os.environ["KERAS_HOME"] + else: + keras_base_dir = os.path.expanduser("~") + if not os.access(keras_base_dir, os.W_OK): + keras_base_dir = "/tmp" + keras_dir = os.path.join(keras_base_dir, ".keras") + keras_dir = os.path.expanduser(keras_dir) + os.makedirs(keras_dir, exist_ok=True) + conf_file = os.path.join(keras_dir, "keras.json") + config = {} + if os.path.exists(conf_file): + try: + with open(conf_file, "r", encoding="utf-8") as c_file: + config = json.load(c_file) + except ValueError: + pass + config["backend"] = "torch" + with open(conf_file, "w", encoding="utf-8") as c_file: + c_file.write(json.dumps(config, indent=4)) + logger.info("Keras config written to: %s", conf_file) + + def set_config(self) -> None: + """ Set the backend in the faceswap config file """ + config = {"backend": self.backend} + pypath = os.path.dirname(os.path.realpath(__file__)) + config_file = os.path.join(pypath, "config", ".faceswap") + with open(config_file, "w", encoding="utf8") as cnf: + json.dump(config, cnf) + logger.info("Faceswap config written to: %s", config_file) + self._configure_keras() + + +class RequiredPackages(): + """ Holds information about installed and required packages. + Handles updating dependencies based on running platform/backend + + Parameters + ---------- + environment : :class:`Environment` + Environment class holding information about the running system + """ + def __init__(self, environment: Environment) -> None: + self._env = environment + self._packages = Packages() + self._requirements = Requirements(include_dev=self._env.include_dev_tools) + self._check_packaging() + self.conda = self._get_missing_conda() + self.python = self._get_missing_python( + self._requirements.requirements[self._env.requirement_version]) + self.pip_arguments = [ + x.strip() + for p in self._requirements.global_options[self._env.requirement_version] + for x in p.split()] + """ list[str] : Any additional pip arguments that are required for installing from pip for + the given backend """ - def get_installed_conda_packages(self): - """ Get currently installed conda packages """ - if not self.is_conda: - return - chk = os.popen("conda list").read() - installed = [re.sub(" +", " ", line.strip()) - for line in chk.splitlines() if not line.startswith("#")] - for pkg in installed: - item = pkg.split(" ") - self.installed_packages[item[0]] = item[1] - - def update_tf_dep(self): - """ Update Tensorflow Dependency """ - if self.is_conda: - self.update_tf_dep_conda() - return + @property + def packages_need_install(self) -> bool: + """bool : ``True`` if there are packages available that need to be installed """ + return bool(self.conda or self.python) - if not self.enable_cuda: - self.required_packages.append("tensorflow") + def _check_packaging(self) -> None: + """ Install packaging if it is not available """ + if self._requirements.packaging_available: return + cmd = [sys.executable, "-u", "-m", "pip", "install", "--no-cache-dir"] + if self._env.system.is_admin and not self._env.system.is_virtual_env: + cmd.append("--user") + cmd.append("packaging") + logger.info("Installing required package...") + installer = Installer(self._env, ["Packaging"], cmd, False, False) + if installer() != 0: + logger.error("Unable to install package: %s. Process aborted", "packaging") + sys.exit(1) + + def _get_missing_python(self, requirements: list[Requirement] + ) -> list[dict[T.Literal["name", "package"], str]]: + """ Check for missing Python dependencies + + Parameters + ---------- + requirements : list[:class:`packaging.requirements.Requirement]` + The packages that are required to be installed + + Returns + ------- + list[dict[Literal["name", "package"], str]] + List of missing Python packages to install + """ + retval: list[dict[T.Literal["name", "package"], str]] = [] + for req in requirements: + package: dict[T.Literal["name", "package"], str] = { + "name": req.name.title(), + "package": f"{req.name}{req.specifier}"} + installed_version = self._packages.installed_python.get(req.name, "") + if not installed_version: + logger.debug("Adding new Python package '%s'", package["package"]) + retval.append(package) + continue + if not req.specifier.contains(installed_version): + logger.debug("Adding Python package '%s' for specifier change from '%s' to '%s'", + package["package"], installed_version, str(req.specifier)) + retval.append(package) + continue + logger.debug("Skipping installed Python package '%s'", package["package"]) + logger.debug("Selected missing Python packages: %s", retval) + return retval - tf_ver = None - cudnn_inst = self.cudnn_version.split(".") - for key, val in TENSORFLOW_REQUIREMENTS.items(): - cuda_req = val[0] - cudnn_req = val[1].split(".") - if cuda_req == self.cuda_version and (cudnn_req[0] == cudnn_inst[0] and - cudnn_req[1] <= cudnn_inst[1]): - tf_ver = key - break - if tf_ver: - tf_ver = "tensorflow-gpu=={}.0".format(tf_ver) - self.required_packages.append(tf_ver) - return + def _get_required_conda(self) -> list[dict[T.Literal["package", "channel"], str]]: + """ Add backend specific packages to Conda required packages - self.output.warning( - "Tensorflow currently has no official prebuild for your CUDA, cuDNN " - "combination.\nEither install a combination that Tensorflow supports or " - "build and install your own tensorflow-gpu.\r\n" - "CUDA Version: {}\r\n" - "cuDNN Version: {}\r\n" - "Help:\n" - "Building Tensorflow: https://www.tensorflow.org/install/install_sources\r\n" - "Tensorflow supported versions: " - "https://www.tensorflow.org/install/source#tested_build_configurations".format( - self.cuda_version, self.cudnn_version)) - - custom_tf = input("Location of custom tensorflow-gpu wheel (leave " - "blank to manually install): ") - if not custom_tf: - return + Returns + ------- + list[tuple[Literal["package", "channel"], str]] + List of required Conda package names and the channel to install from + """ + retval: list[dict[T.Literal["package", "channel"], str]] = [] + assert self._env.backend is not None + to_add = (_CONDA_BACKEND_REQUIRED.get(self._env.backend, []) + + _CONDA_BACKEND_REQUIRED.get("all", []) + + _CONDA_OS_REQUIRED.get(self._env.system.system, [])) + if not to_add: + logger.debug("No packages to add for '%s'('%s'). All backend packages: %s. All OS " + "packages: %s", + self._env.backend, self._env.system, + _CONDA_BACKEND_REQUIRED, _CONDA_OS_REQUIRED) + return retval + for pkg in to_add: + channel = _CONDA_MAPPING.get(pkg, "conda-forge") + retval.append({"package": pkg, "channel": channel}) + logger.debug("Adding conda required package '%s' for system '%s'('%s'))", + pkg, self._env.backend, self._env.system.system) + return retval - custom_tf = os.path.realpath(os.path.expanduser(custom_tf)) - if not os.path.isfile(custom_tf): - self.output.error("{} not found".format(custom_tf)) - elif os.path.splitext(custom_tf)[1] != ".whl": - self.output.error("{} is not a valid pip wheel".format(custom_tf)) - elif custom_tf: - self.required_packages.append(custom_tf) - - def update_tf_dep_conda(self): - """ Update Conda TF Dependency """ - if not self.enable_cuda: - self.required_packages.append("tensorflow==1.12.0") - else: - self.required_packages.append("tensorflow-gpu==1.12.0") - - -class Output(): - """ Format and display output """ - def __init__(self): - self.red = "\033[31m" - self.green = "\033[32m" - self.yellow = "\033[33m" - self.default_color = "\033[0m" - self.term_support_color = platform.system() in ("Linux", "Darwin") - - @staticmethod - def __indent_text_block(text): - """ Indent a text block """ - lines = text.splitlines() - if len(lines) > 1: - out = lines[0] + "\r\n" - for i in range(1, len(lines)-1): - out = out + " " + lines[i] + "\r\n" - out = out + " " + lines[-1] - return out - return text - - def info(self, text): - """ Format INFO Text """ - trm = "INFO " - if self.term_support_color: - trm = "{}INFO {} ".format(self.green, self.default_color) - print(trm + self.__indent_text_block(text)) - - def warning(self, text): - """ Format WARNING Text """ - trm = "WARNING " - if self.term_support_color: - trm = "{}WARNING{} ".format(self.yellow, self.default_color) - print(trm + self.__indent_text_block(text)) - - def error(self, text): - """ Format ERROR Text """ - global INSTALL_FAILED - trm = "ERROR " - if self.term_support_color: - trm = "{}ERROR {} ".format(self.red, self.default_color) - print(trm + self.__indent_text_block(text)) - INSTALL_FAILED = True - - -class Checks(): - """ Pre-installation checks """ - def __init__(self, environment): - self.env = environment - self.output = Output() - self.tips = Tips() + def _get_missing_conda(self) -> dict[str, list[dict[T.Literal["name", "package"], str]]]: + """ Check for conda missing dependencies + + Returns + ------- + dict[str, list[dict[Literal["name", "package"], str]]] + The Conda packages to install grouped by channel + """ + retval: dict[str, list[dict[T.Literal["name", "package"], str]]] = {} + if not self._env.system.is_conda: + return retval + required = self._get_required_conda() + requirements = self._requirements.parse_requirements( + [p["package"] for p in required]) + channels = [p["channel"] for p in required] + installed = {k: v for k, v in self._packages.installed_conda.items() if v[1] != "pypi"} + for req, channel in zip(requirements, channels): + spec_str = str(req.specifier).replace("==", "=") if req.specifier else "" + package: dict[T.Literal["name", "package"], str] = {"name": req.name.title(), + "package": f"{req.name}{spec_str}"} + exists = installed.get(req.name) + if req.name == "tk" and self._env.system.is_linux: + # Default TK has bad fonts under Linux. + # Ref: https://github.com/ContinuumIO/anaconda-issues/issues/6833 + # This versioning will fail in parse_requirements, so we need to do it here + package["package"] = f"{req.name}=*=xft_*" # Swap out for explicit XFT version + if exists is not None and not exists[1].startswith("xft"): # Replace noxft version + exists = None + if not exists: + logger.debug("Adding new Conda package '%s'", package["package"]) + retval.setdefault(channel, []).append(package) + continue + if exists[-1] != channel: + logger.debug("Adding Conda package '%s' for channel change from '%s' to '%s'", + package["package"], exists[-1], channel) + retval.setdefault(channel, []).append(package) + continue + if not req.specifier.contains(exists[0]): + logger.debug("Adding Conda package '%s' for specifier change from '%s' to '%s'", + package["package"], exists[0], spec_str) + retval.setdefault(channel, []).append(package) + continue + logger.debug("Skipping installed Conda package '%s'", package["package"]) + logger.debug("Selected missing Conda packages: %s", retval) + return retval + + +class Checks(): # pylint:disable=too-few-public-methods + """ Pre-installation checks + Parameters + ---------- + environment : :class:`Environment` + Environment class holding information about the running system + """ + def __init__(self, environment: Environment) -> None: + self._env: Environment = environment + self._tips: Tips = Tips() # Checks not required for installer - if self.env.is_installer: - self.env.update_tf_dep() + if self._env.is_installer: return + # Checks not required for Apple Silicon + if self._env.backend == "apple_silicon": + return + self._user_input() + self._check_cuda() + self._check_rocm() + if self._env.system.is_windows: + self._tips.pip() + + def _rocm_ask_enable(self) -> None: + """ Set backend to 'rocm' if OS is Linux and ROCm support required """ + if not self._env.system.is_linux: + return + logger.info("ROCm support:\r\nIf you are using an AMD GPU, then select 'yes'." + "\r\nCPU/non-AMD GPU users should answer 'no'.\r\n") + i = input("Enable ROCm Support? [y/N] ").strip() + if i not in ("", "Y", "y", "n", "N"): + logger.warning("Invalid selection '%s'", i) + self._rocm_ask_enable() + return + if i not in ("Y", "y"): + return + logger.info("ROCm Support Enabled") + self._env.set_backend("rocm") + versions = ["6.0", "6.1", "6.2", "6.3", "6.4"] + i = input(f"Which ROCm version? [{', '.join(versions)}] ").strip() + i = versions[-1] if not i else i + print(i, i in versions, versions) + if i not in versions: + logger.warning("Invalid selection '%s'", i) + self._rocm_ask_enable() + return + logger.info("ROCm Version %s Selected", i) + self._env.set_requirements(f"rocm_{i.replace('.', '')}") - # Ask Docker/Cuda - self.docker_ask_enable() - self.cuda_ask_enable() - if self.env.os_version[0] != "Linux" and self.env.enable_docker and self.env.enable_cuda: - self.docker_confirm() - if self.env.enable_docker: - self.docker_tips() - exit(0) - - # Check for CUDA and cuDNN - if self.env.enable_cuda and self.env.os_version[0] in ("Linux", "Windows"): - self.cuda_check() - self.cudnn_check() - elif self.env.enable_cuda and self.env.os_version[0] not in ("Linux", "Windows"): - self.tips.macos() - self.output.warning("Cannot find CUDA on macOS") - self.env.cuda_version = input("Manually specify CUDA version: ") - - self.env.update_tf_dep() - self.check_system_dependencies() - if self.env.os_version[0] == "Windows": - self.tips.pip() - - def docker_ask_enable(self): + def _docker_ask_enable(self) -> None: """ Enable or disable Docker """ - i = input("Enable Docker? [y/N] ") + i = input("Enable Docker? [y/N] ").strip() + if i not in ("", "Y", "y", "n", "N"): + logger.warning("Invalid selection '%s'", i) + self._docker_ask_enable() + return if i in ("Y", "y"): - self.output.info("Docker Enabled") - self.env.enable_docker = True + logger.info("Docker Enabled") + self._env.enable_docker = True else: - self.output.info("Docker Disabled") - self.env.enable_docker = False - - def docker_confirm(self): - """ Warn if nvidia-docker on non-linux system """ - self.output.warning("Nvidia-Docker is only supported on Linux.\r\n" - "Only CPU is supported in Docker for your system") - self.docker_ask_enable() - if self.env.enable_docker: - self.output.warning("CUDA Disabled") - self.env.enable_cuda = False - - def docker_tips(self): - """ Provide tips for Docker use """ - if not self.env.enable_cuda: - self.tips.docker_no_cuda() - else: - self.tips.docker_cuda() + logger.info("Docker Disabled") + self._env.enable_docker = False - def cuda_ask_enable(self): + def _cuda_ask_enable(self) -> None: """ Enable or disable CUDA """ - i = input("Enable CUDA? [Y/n] ") - if i in ("", "Y", "y"): - self.output.info("CUDA Enabled") - self.env.enable_cuda = True - else: - self.output.info("CUDA Disabled") - self.env.enable_cuda = False - - def cuda_check(self): - """ Check Cuda for Linux or Windows """ - if self.env.os_version[0] == "Linux": - self.cuda_check_linux() - elif self.env.os_version[0] == "Windows": - self.cuda_check_windows() - - def cuda_check_linux(self): - """ Check Linux CUDA Version """ - chk = os.popen("ldconfig -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read() - if self.env.ld_library_path and not chk: - paths = self.env.ld_library_path.split(":") - for path in paths: - chk = os.popen("ls {} | grep -P -o \"libcudart.so.\\d+.\\d+\" | " - "head -n 1".format(path)).read() - if chk: - break - if not chk: - self.output.error("CUDA not found. Install and try again.\n" - "Recommended version: CUDA 9.0 cuDNN 7.1.3\n" - "CUDA: https://developer.nvidia.com/cuda-downloads\n" - "cuDNN: https://developer.nvidia.com/rdp/cudnn-download") + i = input("Enable CUDA? [Y/n] ").strip() + if i not in ("", "Y", "y", "n", "N"): + logger.warning("Invalid selection '%s'", i) + self._cuda_ask_enable() return - cudavers = chk.strip().replace("libcudart.so.", "") - self.env.cuda_version = cudavers[:cudavers.find(" ")] - if self.env.cuda_version: - self.output.info("CUDA version: " + self.env.cuda_version) - self.env.cuda_path = chk[chk.find("=>") + 3:chk.find("targets") - 1] - - def cuda_check_windows(self): - """ Check Windows CUDA Version """ - cuda_keys = [key - for key in os.environ.keys() - if key.lower().startswith("cuda_path_v")] - if not cuda_keys: - self.output.error("CUDA not found. See " - "https://github.com/deepfakes/faceswap/blob/master/INSTALL.md#cuda " - "for instructions") + if i not in ("", "Y", "y"): return - - self.env.cuda_version = cuda_keys[0].replace("CUDA_PATH_V", "").replace("_", ".") - self.env.cuda_path = os.environ[cuda_keys[0]] - self.output.info("CUDA version: " + self.env.cuda_version) - - def cudnn_check(self): - """ Check Linux or Windows cuDNN Version from cudnn.h """ - cudnn_checkfile = os.path.join(self.env.cuda_path, "include", "cudnn.h") - if not os.path.isfile(cudnn_checkfile): - self.output.error("cuDNN not found. See " - "https://github.com/deepfakes/faceswap/blob/master/INSTALL.md#cudnn " - "for instructions") + logger.info("CUDA Enabled") + self._env.set_backend("nvidia") + versions = ["11", "12", "13"] + i = input("Which Cuda version: 11 (GTX7xx-8xx), 12 (GTX9xx-10xx) or 13 (RTX20xx-)? " + f"[{', '.join(versions)}] ").strip() + i = "13" if not i else i + if i not in versions: + logger.warning("Invalid selection '%s'", i) + self._cuda_ask_enable() return - found = 0 - with open(cudnn_checkfile, "r") as ofile: - for line in ofile: - if line.lower().startswith("#define cudnn_major"): - major = line[line.rfind(" ") + 1:].strip() - found += 1 - elif line.lower().startswith("#define cudnn_minor"): - minor = line[line.rfind(" ") + 1:].strip() - found += 1 - elif line.lower().startswith("#define cudnn_patchlevel"): - patchlevel = line[line.rfind(" ") + 1:].strip() - found += 1 - if found == 3: - break - if found != 3: - self.output.error("cuDNN version could not be determined. See " - "https://github.com/deepfakes/faceswap/blob/master/INSTALL.md#cudnn " - "for instructions") + logger.info("CUDA Version %s Selected", i) + self._env.set_requirements(f"nvidia_{i}") + + def _docker_confirm(self) -> None: + """ Warn if nvidia-docker on non-Linux system """ + logger.warning("Nvidia-Docker is only supported on Linux.\r\n" + "Only CPU is supported in Docker for your system") + self._docker_ask_enable() + if self._env.enable_docker: + logger.warning("CUDA Disabled") + self._env.set_backend("cpu") + + def _docker_tips(self) -> None: + """ Provide tips for Docker use """ + if self._env.backend != "nvidia": + self._tips.docker_no_cuda() + else: + self._tips.docker_cuda() + + def _user_input(self) -> None: + """ Get user input for AMD/ROCm/Cuda/Docker """ + if self._env.backend is None: + self._rocm_ask_enable() + if self._env.backend is None: + self._docker_ask_enable() + self._cuda_ask_enable() + if not self._env.system.is_linux and (self._env.enable_docker + and self._env.backend == "nvidia"): + self._docker_confirm() + if self._env.enable_docker: + self._docker_tips() + self._env.set_config() + sys.exit(0) + + def _check_cuda(self) -> None: + """ Check for Cuda and cuDNN Locations. """ + if self._env.backend != "nvidia": + logger.debug("Skipping Cuda checks as not enabled") + return + if not any((self._env.system.is_linux, self._env.system.is_windows)): return + cuda = Cuda() + if cuda.versions: + str_vers = ", ".join(".".join(str(x) for x in v) for v in cuda.versions) + msg = (f"Globally installed Cuda version{'s' if len(cuda.versions) > 1 else ''} " + f"{str_vers} found. PyTorch uses it's own version of Cuda, so if you have " + "GPU issues, you should remove these global installs") + _InstallState.messages.append(msg) + self._env.cuda_cudnn[0] = str_vers + logger.debug("CUDA version: %s", self._env.cuda_version) + if cuda.cudnn_versions: + str_vers = ", ".join(".".join(str(x) for x in v) + for v in cuda.cudnn_versions.values()) + msg = ("Globally installed CuDNN version" + f"{'s' if len(cuda.cudnn_versions) > 1 else ''} {str_vers} found. PyTorch uses " + "its own version of Cuda, so if you have GPU issues, you should remove these " + "global installs") + _InstallState.messages.append(msg) + self._env.cuda_cudnn[1] = str_vers + logger.debug("cuDNN version: %s", self._env.cudnn_version) + + def _check_rocm(self) -> None: + """ Check for ROCm version """ + if self._env.backend != "rocm" or not self._env.system.is_linux: + logger.debug("Skipping ROCm checks as not enabled") + return + rocm = ROCm() - self.env.cudnn_version = "{}.{}".format(major, minor) - self.output.info("cuDNN version: {}.{}".format(self.env.cudnn_version, patchlevel)) - - def check_system_dependencies(self): - """ Check that system applications are installed """ - self.output.info("Checking System Dependencies...") - self.cmake_check() - if self.env.os_version[0] == "Windows": - self.visual_studio_check() - self.check_cplus_plus() - if self.env.os_version[0] == "Linux": - self.gcc_check() - self.gpp_check() - - def gcc_check(self): - """ Check installed gcc version for linux """ - chk = Popen("gcc --version", shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = chk.communicate() - if stderr: - self.output.error("gcc not installed. Please install gcc for your distribution") + if rocm.is_valid or rocm.valid_installed: + self._env.rocm_version = max(rocm.valid_versions) + logger.info("ROCm version: %s", ".".join(str(v) for v in self._env.rocm_version)) + if rocm.is_valid: return - gcc = [re.sub(" +", " ", line.strip()) - for line in stdout.decode(self.env.encoding).splitlines() - if line.lower().strip().startswith("gcc")][0] - version = gcc[gcc.rfind(" ") + 1:] - self.output.info("gcc version: {}".format(version)) - - def gpp_check(self): - """ Check installed g++ version for linux """ - chk = Popen("g++ --version", shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = chk.communicate() - if stderr: - self.output.error("g++ not installed. Please install g++ for your distribution") + if rocm.valid_installed: + str_vers = ".".join(str(v) for v in self._env.rocm_version) + _InstallState.messages.append( + f"Valid ROCm version {str_vers} is installed, but is not your default version.\n" + "You may need to change this to enable GPU acceleration") return - gpp = [re.sub(" +", " ", line.strip()) - for line in stdout.decode(self.env.encoding).splitlines() - if line.lower().strip().startswith("g++")][0] - version = gpp[gpp.rfind(" ") + 1:] - self.output.info("g++ version: {}".format(version)) - - def cmake_check(self): - """ Check CMake is installed """ - chk = Popen("cmake --version", shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = chk.communicate() - stdout = stdout.decode(self.env.encoding) - if stderr and self.env.os_version[0] == "Windows": - stdout, stderr = self.cmake_check_windows() - if stderr: - self.output.error("CMake could not be found. See " - "https://github.com/deepfakes/faceswap/blob/master/INSTALL.md#cmake " - "for instructions") + + if rocm.versions: + str_vers = ", ".join(".".join(str(x) for x in v) for v in rocm.versions) + msg = f"Incompatible ROCm version{'s' if len(rocm.versions) > 1 else ''}: {str_vers}\n" + else: + msg = "ROCm not found\n" + _InstallState.messages.append(f"{msg}\n") + str_min = ".".join(str(v) for v in rocm.version_min) + str_max = ".".join(str(v) for v in rocm.version_max) + valid = f"{str_min} to {str_max}" if str_min != str_max else str_min + msg += ("The installation can proceed, but you will need to install ROCm version " + f"{valid} to enable GPU acceleration") + _InstallState.messages.append(msg) + + +class Status(): + """ Simple Status output for intercepting Conda/Pip installs and keeping the terminal clean + + Parameters + ---------- + is_conda : bool + ``True`` if installing packages from Conda. ``False`` if installing from pip + """ + def __init__(self, is_conda: bool): + self._is_conda = is_conda + self._last_line = "" + self._max_width = 79 # Keep short because of NSIS Details window size + self._prefix = "> " + self._conda_tracked: dict[str, dict[T.Literal["size", "done"], float]] = {} + self._re_pip_pkg = re.compile(r"^Downloading\s(?P\w+)\b.*?\s\((?P.+)\)") + self._re_pip_http = re.compile(r"https?://[^\s]*/([^/\s]+)") + self._re_pip_progress = re.compile(r"^Progress\s+(?P\d+).+?(?P\d+)") + self._re_conda = re.compile( + r"(?P^\S+)\s+\|\s+(?P\d+\.?\d*\s\w+).*\|\s+(?P\d+)%") + + def _clear_line(self) -> None: + """ Clear the last printed line from the console """ + print(" " * self._max_width, end="\r") + + def _print(self, line: str) -> None: + """ Clear the last line and print the new line to the console + + Parameters + ---------- + line : str + The line to print + """ + full_line = f"{self._prefix}{line}" + output = full_line + if len(output) > self._max_width: + output = f"{output[:self._max_width - 3]}..." + if len(output) < len(self._last_line): + self._clear_line() + self._last_line = full_line + print(output, end="\r") + + def _parse_size(self, size: str) -> float: + """ Parse the string representation of a package size and return as megabytes + + Parameters + ---------- + size : str + The string representation of a package size + + Returns + ------- + float + The size in megabytes + """ + size, unit = size.strip().split(" ", maxsplit=1) + if unit.lower() == "b": + return float(size) / 1024 / 1024 + if unit.lower() == "kb": + return float(size) / 1024 + if unit.lower() == "mb": + return float(size) + if unit.lower() == "gb": + return float(size) * 1024 + return float(size) # Should never happen, but to prevent error + + def _print_conda(self, line: str) -> None: + """ Output progress for Conda installs + + Parameters + ---------- + line : str + The conda install line to parse + """ + progress = self._re_conda.match(line) + if progress is None: + self._print(line) return - cmake = [re.sub(" +", " ", line.strip()) - for line in stdout.splitlines() - if line.lower().strip().startswith("cmake")][0] - version = cmake[cmake.rfind(" ") + 1:] - self.output.info("CMake version: {}".format(version)) - - def cmake_check_windows(self): - """ Additional checks for cmake on Windows """ - chk = Popen("wmic product where \"name = 'cmake'\" get installlocation,version", - shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = chk.communicate() - if stderr: - return False, stderr - lines = [re.sub(" +", " ", line.strip()) - for line in stdout.decode(self.env.encoding).splitlines() - if line.strip()] - stdout = lines[1] - location = stdout[:stdout.rfind(" ")] + "bin" - self.output.info("CMake not found in %PATH%. Temporarily adding: \"{}\"".format(location)) - os.environ["PATH"] += ";{}".format(location) - stdout = "cmake {}".format(stdout) - return stdout, False - - def visual_studio_check(self): - """ Check Visual Studio 2015 is installed for Windows - - Somewhat hacky solution which checks for the existence - of the VS2015 Performance Report + info = progress.groupdict() + if info["lib"] not in self._conda_tracked: + self._conda_tracked[info["lib"]] = {"size": self._parse_size(info["tot"]), + "done": float(info["prg"])} + else: + self._conda_tracked[info["lib"]]["done"] = float(info["prg"]) + count = len(self._conda_tracked) + total_size = sum(v["size"] for v in self._conda_tracked.values()) + prog = min(sum(v["done"] for v in self._conda_tracked.values()) / count, 100.) + self._print(f"Downloading {count} packages ({total_size:.1f} MB) {prog:.1f}%") + + def _print_pip(self, line: str) -> None: + """ Output progress for Pip installs + + Parameters + ---------- + line : str + The pip install line to parse """ - chk = Popen("reg query HKLM\\SOFTWARE\\Microsoft\\VisualStudio\\14.0\\VSPerf", - shell=True, stdout=PIPE, stderr=PIPE) - _, stderr = chk.communicate() - if stderr: - self.output.error("Visual Studio 2015 could not be found. See " - "https://github.com/deepfakes/faceswap/blob/master/" - "INSTALL.md#microsoft-visual-studio-2015 for instructions") + if (line.lower().startswith("installing collected packages:") and + len(line) > self._max_width): + count = len(line.split(":", maxsplit=1)[-1].split(",")) + line = f"Installing {count} collected packages..." + progress = self._re_pip_progress.match(line) + if progress is None: + self._print(line) return - self.output.info("Visual Studio 2015 version: 14.0") - - def check_cplus_plus(self): - """ Check Visual C++ Redistributable 2015 is instlled for Windows """ - keys = ( - "HKLM\\SOFTWARE\\Classes\\Installer\\Dependencies\\{d992c12e-cab2-426f-bde3-fb8c53950b0d}", - "HKLM\\SOFTWARE\\WOW6432Node\\Microsoft\\VisualStudio\\14.0\\VC\\Runtimes\\x64") - for key in keys: - chk = Popen("reg query {}".format(key), shell=True, stdout=PIPE, stderr=PIPE) - stdout, stderr = chk.communicate() - if stdout: - break - if stderr: - self.output.error("Visual C++ 2015 could not be found. Make sure you have selected " - "'Visual C++' in Visual Studio 2015 Configuration or download the " - "Visual C++ 2015 Redistributable from: " - "https://www.microsoft.com/en-us/download/details.aspx?id=48145") + info = progress.groupdict() + done = (int(info["done"]) / int(info["total"])) * 100.0 + last_line = self._last_line.strip()[len(self._prefix):] + pkg = self._re_pip_pkg.match(self._re_pip_http.sub(r"\1", last_line)) + if pkg is not None: + info = pkg.groupdict() + last_line = f"Downloading {info['lib']} ({info['size']})" + self._print(f"{last_line} {done:.1f}%") + + def __call__(self, line: str) -> None: + """ Update the output status with the given line + + Parameters + ---------- + line : str + A cleansed line from either Conda or Pip installers + """ + if self._is_conda: + self._print_conda(line.strip()) + else: + self._print_pip(line.strip()) + + def close(self) -> None: + """ Reset all progress bars and re-enable the cursor """ + self._clear_line() + + +class Installer(): + """ Uses the python Subprocess module to install packages. + + Parameters + ---------- + environment : :class:`Environment` + Environment class holding information about the running system + packages : list[str] + The list of package names that are to be installed + command : list + The command to run + is_conda : bool + ``True`` if conda install command is running. ``False`` if pip install command is running + is_gui : bool + ``True`` if the process is being called from the Faceswap GUI + """ + def __init__(self, # pylint:disable=too-many-positional-arguments + environment: Environment, + packages: list[str], + command: list[str], + is_conda: bool, + is_gui: bool) -> None: + self._output_information(packages) + logger.debug("argv: %s", command) + self._env = environment + self._packages = packages + self._command = command + self._is_conda = is_conda + self._is_gui = is_gui + self._status = Status(is_conda) + self._re_ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + self._seen_lines: set[str] = set() + self.error_lines: list[str] = [] + + @classmethod + def _output_information(cls, packages: list[str]): + """ INFO log the packages to be installed, splitting along multiple lines for long package + lists (68 chars = 79 chars - (log-level spacing + indent)) + + Parameters + ---------- + packages : list[str] + The list of package names that are to be installed + """ + output = "" + sep = ", " + for pkg in packages: + current = pkg + sep + if len(output) + len(current) > 68: + logger.info(" %s", output) + output = current + else: + output += current + if output: + logger.info(" %s", output[:-len(sep)]) + + def _clean_line(self, text: str) -> str: + """Remove ANSI escape sequences and special characters from text. + + Parameters + ---------- + text : str + The text to clean + + Returns + ------- + str + The cleansed text + """ + clean = self._re_ansi_escape.sub("", text.rstrip()) + return ''.join(c for c in clean if c in set(printable)) + + def _seen_line_log(self, text: str, is_error: bool = False) -> str: + """ Output gets spammed to the log file when conda is waiting/processing. Only log each + unique line once. + + Parameters + ---------- + text : str + The text to log + is_error : bool, optional + ``True`` if the line comes from an error. Default: ``False`` + + Returns + ------- + str + The cleansed log line + + """ + clean = self._clean_line(text) + if clean in self._seen_lines: + return "" + clean = f"ERROR: {clean}" if is_error else clean + logger.debug(clean) + self._seen_lines.add(clean) + return clean + + def __call__(self) -> int: + """ Install a package using the Subprocess module + + Returns + ------- + int + The return code of the package install process + """ + with Popen(self._command, + bufsize=0, stdout=PIPE, stderr=PIPE) as proc: + lines = b"" + while True: + if proc.stdout is not None: + lines = proc.stdout.readline() + returncode = proc.poll() + if lines == b"" and returncode is not None: + break + for line in lines.split(b"\r"): + clean = self._seen_line_log(line.decode("utf-8", errors="replace")) + if not self._is_gui and clean: + self._status(clean) + if returncode and proc.stderr is not None: + for line in proc.stderr.readlines(): + clean = self._seen_line_log(line.decode("utf-8", errors="replace"), + is_error=True) + if clean: + self.error_lines.append(clean.replace("ERROR:", "").strip()) + + logger.debug("Packages: %s, returncode: %s", self._packages, returncode) + if not self._is_gui: + self._status.close() + return returncode + + +class Install(): # pylint:disable=too-few-public-methods + """ Handles installation of Faceswap requirements + + Parameters + ---------- + environment : :class:`Environment` + Environment class holding information about the running system + is_gui : bool, Optional + ``True`` if the caller is the Faceswap GUI. Used to prevent output of progress bars + which get scrambled in the GUI + """ + def __init__(self, environment: Environment, is_gui: bool = False) -> None: + self._env = environment + self._is_gui = is_gui + if not self._env.is_installer and not self._env.updater: + self._ask_continue() + self._packages = RequiredPackages(environment) + if self._env.updater and not self._packages.packages_need_install: + logger.info("All Dependencies are up to date") return - vscpp = [re.sub(" +", " ", line.strip()) - for line in stdout.decode(self.env.encoding).splitlines() - if line.lower().strip().startswith(("displayname", "version"))][0] - version = vscpp[vscpp.find("REG_SZ") + 7:] - self.output.info("Visual Studio C++ version: {}".format(version)) - - -class Install(): - """ Install the requirements """ - def __init__(self, environment): - self.output = Output() - self.env = environment - - if not self.env.is_installer: - self.ask_continue() - self.check_missing_dep() - self.check_conda_missing_dep() - self.install_missing_dep() - self.output.info("All python3 dependencies are met.\r\nYou are good to go.\r\n\r\n" - "Enter: 'python faceswap.py -h' to see the options\r\n" - " 'python faceswap.py gui' to launch the GUI") - - def ask_continue(self): + self._install_packages() + self._finalize() + + def _ask_continue(self) -> None: """ Ask Continue with Install """ - inp = input("Please ensure your System Dependencies are met. Continue? [y/N] ") + if _InstallState.messages: + for msg in _InstallState.messages: + logger.warning(msg) + text = "Please ensure your System Dependencies are met." + if self._env.backend == "rocm": + text += ("\r\nPlease ensure that your AMD GPU is supported by the " + "installed ROCm version before proceeding.") + text += "\r\nContinue? [y/N] " + inp = input(text) if inp in ("", "N", "n"): - self.output.error("Please install system dependencies to continue") - exit(1) - - def check_missing_dep(self): - """ Check for missing dependencies """ - if self.env.enable_cuda and self.env.is_macos: - self.env.required_packages.extend(self.env.macos_required_packages) - for pkg in self.env.required_packages: - key = pkg.split("==")[0] - if key not in self.env.installed_packages: - self.env.missing_packages.append(pkg) - continue - else: - if len(pkg.split("==")) > 1: - if pkg.split("==")[1] != self.env.installed_packages.get(key): - self.env.missing_packages.append(pkg) - continue - - def check_conda_missing_dep(self): - """ Check for conda missing dependencies """ - if not self.env.is_conda: + logger.info("Installation cancelled") + sys.exit(0) + + def _from_pip(self, + packages: list[dict[T.Literal["name", "package"], str]], + extra_args: list[str] | None = None) -> None: + """ Install packages from pip + + Parameters + ---------- + packages : list[dict[T.Literal["name", "package"], str] + The formatted list of packages to be installed + extra_args : list[str] | None, optional + Any extra arguments to provide to pip. Default: ``None`` (no extra arguments) + """ + pipexe = [sys.executable, + "-u", "-m", "pip", "install", "--no-cache-dir", "--progress-bar=raw"] + + if not self._env.system.is_admin and not self._env.system.is_virtual_env: + pipexe.append("--user") # install as user to solve perm restriction + if extra_args is not None: + pipexe.extend(extra_args) + pipexe.extend([p["package"] for p in packages]) + names = [p["name"] for p in packages] + installer = Installer(self._env, names, pipexe, False, self._is_gui) + if installer() != 0: + msg = f"Unable to install Python packages: {', '.join(names)}" + logger.warning("%s. Please install these packages manually", msg) + for line in installer.error_lines: + _InstallState.messages.append(line) + _InstallState.failed = True + + def _from_conda(self, + packages: list[dict[T.Literal["name", "package"], str]], + channel: str) -> None: + """ Install packages from conda + + Parameters + ---------- + packages : list[dict[T.Literal["name", "package"], str]] + The full formatted packages to be installed + channel : str + The Conda channel to install from. + + Returns + ------- + bool + ``True`` if the package was succesfully installed otherwise ``False`` + """ + conda = which("conda") + assert conda is not None + condaexe = [conda, "install", "-y", "-c", channel, + "--override-channels", "--strict-channel-priority"] + condaexe += [p["package"] for p in packages] + names = [p["name"] for p in packages] + retcode = Installer(self._env, names, condaexe, True, self._is_gui)() + if retcode != 0: + logger.warning("Unable to install Conda packages: %s. " + "Please install these packages manually", ', '.join(names)) + _InstallState.failed = True + + def _install_packages(self) -> None: + """ Install the required packages """ + if self._packages.conda: + logger.info("Installing Conda packages...") + for channel, packages in self._packages.conda.items(): + self._from_conda(packages, channel) + if self._packages.python: + logger.info("Installing Python packages...") + packages = [p for p in self._packages.python if p["name"] != "Packaging"] + self._from_pip(packages, extra_args=self._packages.pip_arguments) + + def _finalize(self) -> None: + """ Output final information on completion """ + if self._env.updater: return - for pkg in self.env.conda_required_packages: - key = pkg[0].split("==")[0] - if key not in self.env.installed_packages: - self.env.conda_missing_packages.append(pkg) - continue - else: - if len(pkg[0].split("==")) > 1: - if pkg[0].split("==")[1] != self.env.installed_conda_packages.get(key): - self.env.conda_missing_packages.append(pkg) - continue - - def install_missing_dep(self): - """ Install missing dependencies """ - if self.env.missing_packages: - self.install_python_packages() - if self.env.conda_missing_packages: - self.install_conda_packages() - - def install_python_packages(self): - """ Install required pip packages """ - self.output.info("Installing Required Python Packages. This may take some time...") - for pkg in self.env.missing_packages: - if self.env.is_conda: - verbose = pkg.startswith("tensorflow") - if self.conda_installer(pkg, verbose=verbose): - continue - self.pip_installer(pkg) - - def install_conda_packages(self): - """ Install required conda packages """ - self.output.info("Installing Required Conda Packages. This may take some time...") - for pkg in self.env.conda_missing_packages: - channel = None if len(pkg) != 2 else pkg[1] - self.conda_installer(pkg[0], channel=channel, conda_only=True) - - def conda_installer(self, package, channel=None, verbose=False, conda_only=False): - """ Install a conda package """ - success = True - condaexe = ["conda", "install", "-y"] - if not verbose: - condaexe.append("-q") - if channel: - condaexe.extend(["-c", channel]) - condaexe.append(package) - self.output.info("Installing {}".format(package)) - try: - if verbose: - run(condaexe, check=True) - else: - with open(os.devnull, "w") as devnull: - run(condaexe, stdout=devnull, stderr=devnull, check=True) - except CalledProcessError: - if not conda_only: - self.output.info("Couldn't install {} with Conda. Trying pip".format(package)) - else: - self.output.warning("Couldn't install {} with Conda. " - "Please install this package manually".format(package)) - success = False - return success - - def pip_installer(self, package): - """ Install a pip package """ - pipexe = [sys.executable, "-m", "pip"] - # hide info/warning and fix cache hang - pipexe.extend(["install", "-qq", "--no-cache-dir"]) - # install as user to solve perm restriction - if not self.env.is_admin and not self.env.is_virtualenv: - pipexe.append("--user") - if package.startswith("dlib"): - opt = "yes" if self.env.enable_cuda else "no" - pipexe.extend(["--install-option=--{}".format(opt), - "--install-option=DLIB_USE_CUDA"]) - if self.env.os_version[0] == "Windows": - pipexe.extend(["--global-option=-G", - "--global-option=Visual Studio 14 2015"]) - msg = ("Compiling {}. This will take a while...\n" - "Please ignore the following UserWarning: " - "'Disabling all use of wheels...'".format(package)) + if not _InstallState.failed: + if _InstallState.messages: + for msg in _InstallState.messages: + logger.warning(msg) + logger.info("All Faceswap dependencies are met. You are good to go.\r\n\r\n" + "Enter: 'python faceswap.py -h' to see the options\r\n" + " 'python faceswap.py gui' to launch the GUI") else: - msg = "Installing {}".format(package) - self.output.info(msg) - pipexe.append(package) - try: - run(pipexe, check=True) - except CalledProcessError: - self.output.warning("Couldn't install {} with pip. " - "Please install this package manually".format(package)) + msg = "Some packages failed to install. " + if not _InstallState.messages: + msg += ("This may be temporary and might be fixed by re-running this script. " + "Otherwise check 'faceswap_setup.log' to see which failed and install " + "these packages manually.") + else: + msg += ("Further information can be found in 'faceswap_setup.log'. The following " + "output shows specific error(s) that were collected:\r\n") + msg += "\r\n".join(_InstallState.messages) + logger.error(msg) + sys.exit(1) class Tips(): """ Display installation Tips """ - def __init__(self): - self.output = Output() - - def docker_no_cuda(self): + @classmethod + def docker_no_cuda(cls) -> None: """ Output Tips for Docker without Cuda """ - self.output.info( - "1. Install Docker\n" - "https://www.docker.com/community-edition\n\n" - "2. Build Docker Image For Faceswap\n" - "docker build -t deepfakes-cpu -f Dockerfile.cpu .\n\n" - "3. Mount faceswap volume and Run it\n" - "# without GUI\n" - "docker run -p 8888:8888 \\ \n" - "\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n" - "\t-v {path}:/srv \\ \n" - "\tdeepfakes-cpu\n\n" - "# with gui. tools.py gui working.\n" - "## enable local access to X11 server\n" - "xhost +local:\n" - "## create container\n" - "nvidia-docker run -p 8888:8888 \\ \n" - "\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n" - "\t-v {path}:/srv \\ \n" - "\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" - "\t-e DISPLAY=unix$DISPLAY \\ \n" - "\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n" - "\t-e VIDEO_GID=`getent group video | cut -d: -f3` \\ \n" - "\t-e GID=`id -g` \\ \n" - "\t-e UID=`id -u` \\ \n" - "\tdeepfakes-cpu \n\n" - "4. Open a new terminal to run faceswap.py in /srv\n" - "docker exec -it deepfakes-cpu bash".format(path=sys.path[0])) - self.output.info("That's all you need to do with a docker. Have fun.") - - def docker_cuda(self): - """ Output Tips for Docker wit Cuda""" - self.output.info( - "1. Install Docker\n" - "https://www.docker.com/community-edition\n\n" - "2. Install latest CUDA\n" - "CUDA: https://developer.nvidia.com/cuda-downloads\n\n" - "3. Install Nvidia-Docker & Restart Docker Service\n" - "https://github.com/NVIDIA/nvidia-docker\n\n" - "4. Build Docker Image For Faceswap\n" - "docker build -t deepfakes-gpu -f Dockerfile.gpu .\n\n" - "5. Mount faceswap volume and Run it\n" - "# without gui \n" - "docker run -p 8888:8888 \\ \n" - "\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n" - "\t-v {path}:/srv \\ \n" - "\tdeepfakes-gpu\n\n" - "# with gui.\n" - "## enable local access to X11 server\n" - "xhost +local:\n" - "## enable nvidia device if working under bumblebee\n" - "echo ON > /proc/acpi/bbswitch\n" - "## create container\n" - "nvidia-docker run -p 8888:8888 \\ \n" - "\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n" - "\t-v {path}:/srv \\ \n" - "\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" - "\t-e DISPLAY=unix$DISPLAY \\ \n" - "\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n" - "\t-e VIDEO_GID=`getent group video | cut -d: -f3` \\ \n" - "\t-e GID=`id -g` \\ \n" - "\t-e UID=`id -u` \\ \n" - "\tdeepfakes-gpu\n\n" - "6. Open a new terminal to interact with the project\n" - "docker exec deepfakes-gpu python /srv/tools.py gui\n".format(path=sys.path[0])) - - def macos(self): + logger.info( + "1. Install Docker from: https://www.docker.com/get-started\n\n" + "2. Enter the Faceswap folder and build the Docker Image For Faceswap:\n" + " docker build -t faceswap-cpu -f Dockerfile.cpu .\n\n" + "3. Launch and enter the Faceswap container:\n" + " a. Headless:\n" + " docker run --rm -it -v ./:/srv faceswap-cpu\n\n" + " b. GUI:\n" + " xhost +local: && \\ \n" + " docker run --rm -it \\ \n" + " -v ./:/srv \\ \n" + " -v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" + " -e DISPLAY=${DISPLAY} \\ \n" + " faceswap-cpu \n") + logger.info("That's all you need to do with docker. Have fun.") + + @classmethod + def docker_cuda(cls) -> None: + """ Output Tips for Docker with Cuda""" + logger.info( + "1. Install Docker from: https://www.docker.com/get-started\n\n" + "2. Install latest CUDA 11 and cuDNN 8 from: https://developer.nvidia.com/cuda-" + "downloads\n\n" + "3. Install the the Nvidia Container Toolkit from https://docs.nvidia.com/datacenter/" + "cloud-native/container-toolkit/latest/install-guide\n\n" + "4. Restart Docker Service\n\n" + "5. Enter the Faceswap folder and build the Docker Image For Faceswap:\n" + " docker build -t faceswap-gpu -f Dockerfile.gpu .\n\n" + "6. Launch and enter the Faceswap container:\n" + " a. Headless:\n" + " docker run --runtime=nvidia --rm -it -v ./:/srv faceswap-gpu\n\n" + " b. GUI:\n" + " xhost +local: && \\ \n" + " docker run --runtime=nvidia --rm -it \\ \n" + " -v ./:/srv \\ \n" + " -v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" + " -e DISPLAY=${DISPLAY} \\ \n" + " faceswap-gpu \n") + logger.info("That's all you need to do with docker. Have fun.") + + @classmethod + def macos(cls) -> None: """ Output Tips for macOS""" - self.output.info( + logger.info( "setup.py does not directly support macOS. The following tips should help:\n\n" "1. Install system dependencies:\n" "XCode from the Apple Store\n" - "XQuartz: https://www.xquartz.org/\n\n" - - "2a. It is recommended to use Anaconda for your Python Virtual Environment as this\n" - "will handle the installation of CUDA and cuDNN for you:\n" - "https://www.anaconda.com/distribution/\n\n" + "XQuartz: https://www.xquartz.org/\n\n") - "2b. If you do not want to use Anaconda, or if you wish to compile DLIB with GPU\n" - "support you will need to manually install CUDA and cuDNN:\n" - "CUDA: https://developer.nvidia.com/cuda-downloads" - "cuDNN: https://developer.nvidia.com/rdp/cudnn-download\n\n") - - def pip(self): + @classmethod + def pip(cls) -> None: """ Pip Tips """ - self.output.info("1. Install PIP requirements\n" - "You may want to execute `chcp 866` in cmd line\n" - "to fix Unicode issues on Windows when installing dependencies") + logger.info("1. Install PIP requirements\n" + "You may want to execute `chcp 65001` in cmd line\n" + "to fix Unicode issues on Windows when installing dependencies") if __name__ == "__main__": + logfile = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap_setup.log") + log_setup("INFO", logfile, "setup") + logger.debug("Setup called with args: %s", sys.argv) ENV = Environment() Checks(ENV) - if INSTALL_FAILED: - exit(1) + ENV.set_config() + if _InstallState.failed: + sys.exit(1) Install(ENV) + + +__all__ = get_module_objects(__name__) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/imgs/test_img1.jpg b/tests/data/imgs/test_img1.jpg new file mode 100644 index 0000000000..411f46dce8 Binary files /dev/null and b/tests/data/imgs/test_img1.jpg differ diff --git a/tests/data/imgs/test_img2.jpg b/tests/data/imgs/test_img2.jpg new file mode 100644 index 0000000000..fde67669b8 Binary files /dev/null and b/tests/data/imgs/test_img2.jpg differ diff --git a/tests/data/imgs/test_img3.jpg b/tests/data/imgs/test_img3.jpg new file mode 100644 index 0000000000..039a58b123 Binary files /dev/null and b/tests/data/imgs/test_img3.jpg differ diff --git a/tests/data/imgs/test_img4.jpg b/tests/data/imgs/test_img4.jpg new file mode 100644 index 0000000000..42fe62a020 Binary files /dev/null and b/tests/data/imgs/test_img4.jpg differ diff --git a/tests/data/vid/test.mp4 b/tests/data/vid/test.mp4 new file mode 100644 index 0000000000..a2cdeb0243 Binary files /dev/null and b/tests/data/vid/test.mp4 differ diff --git a/tests/lib/__init__.py b/tests/lib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/config/__init__.py b/tests/lib/config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/config/config_test.py b/tests/lib/config/config_test.py new file mode 100644 index 0000000000..9fa1dddc9e --- /dev/null +++ b/tests/lib/config/config_test.py @@ -0,0 +1,276 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.config.config` """ + +import pytest + +from lib.config import config as config_mod + +from tests.lib.config.helpers import FakeConfigItem + + +# pylint:disable=too-few-public-methods,protected-access,invalid-name + +def get_instance(mocker, module="plugins.test.test_config"): + """ Generate a FaceswapConfig instance, substituting the calling module for the one given """ + mocker.patch("lib.config.config.FaceswapConfig.__module__", module) + return config_mod.FaceswapConfig() + + +_MODULES = (("plugins.test.test_config", "path_valid"), + ("plugins.test.test", "path_invalid"), + ("plugins.config.test_config", "folder_invalid")) +_MODULE_IDS = [x[-1] for x in _MODULES] + + +@pytest.mark.parametrize(("module", "mod_status"), _MODULES, ids=_MODULE_IDS) +def test_FaceswapConfig_init(module, mod_status, mocker): + """ Test that :class:`lib.config.config.FaceswapConfig` initializes correctly """ + mocker.patch("lib.config.config.FaceswapConfig.set_defaults", mocker.MagicMock()) + mocker.patch("lib.config.config.ConfigFile.on_load", mocker.MagicMock()) + if mod_status.endswith("invalid"): + with pytest.raises(AssertionError): + get_instance(mocker, module=module) + return + + test = get_instance(mocker, module=module) + assert test._plugin_group == "test" + assert isinstance(test._ini, config_mod.ConfigFile) + test.set_defaults.assert_called_once() + test._ini.on_load.assert_called_once_with(test.sections) # pylint:disable=no-member + assert config_mod._CONFIGS["test"] == test + + +def test_FaceswapConfig_add_section(mocker): + """ Test :class:`lib.config.config.FaceswapConfig.add_section` works """ + instance = get_instance(mocker) + title = "my.test.section" + info = "And here is some test help text" + assert title not in instance.sections + instance.add_section(title, info) + assert title in instance.sections + assert isinstance(instance.sections[title], config_mod.ConfigSection) + assert instance.sections[title].helptext == info + + +def test_FaceswapConfig_add_item(mocker): + """ Test :class:`lib.config.config.FaceswapConfig.add_item` works """ + instance = get_instance(mocker) + section = "my.test.section" + title = "test_option" + config_item = "TEST_CONFIG_ITEM" + + assert section not in instance.sections + with pytest.raises(KeyError): # Fail adding item to non-existant key + instance.add_item(section, title, config_item) + + instance.add_section(section, "") + assert title not in instance.sections[section].options + instance.add_item(section, title, config_item) + assert title in instance.sections[section].options + assert instance.sections[section].options[title] == config_item + + +@pytest.mark.parametrize("filename", + ("test_defaults.py", "train_defaults.py", "different_name.py")) +def test_FaceswapConfig_import_defaults_from_module(mocker, filename): + """ Test :class:`lib.config.config.FaceswapConfig._defaults_from_module` works """ + mocker.patch("lib.config.config.ConfigItem", FakeConfigItem) + + class DummyMod: + """ Dummy Module for loading config items """ + opt1 = FakeConfigItem(10) + opt2 = FakeConfigItem(20) + invalid = "invalid" + HELPTEXT = "Test help text" + mock_mod = mocker.MagicMock(return_value=DummyMod) + mocker.patch("lib.config.config.import_module", mock_mod) + + instance = get_instance(mocker) + module_path = "test.module.path" + plugin_type = "test" + section = plugin_type + "." + filename[:-3].replace("_defaults", "") + + assert section not in instance.sections + + instance._import_defaults_from_module(filename, module_path, plugin_type) + + mock_mod.assert_called_once_with(f"{module_path}.{filename[:-3]}") + + assert section in instance.sections + assert instance.sections[section].helptext == DummyMod.HELPTEXT + assert len(instance.sections[section].options) == 2 + assert isinstance(instance.sections[section].options["opt1"], FakeConfigItem) + assert isinstance(instance.sections[section].options["opt2"], FakeConfigItem) + + +def test_FaceswapConfig_defaults_from_plugin(mocker): + """ Test :class:`lib.config.config.FaceswapConfig._defaults_from_plugin` works """ + mocker.patch("lib.config.config.ConfigItem", FakeConfigItem) + dir_tree = [("plugins/train/model/plugin_a", [], ['plugin_a_defaults.py', '__init__.py']), + ("plugins/extract", [], ['extract_defaults.py', '__init__.py']), + ("plugins/convert/writer", [], ['writer_defaults.py', '__init__.py']), + ("plugins/train", ["model", "trainer"], ['train_config.py', '__init__.py'])] + mock_walk = mocker.MagicMock(return_value=dir_tree) + mocker.patch("lib.config.config.os.walk", mock_walk) + + instance = get_instance(mocker) + + instance._import_defaults_from_module = mocker.MagicMock() + + instance._defaults_from_plugin("test") + + assert instance._import_defaults_from_module.call_count == 3 # 3 valid, 1 invalid + + +def test_FaceswapConfig_set_defaults_global(mocker): + """ Test :class:`lib.config.config.FaceswapConfig.set_defaults` works for global sections """ + mocker.patch("lib.config.config.ConfigItem", FakeConfigItem) + + class DummyMod: + """ Dummy Module for loading config items """ + opt1 = FakeConfigItem(10) + opt2 = FakeConfigItem(20) + invalid = "invalid" + HELPTEXT = "Test help text" + mocker.patch("lib.config.config.sys.modules", + config_mod.sys.modules | {"plugins.test.test_config": DummyMod}) + + instance = get_instance(mocker) + + instance.add_section = mocker.MagicMock() + instance.add_item = mocker.MagicMock() + + instance.set_defaults("") + instance.add_section.assert_not_called() + instance.add_item.assert_not_called() + + instance.set_defaults("test") + instance.add_section.assert_called_once() + assert instance.add_item.call_count == 2 + + +def test_FaceswapConfig_set_defaults_subsection(mocker): + """ Test :class:`lib.config.config.FaceswapConfig.set_defaults` works for sub-sections """ + mocker.patch("lib.config.config.ConfigItem", FakeConfigItem) + + class DummyGlobal(config_mod.GlobalSection): + """ Dummy GlobalSection class """ + opt1 = FakeConfigItem(30) + opt2 = FakeConfigItem(40) + opt3 = FakeConfigItem(50) + invalid = "invalid" + helptext = "Section help text" + + class DummyMod: + """ Dummy Module class for loading config items """ + opt1 = FakeConfigItem(10) + opt2 = FakeConfigItem(20) + sect1 = DummyGlobal + invalid = "invalid" + HELPTEXT = "Test help text" + mocker.patch("lib.config.config.sys.modules", + config_mod.sys.modules | {"plugins.test.test_config": DummyMod}) + + instance = get_instance(mocker) + + instance.add_section = mocker.MagicMock() + instance.add_item = mocker.MagicMock() + + instance.set_defaults("test") + assert instance.add_section.call_count == 2 # global + subsection + assert instance.add_item.call_count == 5 # global + subsection + + +def test_FaceswapConfig_set_defaults(mocker): + """ Test :class:`lib.config.config.FaceswapConfig._set_defaults` works """ + instance = get_instance(mocker) + + class DummySection1: + """ Dummy ConfigSection class """ + options = {"opt1": FakeConfigItem(10), + "opt2": FakeConfigItem(20), + "opt3": FakeConfigItem(30)} + + class DummySection2: + """ Dummy ConfigSection class """ + options = {"opt1": FakeConfigItem(40), + "opt2": FakeConfigItem(50), + "opt3": FakeConfigItem(60)} + + class DummySection3: + """ Dummy ConfigSection class """ + options = {"opt1": FakeConfigItem(70), + "opt2": FakeConfigItem(80), + "opt3": FakeConfigItem(90)} + + instance.set_defaults = mocker.MagicMock() + sections = {"zzz_section": DummySection1(), + "mmm_section": DummySection2(), + "aaa_section": DummySection3()} + instance.sections = sections + + instance._set_defaults() + + instance.set_defaults.assert_called_once() + for sect_name, sect in instance.sections.items(): + for key, opt in sect.options.items(): + assert opt._name == f"test.{sect_name}.{key}" + assert list(instance.sections) == sorted(sections) + + +def test_FaceswapConfig_save(mocker): + """ Test :class:`lib.config.config.FaceswapConfig.save` works """ + instance = get_instance(mocker) + instance._ini.update_from_app = mocker.MagicMock() + instance.sections = "TEST_SECTIONS" + + instance.save_config() + + instance._ini.update_from_app.assert_called_once_with(instance.sections) + + +def test_get_configs(mocker): + """ Test :class:`lib.config.config.get_configs` works """ + mock_gen_configs = mocker.MagicMock() + mocker.patch("lib.config.config.generate_configs", mock_gen_configs) + mocker.patch("lib.config.config._CONFIGS", "TEST_ALL_CONFIGS") + + result = config_mod.get_configs() + mock_gen_configs.assert_called_once_with(force=True) + assert result == "TEST_ALL_CONFIGS" + + +def test_generate_configs(mocker): + """ Test :class:`lib.config.config.generate_configs` works """ + _root = "/path/to/faceswap" + mocker.patch("lib.config.config.PROJECT_ROOT", _root) + + dir_tree = [ + (f"{_root}/plugins/train", [], ['train_config.py', '__init__.py']), # Success + (f"{_root}/plugins/extract", [], ['extract_config.py', '__init__.py']), # Success + (f"{_root}/plugins/convert/writer", [], ['writer_config.py', '__init__.py']), # Too deep + # Wrong name + (f"{_root}/plugins/train", ["model", "trainer"], ['train_defaults.py', '__init__.py'])] + mock_walk = mocker.MagicMock(return_value=dir_tree) + mocker.patch("lib.config.config.os.walk", mock_walk) + + mock_initialized = mocker.MagicMock() + + class DummyConfig(config_mod.FaceswapConfig): + """ Dummy FaceswapConfig class """ + def __init__(self, # pylint:disable=unused-argument,super-init-not-called + *args, + **kwargs): + mock_initialized() + + class DummyMod: + """ Dummy Module to load configs from """ + mod1 = DummyConfig + + mock_mod = mocker.MagicMock(return_value=DummyMod) + mocker.patch("lib.config.config.import_module", mock_mod) + + config_mod.generate_configs(False) + + assert mock_mod.call_count == 2 # 2 modules imported + assert mock_initialized.call_count == 2 # 2 configs loaded diff --git a/tests/lib/config/helpers.py b/tests/lib/config/helpers.py new file mode 100644 index 0000000000..0d1e9ce852 --- /dev/null +++ b/tests/lib/config/helpers.py @@ -0,0 +1,51 @@ +#! /usr/env/bin/python3 +""" Helper mock items for ConfigItems """ + +import pytest + + +class FakeConfigItem: + """ ConfigItem substitute""" + def __init__(self, value): + self.value = value + self._name = "" + + @property + def ini_value(self): + """ Dummy ini value """ + return self.value.lower() if isinstance(self.value, str) else self.value + + @property + def helptext(self): + """ Dummy help text """ + return f"Test helptext for {self._name}:{self.value}" + + def get(self): + """ Return the value """ + return self.value + + def set(self, value): + """ Return the value """ + self.value = value + + def set_name(self, name): + """ Set the name """ + self._name = name + + def __call__(self): + return self.value + + def __repr__(self): + return f"FakeConfigItem(value={self.value!r})" + + +@pytest.fixture +def patch_config(monkeypatch: pytest.MonkeyPatch): + """ Fixture to patch user config values """ + + def _apply(module, cfg_dict): + """ Create the fake ConfigItem object """ + for key, value in cfg_dict.items(): + monkeypatch.setattr(module, key, FakeConfigItem(value)) + + return _apply diff --git a/tests/lib/config/ini_test.py b/tests/lib/config/ini_test.py new file mode 100644 index 0000000000..09730f92a6 --- /dev/null +++ b/tests/lib/config/ini_test.py @@ -0,0 +1,377 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.config.ini` """ + +import os +import pytest + +from lib.config import ini as ini_mod + +from tests.lib.config.helpers import FakeConfigItem + +# pylint:disable=protected-access,invalid-name + + +_GROUPS = ("group1", "group2", "group3") +_CONFIG = ("custom", "custom_missing", "root", "root_missing") + + +@pytest.mark.parametrize("plugin_group", _GROUPS) +@pytest.mark.parametrize("config", _CONFIG) +def test_ConfigFile(tmpdir, mocker, plugin_group, config): + """ Test that :class:`lib.config.ini.ConfigFile` initializes correctly """ + root_conf = tmpdir.mkdir("root").mkdir("config").join(f"{plugin_group}.ini") + root_dir = os.path.dirname(os.path.dirname(root_conf)) + if config != "root_missing": + root_conf.write("") + mocker.patch("lib.config.ini.PROJECT_ROOT", root_dir) + + conf_file = None + if config.startswith("custom"): + conf_file = tmpdir.mkdir("config").join("test_custom_config.ini") + if config == "custom": + conf_file.write("") + + mock_load = mocker.MagicMock() + mocker.patch("lib.config.ini.ConfigFile.load", mock_load) + + if config == "custom_missing": # Error on explicit missing + with pytest.raises(ValueError): + ini_mod.ConfigFile("group2test", conf_file) + return + + instance = ini_mod.ConfigFile(plugin_group, conf_file) + file_path = conf_file if config == "custom" else root_conf + assert instance._file_path == file_path + assert instance._plugin_group == plugin_group + assert instance._parser.optionxform is str + + if config in ("custom", "root"): # load when exists + mock_load.assert_called_once() + else: + mock_load.assert_not_called() # Don't load when it doesn't + + +def test_ConfigFile_load(mocker): + """ Test that :class:`lib.config.ini.ConfigFile.load` calls correctly """ + instance = ini_mod.ConfigFile("test") + + mock_read = mocker.MagicMock() + instance._parser.read = mock_read + + instance.load() + + mock_read.assert_called_once() + + +def test_ConfigFile_save(mocker): + """ Test that :class:`lib.config.ini.ConfigFile.save` calls correctly """ + instance = ini_mod.ConfigFile("test") + + mock_write = mocker.MagicMock() + instance._parser.write = mock_write + + instance.save() + + mock_write.assert_called_once() + + +class FakeConfigSection: # pylint:disable=too-few-public-methods + """ Fake config section """ + def __init__(self, num_opts=2): + self.options = {f"opt{i}": FakeConfigItem(f"test_value{i}") for i in range(num_opts)} + self.helptext = f"Test helptext for {num_opts} options" + + +def get_local_remote(sections=[2, 1, 3]): # pylint:disable=dangerous-default-value + """ Obtain an object representing inputs to a ConfigParser and a matching object representing + Faceswap Config """ + parser_sections = {f"section{i}": {f"opt{idx}": f"test_value{idx}" for idx in range(s)} + for i, s in enumerate(sections)} + fs_sections = {f"section{i}": FakeConfigSection(s) for i, s in enumerate(sections)} + return parser_sections, fs_sections + + +def test_ConfigFile_is_synced_structure(): + """ Test that :class:`lib.config.ini.ConfigFile.is_synced_structure` is logical """ + instance = ini_mod.ConfigFile("test") + + sect_sizes = [2, 1, 3] + parser_sects, fs_sects = get_local_remote(sect_sizes) + + # No Config + test = instance._is_synced_structure(fs_sects) + assert test is False + + # Sects exist + for section in parser_sects: + instance._parser.add_section(section) + + test = instance._is_synced_structure(fs_sects) + assert test is False + + # Some Options missing + for section, options in parser_sects.items(): + for opt, val in options.items(): + instance._parser.set(section, opt, val) + break + + test = instance._is_synced_structure(fs_sects) + assert test is False + + # Structure matches + for section, options in parser_sects.items(): + for opt, val in options.items(): + instance._parser.set(section, opt, val) + + test = instance._is_synced_structure(fs_sects) + assert test is True + + # Extra saved section + instance._parser.add_section("text_extra_section") + test = instance._is_synced_structure(fs_sects) + assert test is False + + # Structure matches + del instance._parser["text_extra_section"] + test = instance._is_synced_structure(fs_sects) + assert test is True + + # Extra Option + instance._parser.set(section, "opt_test_extra_option", "val_test_extra_option") + test = instance._is_synced_structure(fs_sects) + assert test is False + + +def testConfigFile_format_help(): + """ Test that :class:`lib.config.ini.ConfigFile.format_help` inserts # on each line """ + instance = ini_mod.ConfigFile("test") + text = "This\nis a test\n\n\nof some text\n" + result = instance.format_help(text) + assert all(x.startswith("#") for x in result.splitlines() if x) + + +@pytest.mark.parametrize("section", + ("section1", "another_section", "section_test")) +def testConfigFile_insert_section(mocker, section): + """ Test that :class:`lib.config.ini.ConfigFile._insert_section` calls correctly """ + helptext = f"{section}_helptext" + + instance = ini_mod.ConfigFile("test") + instance.format_help = mocker.MagicMock(return_value=helptext) + + parser = instance._parser + + assert section not in parser + + instance._insert_section(section, helptext, parser) + + instance.format_help.assert_called_once_with(helptext, is_section=True) + + assert section in parser + assert helptext in parser[section] + + +@pytest.mark.parametrize(("section", "name", "value"), + (("section1", "opt1", "value1"), + ("another_section", "my_option", "what_its_worth"))) +def testConfigFile_insert_option(mocker, section, name, value): + """ Test that :class:`lib.config.ini.ConfigFile._insert_option` calls correctly """ + helptext = f"{section}_helptext" + + instance = ini_mod.ConfigFile("test") + instance.format_help = mocker.MagicMock(return_value=helptext) + + parser = instance._parser + parser.add_section(section) + + assert name not in parser[section] + + instance._insert_option(section, name, helptext, value, parser) + + instance.format_help.assert_called_once_with(helptext, is_section=False) + assert name in parser[section] + assert parser[section][name] == value + + +_ini, _app, = get_local_remote([2, 1, 3]) +_ini_extra, _app_extra = get_local_remote(sections=[3, 1, 3]) +_ini_value, _ = get_local_remote(sections=[2, 1, 3]) +_ini_value["section0"]["opt0"] = "updated_value" + +_SYNC = ((_ini, _app, "synced"), + (_ini, _app_extra, "new_from_app"), + (_ini_extra, _app, "del_from_app"), + (_ini_value, _app, "updated_ini")) +_SYNC_IDS = [x[-1] for x in _SYNC] + + +@pytest.mark.parametrize(("ini_config", "app_config", "status"), _SYNC, ids=_SYNC_IDS) +@pytest.mark.parametrize("exists", (True, False), ids=("exists", "not_exists")) +def testConfigFile_sync_from_app(ini_config, # pylint:disable=too-many-branches # noqa[C901] + app_config, + status, + exists, + mocker): + """ Test :class:`lib.config.ini.ConfigFile._sync_from_app` logic """ + mocker.patch("lib.config.ini.ConfigFile._exists", exists) + + instance = ini_mod.ConfigFile("test") + instance.save = mocker.MagicMock() + + original_parser = instance._parser + + if exists: + for section, opts in ini_config.items(): + original_parser.add_section(section) + for name, opt in opts.items(): + original_parser[section][name] = opt + + opt_pairs = [({k: v.value for k, v in opts.options.items()}, + dict(original_parser[s].items())) + for s, opts in app_config.items()] + # Sanity check that the loaded parser is set correctly + if status == "synced": + assert all(set(x[0]) == set(x[1]) for x in opt_pairs) + elif status == "new_from_app": + assert any(len(x[1]) < len(x[0]) for x in opt_pairs) + elif status == "new_from_ini": + assert any(len(x[0]) < len(x[1]) for x in opt_pairs) + elif status == "updated_ini": + vals = [(set(x[0].values()), set(x[1].values())) for x in opt_pairs] + assert not all(a == i for a, i in vals) + else: + for section in ini_config: + assert section not in instance._parser + + instance._sync_from_app(app_config) # Sync + + instance.save.assert_called_once() # Saved + if exists: + assert instance._parser is not original_parser # New config Generated + else: + assert instance._parser is original_parser # Blank Config pre-exists + + opt_pairs = [({k: v.value for k, v in opts.options.items()}, + {k: v for k, v in instance._parser[s].items() if k.startswith("opt")}) + for s, opts in app_config.items()] + + # Test options are now in sync + assert all(set(x[0]) == set(x[1]) for x in opt_pairs) + # Test that ini value kept + vals = [(set(x[0].values()), set(x[1].values())) for x in opt_pairs] + if exists and status == "updated_ini": + assert any("updated_value" in i for _, i in vals) + assert any(a != i for a, i in vals) + else: + assert not any("updated_value" in i for _, i in vals) + assert all(a == i for a, i in vals) + + +@pytest.mark.parametrize(("section", "option", "value", "datatype"), + (("section1", "opt_str", "test_str", str), + ("section2", "opt_bool", "True", bool), + ("section3", "opt_int", "42", int), + ("section4", "opt_float", "42.69", float), + ("section5", "opt_other", "[test_other]", str)), + ids=("str", "bool", "int", "float", "other")) +def testConfigFile_get_converted_value(section, option, value, datatype): + """ Test :class:`lib.config.ini.ConfigFile._get_converted_value` logic """ + instance = ini_mod.ConfigFile("test") + instance._parser.add_section(section) + instance._parser[section][option] = value + + result = instance._get_converted_value(section, option, datatype) + assert isinstance(result, datatype) + assert datatype(value) == result + + +_ini, _app, = get_local_remote([2, 1, 3]) +_ini_changed, _ = get_local_remote(sections=[2, 1, 3]) +_ini_changed["section0"]["opt0"] = "updated_value" +_ini_changed["section2"]["opt1"] = "updated_value" + +_SYNC_TO = ((_ini, _app, "synced"), (_ini_changed, _app, "updated_ini")) +_SYNC__TO_IDS = [x[-1] for x in _SYNC_TO] + + +@pytest.mark.parametrize(("ini_config", "app_config", "status"), _SYNC_TO, ids=_SYNC__TO_IDS) +def testConfigFile_sync_to_app(ini_config, app_config, status, mocker): + """ Test :class:`lib.config.ini.ConfigFile._sync_to_app` logic """ + + for sect in app_config.values(): # Add a dummy datatype param to FSConfig + for opt in sect.options.values(): + setattr(opt, "datatype", str) + + instance = ini_mod.ConfigFile("test") + instance._get_converted_value = mocker.MagicMock(return_value="updated_value") + + for section, opts in ini_config.items(): # Load up the dummy ini info + instance._parser.add_section(section) + for name, opt in opts.items(): + instance._parser[section][name] = opt + + instance._sync_to_app(app_config) + + app_values = {sname: set(v.value for v in sect.options.values()) + for sname, sect in app_config.items()} + sect_values = {sname: set(instance._parser[sname].values()) + for sname in instance._parser.sections()} + + if status == "synced": # No items change + instance._get_converted_value.assert_not_called() + else: # 2 items updated in the config.ini + assert instance._get_converted_value.call_count == 2 + + # App and ini values must now match + assert set(app_values) == set(sect_values) + for sect in app_values: + assert set(app_values[sect]) == set(sect_values[sect]) + + +@pytest.mark.parametrize("structure_synced", + (True, False), + ids=("struc_synced", "not_struc_synced")) +@pytest.mark.parametrize("exists", (True, False), ids=("exists", "not_exists")) +def testConfigFile_sync_on_load(structure_synced, exists, mocker): + """ Test :class:`lib.config.ini.ConfigFile.on_load` logic """ + mocker.patch("lib.config.ini.ConfigFile._exists", exists) + _, app_config = get_local_remote() + + instance = ini_mod.ConfigFile("test") + instance._sync_from_app = mocker.MagicMock() + instance._sync_to_app = mocker.MagicMock() + instance._is_synced_structure = mocker.MagicMock(return_value=structure_synced) + + instance.on_load(app_config) + + instance._is_synced_structure.assert_called_once_with(app_config) + instance._sync_to_app.assert_called_once_with(app_config) + + if not exists or not structure_synced: + instance._sync_from_app.assert_called_with(app_config) + call_count = 2 if (not exists and not structure_synced) else 1 + else: + call_count = 0 + assert instance._sync_from_app.call_count == call_count + + +@pytest.mark.parametrize("app_config", + (get_local_remote([2, 1, 3])[1], + get_local_remote([4, 2, 6, 8])[1], + get_local_remote([3])[1])) +def testConfigFile_sync_update_from_app(app_config, mocker): + """ Test :class:`lib.config.ini.ConfigFile.update_from_app` logic """ + instance = ini_mod.ConfigFile("test") + instance.save = mocker.MagicMock() + for sect in app_config: + # Updating from app always replaces the existing parser with a new one + assert sect not in instance._parser.sections() + + instance.update_from_app(app_config) + + instance.save.assert_called_once() + for sect_name, sect in app_config.items(): + assert sect_name in instance._parser.sections() + for opt_name, val in sect.options.items(): + assert opt_name in instance._parser[sect_name] + assert instance._parser[sect_name][opt_name] == val.ini_value diff --git a/tests/lib/config/objects_test.py b/tests/lib/config/objects_test.py new file mode 100644 index 0000000000..9bcdd8f6eb --- /dev/null +++ b/tests/lib/config/objects_test.py @@ -0,0 +1,411 @@ +#! /usr/env/bin/python +""" Unit tests for lib.convert.objects """ +import pytest + +from lib.config.objects import ConfigItem +# pylint:disable=invalid-name + +_TEST_GROUP = "TestGroup" +_TEST_INFO = "TestInfo" + + +_STR_CONFIG = ( # type:ignore[var-annotated] + ("TestDefault", ["TestDefault", "Other"], "success-choices"), + ("TestDefault", [], "success-no-choices"), + ("#ffffff", "colorchooser", "success-colorchooser"), + ("FailDefault", ["TestDefault", "Other"], "fail-choices"), + ("TestDefault", "Invalid", "fail-invalid-choices"), + ("TestDefault", "colorchooser", "fail-colorchooser"), + (1, [], "fail-int"), + (1.1, [], "fail-float"), + (True, [], "fail-bool"), + (["test", "list"], [], "fail-list")) +_STR_PARAMS = ["default", "choices", "status"] + + +@pytest.mark.parametrize(_STR_PARAMS, _STR_CONFIG, ids=[x[-1] for x in _STR_CONFIG]) +def test_ConfigItem_str(default, choices, status): + """ Test that datatypes validate for strings and value is set correctly """ + dtype = str + if status.startswith("success"): + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + choices=choices) + assert dclass.value == default.lower() + else: + with pytest.raises(ValueError): + ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + choices=choices) + + +_INT_CONFIG = ((10, (0, 100), 1, "success"), + (20, None, 1, "fail-min-max-missing"), + (30, (0.1, 100.1), 1, "fail-min-max-dtype"), + (35, "TestMinMax", 1, "fail-min-max-type"), + (40, (0, 100), -1, "fail-rounding-missing"), + (50, (0, 100), 1.1, "fail-rounding-dtype"), + ("TestDefault", (0, 100), 1, "fail-str"), + (1.1, (0, 100), 1, "fail-float"), + (True, (0, 100), 1, "fail-bool"), + ([0, 1], [0.0, 100.0], 1, "fail-list")) +_INT_PARAMS = ["default", "min_max", "rounding", "status"] + + +@pytest.mark.parametrize(_INT_PARAMS, _INT_CONFIG, ids=[x[-1] for x in _INT_CONFIG]) +def test_ConfigItem_int(default, min_max, rounding, status): + """ Test that datatypes validate for integers and value is set correctly """ + dtype = int + if status.startswith("success"): + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + min_max=min_max, + rounding=rounding) + assert dclass.value == default + else: + with pytest.raises(ValueError): + ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + min_max=min_max, + rounding=rounding) + + +_FLOAT_CONFIG = ((10.0, (0.0, 100.0), 1, "success"), + (20.1, None, 1, "fail-min-max-missing"), + (30.2, (1, 100), 1, "fail-min-max-dtype"), + (35.0, "TestMinMax", 1, "fail-min-max-type"), + (40.3, (0.0, 100.0), -1, "fail-rounding-missing"), + (50.4, (0.0, 100.0), 1.1, "fail-rounding-dtype"), + ("TestDefault", (0.0, 100.0), 1, "fail-str"), + (1, (0.0, 100.0), 1, "fail-float"), + (True, (0.0, 100.0), 1, "fail-bool"), + ([0.1, 1.2], [0.0, 100.0], 1, "fail-list")) +_FLOAT_PARAMS = ["default", "min_max", "rounding", "status"] + + +@pytest.mark.parametrize(_FLOAT_PARAMS, _FLOAT_CONFIG, ids=[x[-1] for x in _FLOAT_CONFIG]) +def test_ConfigItem_float(default, min_max, rounding, status): + """ Test that datatypes validate for floats and value is set correctly """ + dtype = float + if status.startswith("success"): + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + min_max=min_max, + rounding=rounding) + assert dclass.value == default + else: + with pytest.raises(ValueError): + ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + min_max=min_max, + rounding=rounding) + + +_BOOL_CONFIG = ((True, "success-true"), + (False, "success-false"), + ("True", "fail-str"), + (42, "fail-int"), + (42.69, "fail-float"), + ([True, False], "fail-list")) +_BOOL_PARAMS = ["default", "status"] + + +@pytest.mark.parametrize(_BOOL_PARAMS, _BOOL_CONFIG, ids=[x[-1] for x in _BOOL_CONFIG]) +def test_ConfigItem_bool(default, status): + """ Test that datatypes validate for bool and value is set correctly """ + dtype = bool + if status.startswith("success"): + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO) + assert dclass.value is default + else: + with pytest.raises(ValueError): + ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO) + + +_LIST_CONFIG = ( # type:ignore[var-annotated] + (["TestDefault"], ["TestDefault", "Other"], "success"), + (["TestDefault", "Fail"], ["TestDefault", "Other"], "fail-invalid-choice"), + (["TestDefault"], [], "fail-no-choices"), + ([1, 2], [1, 2, 3], "fail-dtype"), + ("test", ["TestDefault", "Other"], "fail-str"), + (1, ["TestDefault", "Other"], "fail-int"), + (1.1, ["TestDefault", "Other"], "fail-float"), + (True, ["TestDefault", "Other"], "fail-bool")) +_LIST_PARAMS = ["default", "choices", "status"] + + +@pytest.mark.parametrize(_LIST_PARAMS, _LIST_CONFIG, ids=[x[-1] for x in _LIST_CONFIG]) +def test_ConfigItem_list(default, choices, status): + """ Test that datatypes validate for strings and value is set correctly """ + dtype = list + if status.startswith("success"): + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + choices=choices) + assert dclass.value == [x.lower() for x in default] + else: + with pytest.raises(ValueError): + ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + choices=choices) + + +_REQ_CONFIG = (("TestGroup", "TestInfo", "success"), + ("", "TestGroup", "fail-no-group"), + ("TestGroup", "", "fail-no-info")) +_REQ_PARAMS = ["group", "info", "status"] + + +@pytest.mark.parametrize(_REQ_PARAMS, _REQ_CONFIG, ids=[x[-1] for x in _REQ_CONFIG]) +def test_ConfigItem_missing_required(group, info, status): + """ Test that an error is raised when either group or info are not provided """ + dtype = str + default = "test" + if status.startswith("success"): + dclass = ConfigItem(datatype=dtype, + default=default, + group=group, + info=info) + assert dclass.group == group + assert dclass.info == info + assert isinstance(dclass.helptext, str) and dclass.helptext + assert dclass.name == "" + else: + with pytest.raises(ValueError): + ConfigItem(datatype=dtype, + default=default, + group=group, + info=info) + + +_NAME_CONFIG = (("TestName", "success"), + ("", "fail-no-name"), + (100, "fail-dtype")) + + +@pytest.mark.parametrize(("name", "status"), _NAME_CONFIG, ids=[x[-1] for x in _NAME_CONFIG]) +def test_ConfigItem_set_name(name, status): + """ Test that setting the config item's name functions correctly """ + dtype = str + default = "test" + dclass = ConfigItem(datatype=dtype, + default=default, + group="TestGroup", + info="TestInfo") + if status.startswith("success"): + dclass.set_name(name) + assert dclass.name == name + else: + with pytest.raises(AssertionError): + dclass.set_name(name) + + +_STR_SET_CONFIG = ( # type:ignore[var-annotated] + ("NewValue", ["TestDefault", "NewValue"], "success-choices"), + ("NoValue", ["TestDefault", "NewValue"], "success-fallback"), + ("NewValue", [], "success-no-choices"), + ("#AAAAAA", "colorchooser", "success-colorchooser"), + ("NewValue", "colorchooser", "fail-colorchooser"), + (1, [], "fail-int"), + (1.1, [], "fail-float"), + (True, [], "fail-bool"), + (["test", "list"], [], "fail-list")) +_STR_SET_PARAMS = ("value", "choices", "status") + + +@pytest.mark.parametrize(_STR_SET_PARAMS, _STR_SET_CONFIG, ids=[x[-1] for x in _STR_SET_CONFIG]) +def test_ConfigItem_set_str(value, choices, status): + """ Test that strings validate and set correctly """ + default = "#ffffff" if choices == "colorchooser" else "TestDefault" + dtype = str + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + choices=choices) + + with pytest.raises(ValueError): # Confirm setting fails when name not set + dclass.set(value) + + dclass.set_name("TestName") + + if status.startswith("success"): + dclass.set(value) + if status == "success-fallback": + assert dclass.value == dclass() == dclass.get() == dclass.default.lower() + else: + assert dclass.value == dclass() == dclass.get() == value.lower() + else: + with pytest.raises(ValueError): + dclass.set(value) + + +_INT_SET_CONFIG = ((10, "success"), + ("Test", "fail-str"), + (1.1, "fail-float"), + (["test", "list"], "fail-list")) +_INT_SET_PARAMS = ("value", "status") + + +@pytest.mark.parametrize(_INT_SET_PARAMS, _INT_SET_CONFIG, ids=[x[-1] for x in _INT_SET_CONFIG]) +def test_ConfigItem_set_int(value, status): + """ Test that ints validate and set correctly """ + default = 20 + dtype = int + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + min_max=(0, 10), + rounding=1) + + with pytest.raises(ValueError): # Confirm setting fails when name not set + dclass.set(value) + + dclass.set_name("TestName") + + if status.startswith("success"): + dclass.set(value) + assert dclass.value == dclass() == dclass.get() == value + else: + with pytest.raises(ValueError): + dclass.set(value) + + +_FLOAT_SET_CONFIG = ((69.42, "success"), + ("Test", "fail-str"), + (42, "fail-int"), + (True, "fail-bool"), + (["test", "list"], "fail-list")) +_FLOAT_SET_PARAMS = ("value", "status") + + +@pytest.mark.parametrize(_FLOAT_SET_PARAMS, + _FLOAT_SET_CONFIG, + ids=[x[-1] for x in _FLOAT_SET_CONFIG]) +def test_ConfigItem_set_float(value, status): + """ Test that floats validate and set correctly """ + default = 20.025 + dtype = float + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + min_max=(0.0, 100.0), + rounding=1) + + with pytest.raises(ValueError): # Confirm setting fails when name not set + dclass.set(value) + + dclass.set_name("TestName") + + if status.startswith("success"): + dclass.set(value) + assert dclass.value == dclass() == dclass.get() == value + else: + with pytest.raises(ValueError): + dclass.set(value) + + +_BOOL_SET_CONFIG = ((True, "success-true"), + (False, "success-false"), + ("Test", "fail-str"), + (42, "fail-int"), + (42.69, "fail-float"), + (["test", "list"], "fail-list")) +_BOOL_SET_PARAMS = ("value", "status") + + +@pytest.mark.parametrize(_BOOL_SET_PARAMS, _BOOL_SET_CONFIG, ids=[x[-1] for x in _BOOL_SET_CONFIG]) +def test_ConfigItem_set_bool(value, status): + """ Test that bools validate and set correctly """ + default = True + dtype = bool + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO) + + with pytest.raises(ValueError): # Confirm setting fails when name not set + dclass.set(value) + + dclass.set_name("TestName") + + if status.startswith("success"): + dclass.set(value) + assert dclass.value == dclass() == dclass.get() == value + else: + with pytest.raises(ValueError): + dclass.set(value) + + +_LIST_SET_CONFIG = ((["NewValue"], "success-choices"), + ("NewValue, TestDefault", "success-delim-comma"), + ("NewValue TestDefault", "success-delim-space"), + ("NewValue", "success-delim-1value"), + (["NoValue"], "success-fallback1"), + (["NewValue", "NoValue"], "success-fallback2"), + ("NewValue, NoValue", "success-fallback-delim-comma"), + ("NewValue NoValue", "success-fallback-delim-space"), + ("NoValue", "success-fallback-delim-1value"), + (1, "fail-int"), + (1.1, "fail-float"), + (True, "fail-bool")) +_LIST_SET_PARAMS = ("value", "status") + + +@pytest.mark.parametrize(_LIST_SET_PARAMS, _LIST_SET_CONFIG, ids=[x[-1] for x in _LIST_SET_CONFIG]) +def test_ConfigItem_set_list(value, status): + """ Test that lists validate and set correctly """ + default = ["TestDefault"] + choices = ["TestDefault", "NewValue"] + dtype = list + dclass = ConfigItem(datatype=dtype, + default=default, + group=_TEST_GROUP, + info=_TEST_INFO, + choices=choices) + + with pytest.raises(ValueError): # Confirm setting fails when name not set + dclass.set(value) + + dclass.set_name("TestName") + + if status.startswith("success"): + dclass.set(value) + + if not isinstance(value, list): + value = [x.strip() for x in value.split(",")] if "," in value else value.split() + assert dclass.value == dclass() == dclass.get() + expected = [x.lower() for x in value] + if status.startswith("success-fallback"): + expected = [x.lower() for x in value if x in choices] + if not expected: + expected = [x.lower() for x in default] + assert set(expected) == set(dclass.value) + + else: + with pytest.raises(ValueError): + dclass.set(value) diff --git a/tests/lib/gpu_stats/__init__.py b/tests/lib/gpu_stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/gpu_stats/_base_test.py b/tests/lib/gpu_stats/_base_test.py new file mode 100644 index 0000000000..4edbd4df52 --- /dev/null +++ b/tests/lib/gpu_stats/_base_test.py @@ -0,0 +1,134 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.gpu_stats._base` """ +import typing as T + +from dataclasses import dataclass +from unittest.mock import MagicMock + +import pytest +import pytest_mock + +# pylint:disable=protected-access +from lib.gpu_stats._base import BiggestGPUInfo, GPUInfo, _GPUStats +from lib.utils import get_backend + + +@dataclass +class _DummyData: + """ Dummy data for initializing and testing :class:`~lib.gpu_stats._base._GPUStats` """ + device_count = 2 + active_devices = [0, 1] + handles = [0, 1] + driver = "test_driver" + device_names = ['test_device_0', 'test_device_1'] + vram = [1024, 2048] + free_vram = [512, 1024] + + +@pytest.fixture(name="gpu_stats_instance") +def fixture__gpu_stats_instance(mocker: pytest_mock.MockerFixture) -> _GPUStats: + """ Create a fixture of the :class:`~lib.gpu_stats._base._GPUStats` object + + Parameters + ---------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + """ + mocker.patch.object(_GPUStats, '_initialize') + mocker.patch.object(_GPUStats, '_shutdown') + mocker.patch.object(_GPUStats, '_get_device_count', return_value=_DummyData.device_count) + mocker.patch.object(_GPUStats, '_get_active_devices', return_value=_DummyData.active_devices) + mocker.patch.object(_GPUStats, '_get_handles', return_value=_DummyData.handles) + mocker.patch.object(_GPUStats, '_get_driver', return_value=_DummyData.driver) + mocker.patch.object(_GPUStats, '_get_device_names', return_value=_DummyData.device_names) + mocker.patch.object(_GPUStats, '_get_vram', return_value=_DummyData.vram) + mocker.patch.object(_GPUStats, '_get_free_vram', return_value=_DummyData.free_vram) + gpu_stats = _GPUStats() + return gpu_stats + + +def test__gpu_stats_init_(gpu_stats_instance: _GPUStats) -> None: + """ Test that the base :class:`~lib.gpu_stats._base._GPUStats` class initializes correctly + + Parameters + ---------- + gpu_stats_instance: :class:`_GPUStats` + Fixture instance of the _GPUStats base class + """ + # Ensure that the object is initialized and shutdown correctly + assert gpu_stats_instance._is_initialized is False + assert T.cast(MagicMock, gpu_stats_instance._initialize).call_count == 1 + assert T.cast(MagicMock, gpu_stats_instance._shutdown).call_count == 1 + + # Ensure that the object correctly gets and stores the device count, active devices, + # handles, driver, device names, and VRAM information + assert gpu_stats_instance.device_count == _DummyData.device_count + assert gpu_stats_instance._active_devices == _DummyData.active_devices + assert gpu_stats_instance._handles == _DummyData.handles + assert gpu_stats_instance._driver == _DummyData.driver + assert gpu_stats_instance._device_names == _DummyData.device_names + assert gpu_stats_instance._vram == _DummyData.vram + + +def test__gpu_stats_properties(gpu_stats_instance: _GPUStats) -> None: + """ Test that the :class:`~lib.gpu_stats._base._GPUStats` properties are set and formatted + correctly. + + Parameters + ---------- + gpu_stats_instance: :class:`_GPUStats` + Fixture instance of the _GPUStats base class + """ + assert gpu_stats_instance.cli_devices == ['0: test_device_0', '1: test_device_1'] + assert gpu_stats_instance.sys_info == GPUInfo(vram=_DummyData.vram, + vram_free=_DummyData.free_vram, + driver=_DummyData.driver, + devices=_DummyData.device_names, + devices_active=_DummyData.active_devices) + + +def test__gpu_stats_get_card_most_free(mocker: pytest_mock.MockerFixture, + gpu_stats_instance: _GPUStats) -> None: + """ Confirm that :func:`ib.gpu_stats._base._GPUStats.get_card_most_free` functions + correctly + + Parameters + ---------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + gpu_stats_instance: :class:`_GPUStats` + Fixture instance of the _GPUStats base class + """ + assert gpu_stats_instance.get_card_most_free() == BiggestGPUInfo(card_id=1, + device='test_device_1', + free=1024, + total=2048) + mocker.patch.object(_GPUStats, '_get_active_devices', return_value=[]) + gpu_stats = _GPUStats() + assert gpu_stats.get_card_most_free() == BiggestGPUInfo(card_id=-1, + device='No GPU devices found', + free=2048, + total=2048) + + +def test__gpu_stats_no_active_devices( + caplog: pytest.LogCaptureFixture, + gpu_stats_instance: _GPUStats, # pylint:disable=unused-argument + mocker: pytest_mock.MockerFixture) -> None: + """ Ensure that no active GPUs raises a warning when not in CPU mode + + Parameters + ---------- + caplog: :class:`pytest.LogCaptureFixture` + Pytest's log capturing fixture + gpu_stats_instance: :class:`_GPUStats` + Fixture instance of the _GPUStats base class + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + """ + if get_backend() == "cpu": + return + caplog.set_level("WARNING") + mocker.patch.object(_GPUStats, '_get_active_devices', return_value=[]) + _GPUStats() + assert "No GPU detected" in caplog.messages diff --git a/tests/lib/gui/__init__.py b/tests/lib/gui/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/gui/stats/__init__.py b/tests/lib/gui/stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/gui/stats/event_reader_test.py b/tests/lib/gui/stats/event_reader_test.py new file mode 100644 index 0000000000..0e4094aba6 --- /dev/null +++ b/tests/lib/gui/stats/event_reader_test.py @@ -0,0 +1,752 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.gui.stats.event_reader` """ +# pylint:disable=protected-access +from __future__ import annotations +import json +import os +import typing as T + +from shutil import rmtree +from time import time +from unittest.mock import MagicMock + +import numpy as np +import pytest +import pytest_mock + +from tensorboard.compat.proto import event_pb2 + +from lib.gui.analysis.event_reader import (_Cache, _CacheData, _EventParser, + _LogFiles, EventData, TensorBoardLogs) + +if T.TYPE_CHECKING: + from collections.abc import Iterator + + +def test__logfiles(tmp_path: str): + """ Test the _LogFiles class operates correctly + + Parameters + ---------- + tmp_path: :class:`pathlib.Path` + """ + # dummy logfiles + junk data + sess_1 = os.path.join(tmp_path, "session_1", "train") + sess_2 = os.path.join(tmp_path, "session_2", "train") + os.makedirs(sess_1) + os.makedirs(sess_2) + + test_log_1 = os.path.join(sess_1, "events.out.tfevents.123.456.v2") + test_log_2 = os.path.join(sess_2, "events.out.tfevents.789.012.v2") + test_log_junk = os.path.join(sess_2, "test_file.txt") + + for fname in (test_log_1, test_log_2, test_log_junk): + with open(fname, "a", encoding="utf-8"): + pass + + log_files = _LogFiles(tmp_path) + # Test all correct + assert isinstance(log_files._filenames, dict) + assert len(log_files._filenames) == 2 + assert log_files._filenames == {1: test_log_1, 2: test_log_2} + + assert log_files.session_ids == [1, 2] + + assert log_files.get(1) == test_log_1 + assert log_files.get(2) == test_log_2 + + # Remove a file, refresh and check again + rmtree(sess_1) + log_files.refresh() + assert log_files._filenames == {2: test_log_2} + assert log_files.get(2) == test_log_2 + assert log_files.get(3) == "" + + +def test__cachedata(): + """ Test the _CacheData class operates correctly """ + labels = ["label_a", "label_b"] + timestamps = np.array([1.23, 4.56], dtype="float64") + loss = np.array([[2.34, 5.67], [3.45, 6.78]], dtype="float32") + + # Initial test + cache = _CacheData(labels, timestamps, loss) + assert cache.labels == labels + assert cache._timestamps_shape == timestamps.shape + assert cache._loss_shape == loss.shape + np.testing.assert_array_equal(cache.timestamps, timestamps) + np.testing.assert_array_equal(cache.loss, loss) + + # Add data test + new_timestamps = np.array([2.34, 6.78], dtype="float64") + new_loss = np.array([[3.45, 7.89], [8.90, 1.23]], dtype="float32") + + expected_timestamps = np.concatenate([timestamps, new_timestamps]) + expected_loss = np.concatenate([loss, new_loss]) + + cache.add_live_data(new_timestamps, new_loss) + assert cache.labels == labels + assert cache._timestamps_shape == expected_timestamps.shape + assert cache._loss_shape == expected_loss.shape + np.testing.assert_array_equal(cache.timestamps, expected_timestamps) + np.testing.assert_array_equal(cache.loss, expected_loss) + + +# _Cache tests +class Test_Cache: # pylint:disable=invalid-name + """ Test that :class:`lib.gui.analysis.event_reader._Cache` works correctly """ + @staticmethod + def test_init() -> None: + """ Test __init__ """ + cache = _Cache() + assert isinstance(cache._data, dict) + assert isinstance(cache._carry_over, dict) + assert isinstance(cache._loss_labels, list) + assert not cache._data + assert not cache._carry_over + assert not cache._loss_labels + + @staticmethod + def test_is_cached() -> None: + """ Test is_cached function works """ + cache = _Cache() + + data = _CacheData(["test_1", "test_2"], + np.array([1.23, ], dtype="float64"), + np.array([[2.34, ], [4.56]], dtype="float32")) + cache._data[1] = data + assert cache.is_cached(1) + assert not cache.is_cached(2) + + @staticmethod + def test_cache_data(mocker: pytest_mock.MockerFixture) -> None: + """ Test cache_data function works + + Parameters + ---------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking full_info called from _SysInfo + """ + cache = _Cache() + + session_id = 1 + data = {1: EventData(4., [1., 2.]), 2: EventData(5., [3., 4.])} + labels = ['label1', 'label2'] + is_live = False + + cache.cache_data(session_id, data, labels, is_live) + assert cache._loss_labels == labels + assert cache.is_cached(session_id) + np.testing.assert_array_equal(cache._data[session_id].timestamps, np.array([4., 5.])) + np.testing.assert_array_equal(cache._data[session_id].loss, np.array([[1., 2.], [3., 4.]])) + + add_live = mocker.patch("lib.gui.analysis.event_reader._Cache._add_latest_live") + is_live = True + cache.cache_data(session_id, data, labels, is_live) + assert add_live.called + + @staticmethod + def test__to_numpy() -> None: + """ Test _to_numpy function works """ + cache = _Cache() + cache._loss_labels = ['label1', 'label2'] + data = {1: EventData(4., [1., 2.]), 2: EventData(5., [3., 4.])} + + # Non-live + is_live = False + times, loss = cache._to_numpy(data, is_live) + np.testing.assert_array_equal(times, np.array([4., 5.])) + np.testing.assert_array_equal(loss, np.array([[1., 2.], [3., 4.]])) + + # Correctly collected live + is_live = True + times, loss = cache._to_numpy(data, is_live) + np.testing.assert_array_equal(times, np.array([4., 5.])) + np.testing.assert_array_equal(loss, np.array([[1., 2.], [3., 4.]])) + + # Incorrectly collected live + live_data = {1: EventData(4., [1., 2.]), + 2: EventData(5., [3.]), + 3: EventData(6., [4., 5., 6.])} + times, loss = cache._to_numpy(live_data, is_live) + np.testing.assert_array_equal(times, np.array([4.])) + np.testing.assert_array_equal(loss, np.array([[1., 2.]])) + + @staticmethod + def test__collect_carry_over() -> None: + """ Test _collect_carry_over function works """ + data = {1: EventData(3., [4., 5.]), 2: EventData(6., [7., 8.])} + carry_over = {1: EventData(3., [2., 3.])} + expected = {1: EventData(3., [2., 3., 4., 5.]), 2: EventData(6., [7., 8.])} + + cache = _Cache() + cache._carry_over = carry_over + cache._collect_carry_over(data) + assert data == expected + + @staticmethod + def test__process_data() -> None: + """ Test _process_data function works """ + cache = _Cache() + cache._loss_labels = ['label1', 'label2'] + + data = {1: EventData(4., [5., 6.]), + 2: EventData(5., [7., 8.]), + 3: EventData(6., [9.])} + is_live = False + expected_timestamps = np.array([4., 5.]) + expected_loss = np.array([[5., 6.], [7., 8.]]) + expected_carry_over = {3: EventData(6., [9.])} + + timestamps, loss = cache._process_data(data, is_live) + np.testing.assert_array_equal(timestamps, expected_timestamps) + np.testing.assert_array_equal(loss, expected_loss) + assert not cache._carry_over + + is_live = True + timestamps, loss = cache._process_data(data, is_live) + np.testing.assert_array_equal(timestamps, expected_timestamps) + np.testing.assert_array_equal(loss, expected_loss) + assert cache._carry_over == expected_carry_over + + @staticmethod + def test__add_latest_live() -> None: + """ Test _add_latest_live function works """ + session_id = 1 + labels = ['label1', 'label2'] + data = {1: EventData(3., [5., 6.]), 2: EventData(4., [7., 8.])} + new_timestamp = np.array([5.], dtype="float64") + new_loss = np.array([[8., 9.]], dtype="float32") + expected_timestamps = np.array([3., 4., 5.]) + expected_loss = np.array([[5., 6.], [7., 8.], [8., 9.]]) + + cache = _Cache() + cache.cache_data(session_id, data, labels) # Initial data + cache._add_latest_live(session_id, new_loss, new_timestamp) + + assert cache.is_cached(session_id) + assert cache._loss_labels == labels + np.testing.assert_array_equal(cache._data[session_id].timestamps, expected_timestamps) + np.testing.assert_array_equal(cache._data[session_id].loss, expected_loss) + + @staticmethod + def test_get_data() -> None: + """ Test get_data function works """ + session_id = 1 + + cache = _Cache() + assert cache.get_data(session_id, "loss") is None + assert cache.get_data(session_id, "timestamps") is None + + labels = ['label1', 'label2'] + data = {1: EventData(3., [5., 6.]), 2: EventData(4., [7., 8.])} + expected_timestamps = np.array([3., 4.]) + expected_loss = np.array([[5., 6.], [7., 8.]]) + + cache.cache_data(session_id, data, labels, is_live=False) + get_timestamps = cache.get_data(session_id, "timestamps") + get_loss = cache.get_data(session_id, "loss") + + assert isinstance(get_timestamps, dict) + assert len(get_timestamps) == 1 + assert list(get_timestamps) == [session_id] + result = get_timestamps[session_id] + assert list(result) == ["timestamps"] + np.testing.assert_array_equal(result["timestamps"], expected_timestamps) + + assert isinstance(get_loss, dict) + assert len(get_loss) == 1 + assert list(get_loss) == [session_id] + result = get_loss[session_id] + assert list(result) == ["loss", "labels"] + np.testing.assert_array_equal(result["loss"], expected_loss) + + +# TensorBoardLogs +class TestTensorBoardLogs: + """ Test that :class:`lib.gui.analysis.event_reader.TensorBoardLogs` works correctly """ + + @pytest.fixture(name="tensorboardlogs_instance") + def tensorboardlogs_fixture(self, + tmp_path: str, + request: pytest.FixtureRequest) -> TensorBoardLogs: + """ Pytest fixture for :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + + Parameters + ---------- + tmp_path: :class:`pathlib.Path` + Temporary folder for dummy data + + Returns + ------- + :class::class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance for testing + """ + sess_1 = os.path.join(tmp_path, "session_1", "train") + sess_2 = os.path.join(tmp_path, "session_2", "train") + os.makedirs(sess_1) + os.makedirs(sess_2) + + test_log_1 = os.path.join(sess_1, "events.out.tfevents.123.456.v2") + test_log_2 = os.path.join(sess_2, "events.out.tfevents.789.012.v2") + + for fname in (test_log_1, test_log_2): + with open(fname, "a", encoding="utf-8"): + pass + + tblogs_instance = TensorBoardLogs(tmp_path, False) + + def teardown(): + try: + rmtree(tmp_path) + except PermissionError: + pass + + request.addfinalizer(teardown) + return tblogs_instance + + @staticmethod + def test_init(tensorboardlogs_instance: TensorBoardLogs) -> None: + """ Test __init__ works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + """ + tb_logs = tensorboardlogs_instance + assert isinstance(tb_logs._log_files, _LogFiles) + assert isinstance(tb_logs._cache, _Cache) + assert not tb_logs._is_training + + is_training = True + folder = tb_logs._log_files._logs_folder + tb_logs = TensorBoardLogs(folder, is_training) + assert tb_logs._is_training + + @staticmethod + def test_session_ids(tensorboardlogs_instance: TensorBoardLogs) -> None: + """ Test session_ids property works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + """ + tb_logs = tensorboardlogs_instance + assert tb_logs.session_ids == [1, 2] + + @staticmethod + def test_set_training(tensorboardlogs_instance: TensorBoardLogs) -> None: + """ Test set_training works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + """ + tb_logs = tensorboardlogs_instance + assert not tb_logs._is_training + assert tb_logs._training_iterator is None + tb_logs.set_training(True) + assert tb_logs._is_training + assert tb_logs._training_iterator is not None + tb_logs.set_training(False) + assert not tb_logs._is_training + assert tb_logs._training_iterator is None + + @staticmethod + def test__cache_data(tensorboardlogs_instance: TensorBoardLogs, + mocker: pytest_mock.MockerFixture) -> None: + """ Test _cache_data works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking event parser caching is called + """ + tb_logs = tensorboardlogs_instance + session_id = 1 + cacher = mocker.patch("lib.gui.analysis.event_reader._EventParser.cache_events") + tb_logs._cache_data(session_id) + assert cacher.called + cacher.reset_mock() + + tb_logs.set_training(True) + tb_logs._cache_data(session_id) + assert cacher.called + + @staticmethod + def test__check_cache(tensorboardlogs_instance: TensorBoardLogs, + mocker: pytest_mock.MockerFixture) -> None: + """ Test _check_cache works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking _cache_data is called + """ + is_cached = mocker.patch("lib.gui.analysis.event_reader._Cache.is_cached") + cache_data = mocker.patch("lib.gui.analysis.event_reader.TensorBoardLogs._cache_data") + tb_logs = tensorboardlogs_instance + + # Session ID not training + is_cached.return_value = False + tb_logs._check_cache(1) + assert is_cached.called + assert cache_data.called + is_cached.reset_mock() + cache_data.reset_mock() + + is_cached.return_value = True + tb_logs._check_cache(1) + assert is_cached.called + assert not cache_data.called + is_cached.reset_mock() + cache_data.reset_mock() + + # Session ID and training + tb_logs.set_training(True) + tb_logs._check_cache(1) + assert not cache_data.called + cache_data.reset_mock() + + tb_logs._check_cache(2) + assert cache_data.called + cache_data.reset_mock() + + # No session id + tb_logs.set_training(False) + is_cached.return_value = False + + tb_logs._check_cache(None) + assert is_cached.called + assert cache_data.called + is_cached.reset_mock() + cache_data.reset_mock() + + is_cached.return_value = True + tb_logs._check_cache(None) + assert is_cached.called + assert not cache_data.called + is_cached.reset_mock() + cache_data.reset_mock() + + @staticmethod + def test_get_loss(tensorboardlogs_instance: TensorBoardLogs, + mocker: pytest_mock.MockerFixture) -> None: + """ Test get_loss works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking _cache_data is called + """ + tb_logs = tensorboardlogs_instance + + mocker.patch("lib.gui.analysis.event_reader.RecordIterator") + tb_logs.get_loss(3) + + check_cache = mocker.patch("lib.gui.analysis.event_reader.TensorBoardLogs._check_cache") + get_data = mocker.patch("lib.gui.analysis.event_reader._Cache.get_data") + get_data.return_value = None + + assert isinstance(tb_logs.get_loss(None), dict) + assert check_cache.call_count == 2 + assert get_data.call_count == 2 + check_cache.reset_mock() + get_data.reset_mock() + + assert isinstance(tb_logs.get_loss(1), dict) + assert check_cache.call_count == 1 + assert get_data.call_count == 1 + check_cache.reset_mock() + get_data.reset_mock() + + @staticmethod + def test_get_timestamps(tensorboardlogs_instance: TensorBoardLogs, + mocker: pytest_mock.MockerFixture) -> None: + """ Test get_timestamps works correctly + + Parameters + ---------- + tensorboadlogs_instance: :class:`lib.gui.analysis.event_reader.TensorBoardLogs` + The class instance to test + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking _cache_data is called + """ + tb_logs = tensorboardlogs_instance + mocker.patch("lib.gui.analysis.event_reader.RecordIterator") + + tb_logs.get_timestamps(3) + + check_cache = mocker.patch("lib.gui.analysis.event_reader.TensorBoardLogs._check_cache") + get_data = mocker.patch("lib.gui.analysis.event_reader._Cache.get_data") + get_data.return_value = None + + assert isinstance(tb_logs.get_timestamps(None), dict) + assert check_cache.call_count == 2 + assert get_data.call_count == 2 + check_cache.reset_mock() + get_data.reset_mock() + + assert isinstance(tb_logs.get_timestamps(1), dict) + assert check_cache.call_count == 1 + assert get_data.call_count == 1 + check_cache.reset_mock() + get_data.reset_mock() + + +# EventParser +class Test_EventParser: # pylint:disable=invalid-name + """ Test that :class:`lib.gui.analysis.event_reader.TensorBoardLogs` works correctly """ + def _create_example_event(self, + step: int, + loss_value: float, + timestamp: float, + serialize: bool = True) -> bytes: + """ Generate a test TensorBoard event + + Parameters + ---------- + step: int + The step value to use + loss_value: float + The loss value to store + timestamp: float + The timestamp to store + serialize: bool, optional + ``True`` to serialize the event to bytes, ``False`` to return the Event object + """ + tags = {0: "keras", 1: "batch_total", 2: "batch_face_a", 3: "batch_face_b"} + event = event_pb2.Event(step=step) + event.summary.value.add(tag=tags[step], # pylint:disable=no-member + simple_value=loss_value) + event.wall_time = timestamp + retval = event.SerializeToString() if serialize else event + return retval + + @pytest.fixture(name="mock_iterator") + def iterator(self) -> Iterator[bytes]: + """ Dummy iterator for generating test events + + Yields + ------ + bytes + A serialized test Tensorboard Event + """ + return iter([self._create_example_event(i, 1 + (i / 10), time()) for i in range(4)]) + + @pytest.fixture(name="mock_cache") + def mock_cache(self): + """ Dummy :class:`_Cache` for testing""" + class _CacheMock: + def __init__(self): + self.data = {} + self._loss_labels = [] + + def is_cached(self, session_id): + """ Dummy is_cached method""" + return session_id in self.data + + def cache_data(self, session_id, data, labels, + is_live=False): # pylint:disable=unused-argument + """ Dummy cache_data method""" + self.data[session_id] = {'data': data, 'labels': labels} + + return _CacheMock() + + @pytest.fixture(name="event_parser_instance") + def event_parser_fixture(self, + mock_iterator: Iterator[bytes], + mock_cache: _Cache) -> _EventParser: + """ Pytest fixture for :class:`lib.gui.analysis.event_reader._EventParser` + + Parameters + ---------- + mock_iterator: Iterator[bytes] + Dummy iterator for generating TF Event data + mock_cache: :class:'_CacheMock' + Dummy _Cache object + + Returns + ------- + :class::class:`lib.gui.analysis.event_reader._EventParser` + The class instance for testing + """ + event_parser = _EventParser(mock_iterator, mock_cache, live_data=False) + return event_parser + + def test__init_(self, + event_parser_instance: _EventParser, + mock_iterator: Iterator[bytes], + mock_cache: _Cache) -> None: + """ Test __init__ works correctly + + Parameters + ---------- + event_parser_instance: :class:`lib.gui.analysis.event_reader._EventParser` + The class instance to test + mock_iterator: Iterator[bytes] + Dummy iterator for generating TF Event data + mock_cache: :class:'_CacheMock' + Dummy _Cache object + """ + event_parse = event_parser_instance + assert not hasattr(event_parse._iterator, "__name__") + evp_live = _EventParser(mock_iterator, mock_cache, live_data=True) + assert evp_live._iterator.__name__ == "_get_latest_live" # type:ignore[attr-defined] + + def test__get_latest_live(self, event_parser_instance: _EventParser) -> None: + """ Test _get_latest_live works correctly + + Parameters + ---------- + event_parser_instance: :class:`lib.gui.analysis.event_reader._EventParser` + The class instance to test + """ + event_parse = event_parser_instance + test = list(event_parse._get_latest_live(event_parse._iterator)) + assert len(test) == 4 + + def test_cache_events(self, + event_parser_instance: _EventParser, + mocker: pytest_mock.MockerFixture, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test cache_events works correctly + + Parameters + ---------- + event_parser_instance: :class:`lib.gui.analysis.event_reader._EventParser` + The class instance to test + mocker: :class:`pytest_mock.MockerFixture` + Mocker for capturing method calls + monkeypatch: :class:`pytest.MonkeyPatch` + For patching different iterators for testing output + """ + monkeypatch.setattr("lib.utils._FS_BACKEND", "cpu") + + event_parse = event_parser_instance + event_parse._parse_outputs = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + event_parse._process_event = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + event_parse._cache.cache_data = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + + # keras model + monkeypatch.setattr(event_parse, + "_iterator", + iter([self._create_example_event(0, 1., time())])) + event_parse.cache_events(1) + assert event_parse._parse_outputs.called + assert not event_parse._process_event.called + assert event_parse._cache.cache_data.called + event_parse._parse_outputs.reset_mock() + event_parse._process_event.reset_mock() + event_parse._cache.cache_data.reset_mock() + + # Batch item + monkeypatch.setattr(event_parse, + "_iterator", + iter([self._create_example_event(1, 1., time())])) + event_parse.cache_events(1) + assert not event_parse._parse_outputs.called + assert event_parse._process_event.called + assert event_parse._cache.cache_data.called + event_parse._parse_outputs.reset_mock() + event_parse._process_event.reset_mock() + event_parse._cache.cache_data.reset_mock() + + # No summary value + monkeypatch.setattr(event_parse, + "_iterator", + iter([event_pb2.Event(step=1).SerializeToString()])) + assert not event_parse._parse_outputs.called + assert not event_parse._process_event.called + assert not event_parse._cache.cache_data.called + event_parse._parse_outputs.reset_mock() + event_parse._process_event.reset_mock() + event_parse._cache.cache_data.reset_mock() + + def test__parse_outputs(self, + event_parser_instance: _EventParser, + mocker: pytest_mock.MockerFixture) -> None: + """ Test _parse_outputs works correctly + + Parameters + ---------- + event_parser_instance: :class:`lib.gui.analysis.event_reader._EventParser` + The class instance to test + mocker: :class:`pytest_mock.MockerFixture` + Mocker for event object + """ + event_parse = event_parser_instance + model = {"config": {"layers": [{"name": "decoder_a", + "config": {"output_layers": [["face_out_a", 0, 0]]}}, + {"name": "decoder_b", + "config": {"output_layers": [["face_out_b", 0, 0]]}}], + "output_layers": [["decoder_a", 1, 0], ["decoder_b", 1, 0]]}} + data = json.dumps(model).encode("utf-8") + + event = mocker.MagicMock() + event.summary.value.__getitem__ = lambda self, x: event + event.tensor.string_val.__getitem__ = lambda self, x: data + + assert not event_parse._loss_labels + event_parse._parse_outputs(event) + assert event_parse._loss_labels == ["face_out_a", "face_out_b"] + + def test__get_outputs(self, event_parser_instance: _EventParser) -> None: + """ Test _get_outputs works correctly + + Parameters + ---------- + event_parser_instance: :class:`lib.gui.analysis.event_reader._EventParser` + The class instance to test + """ + outputs = [["decoder_a", 1, 0], ["decoder_b", 1, 0]] + model_config = {"output_layers": outputs} + + expected = np.array([[out] for out in outputs]) + actual = event_parser_instance._get_outputs(model_config, is_sub_model=False) + assert isinstance(actual, np.ndarray) + assert actual.shape == (2, 1, 3) + np.testing.assert_equal(expected, actual) + + outputs = [["encoder", 1, 0]] + model_config = {"output_layers": outputs} + + expected = np.array([outputs]) + actual = event_parser_instance._get_outputs(model_config, is_sub_model=True) + assert isinstance(actual, np.ndarray) + assert actual.shape == (1, 1, 3) + np.testing.assert_equal(expected, actual) + + def test__process_event(self, event_parser_instance: _EventParser) -> None: + """ Test _process_event works correctly + + Parameters + ---------- + event_parser_instance: :class:`lib.gui.analysis.event_reader._EventParser` + The class instance to test + """ + event_parse = event_parser_instance + event_data = EventData() + assert not event_data.timestamp + assert not event_data.loss + timestamp = time() + loss = [1.1, 2.2] + event = self._create_example_event(1, 1.0, timestamp, serialize=False) # batch_total + event_parse._process_event(event, event_data) + event = self._create_example_event(2, loss[0], time(), serialize=False) # face A + event_parse._process_event(event, event_data) + event = self._create_example_event(3, loss[1], time(), serialize=False) # face B + event_parse._process_event(event, event_data) + + # Original timestamp and both loss values collected + assert event_data.timestamp == timestamp + np.testing.assert_almost_equal(event_data.loss, loss) # float rounding diff --git a/tests/lib/gui/stats/moving_average_test.py b/tests/lib/gui/stats/moving_average_test.py new file mode 100644 index 0000000000..cfa1881253 --- /dev/null +++ b/tests/lib/gui/stats/moving_average_test.py @@ -0,0 +1,111 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.gui.stats.moving_average` """ + +import numpy as np +import pytest + +from lib.gui.analysis.moving_average import ExponentialMovingAverage as EMA + +# pylint:disable=[protected-access,invalid-name] + + +_INIT_PARAMS = ((np.array([1, 2, 3], dtype="float32"), 0.0), + (np.array([4, 5, 6], dtype="float64"), 0.25), + (np.array([7, 8, 9], dtype="uint8"), 1.0), + (np.array([0, np.nan, 1], dtype="float32"), 0.74), + (np.array([2, 3, np.inf], dtype="float32"), 0.33), + (np.array([4, 5, 6], dtype="float32"), -1.0), + (np.array([7, 8, 9], dtype="float32"), 99.0)) +_INIT_IDS = ["float32", "float64", "uint8", "nan", "inf", "amount:-1", "amount:99"] + + +@pytest.mark.parametrize(("data", "amount"), _INIT_PARAMS, ids=_INIT_IDS) +def test_ExponentialMovingAverage_init(data: np.ndarray, amount: float): + """ Test that moving_average.MovingAverage correctly initializes """ + attrs = {"_data": np.ndarray, + "_alpha": float, + "_dtype": str, + "_row_size": int, + "_out": np.ndarray} + + instance = EMA(data, amount) + # Verify required attributes exist and are of the correct type + for attr, attr_type in attrs.items(): + assert attr in instance.__dict__ + assert isinstance(getattr(instance, attr), attr_type) + # Verify we are testing all existing attributes + for key in instance.__dict__: + assert key in attrs + + # Verify numeric sanitization + assert not np.any(np.isnan(instance._data)) + assert not np.any(np.isinf(instance._data)) + + # Check alpha clamp logic + expected_alpha = 1. - min(0.999, max(0.001, amount)) + assert instance._alpha == expected_alpha + + # dtype assignment logic + expected_dtype = "float32" if data.dtype == np.float32 else "float64" + assert instance._dtype == expected_dtype + + # ensure row size is positive and output matches shape and dtype + assert instance._row_size > 0 + assert instance._out.shape == data.shape + assert instance._out.dtype == expected_dtype + + +def naive_ewma(data: np.ndarray, alpha: float) -> np.ndarray: + """ A simple ewma implementation to test for correctness """ + out = np.empty_like(data, dtype=data.dtype) + out[0] = data[0] + for i in range(1, len(data)): + out[i] = alpha * data[i] + (1 - alpha) * out[i - 1] + return out + + +@pytest.mark.parametrize("alpha", [0.001, 0.01, 0.25, 0.33, 0.5, 0.66, 0.75, 0.90, 0.999]) +@pytest.mark.parametrize("dtype", ("float32", "float64")) +def test_ExponentialMovingAverage_matches_naive(alpha: float, dtype: str) -> None: + """ Make sure that we get sane results out for various data sizes against our reference + for various amounts """ + rows = max(5, int(np.random.random() * 25000)) + data = np.random.rand(rows).astype(dtype) + instance = EMA(data, 1 - alpha) + out = instance() + + ref = naive_ewma(data, alpha) + np.testing.assert_allclose(out, ref, rtol=3e-6, atol=3e-6) + + +@pytest.mark.parametrize("dtype", ("float32", "float64")) +def test_ExponentialMovingAverage_small_data(dtype: str) -> None: + """ Make sure we get sane results out of our small path """ + data = np.array([1., 2., 3.], dtype=dtype) + instance = EMA(data, 0.5) + out = instance() + ref = naive_ewma(data, instance._alpha) + np.testing.assert_allclose(out, ref) + + +@pytest.mark.parametrize("dtype", ("float32", "float64")) +def test_ExponentialMovingAverage_large_data_safe_path(dtype: str) -> None: + """ Make sure we get sane results out of our safe path """ + data = np.random.rand(50000).astype(dtype) + instance = EMA(data, 0.1) + # Force safe path + instance._row_size = 10 + + out = instance() + ref = naive_ewma(data, instance._alpha) + + np.testing.assert_allclose(out, ref, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("dtype", ("float32", "float64")) +def test_ExponentialMovingAverage_empty_input(dtype: str) -> None: + """ Test that we get no data on an empty input """ + data = np.array([], dtype=dtype) + instance = EMA(data, 0.5) + out = instance() + assert out.size == 0 diff --git a/tests/lib/model/__init__.py b/tests/lib/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/model/initializers_test.py b/tests/lib/model/initializers_test.py new file mode 100644 index 0000000000..5a39eb15fb --- /dev/null +++ b/tests/lib/model/initializers_test.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Initializers. + +Adapted from Keras tests. +""" + +import pytest +import numpy as np + +from keras import device, initializers as k_initializers, Variable + +from lib.model import initializers +from lib.utils import get_backend + +CONV_SHAPE = (3, 3, 256, 2048) +CONV_ID = get_backend().upper() + + +def _runner(init, shape, target_mean=None, target_std=None, + target_max=None, target_min=None): + with device("cpu"): + variable = Variable(init(shape)) + output = variable.numpy() + lim = 3e-2 + if target_std is not None: + assert abs(output.std() - target_std) < lim + if target_mean is not None: + assert abs(output.mean() - target_mean) < lim + if target_max is not None: + assert abs(output.max() - target_max) < lim + if target_min is not None: + assert abs(output.min() - target_min) < lim + + +@pytest.mark.parametrize('tensor_shape', [CONV_SHAPE], ids=[CONV_ID]) +def test_icnr(tensor_shape): + """ ICNR Initialization Test + + Parameters + ---------- + tensor_shape: tuple + The shape of the tensor to feed to the initializer + """ + with device("cpu"): + fan_in, _ = initializers.compute_fans(tensor_shape) + std = np.sqrt(2. / fan_in) + _runner(initializers.ICNR(initializer=k_initializers.he_uniform(), scale=2), + tensor_shape, + target_mean=0, + target_std=std) + + +@pytest.mark.parametrize('tensor_shape', [CONV_SHAPE], ids=[CONV_ID]) +def test_convolution_aware(tensor_shape): + """ Convolution Aware Initialization Test + + Parameters + ---------- + tensor_shape: tuple + The shape of the tensor to feed to the initializer + """ + with device("cpu"): + fan_in, _ = initializers.compute_fans(tensor_shape) + std = np.sqrt(2. / fan_in) + _runner(initializers.ConvolutionAware(seed=123), tensor_shape, + target_mean=0, target_std=std) diff --git a/tests/lib/model/layers_test.py b/tests/lib/model/layers_test.py new file mode 100644 index 0000000000..27743486c7 --- /dev/null +++ b/tests/lib/model/layers_test.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Custom Layers. + +Adapted from Keras tests. +""" + + +import pytest +import numpy as np + +from numpy.testing import assert_allclose + +from keras import device, Input, Model, backend as K + +from lib.model import layers +from lib.utils import get_backend +from tests.utils import has_arg + + +# pylint:disable=dangerous-default-value,too-many-locals,too-many-branches +def layer_test(layer_cls, # noqa:C901 + kwargs={}, + input_shape=None, + input_dtype=None, + input_data=None, + expected_output=None, + expected_output_dtype=None, + fixed_batch_size=False): + """Test routine for a layer with a single input tensor + and single output tensor. + """ + with device("cpu"): + # generate input data + # pylint:disable=duplicate-code + if input_data is None: + assert input_shape + if not input_dtype: + input_dtype = K.floatx() + input_data_shape = list(input_shape) + for i, var_e in enumerate(input_data_shape): + if var_e is None: + input_data_shape[i] = np.random.randint(1, 4) + input_data = 10 * np.random.random(input_data_shape) + input_data = input_data.astype(input_dtype) + else: + if input_shape is None: + input_shape = input_data.shape + if input_dtype is None: + input_dtype = input_data.dtype + if expected_output_dtype is None: + expected_output_dtype = input_dtype + + # instantiation + layer = layer_cls(**kwargs) + + # test get_weights , set_weights at layer level + weights = layer.get_weights() + layer.set_weights(weights) + + layer.build(input_shape) + expected_output_shape = layer.compute_output_shape(input_shape) + + # test in functional API + if fixed_batch_size: + inp = Input(batch_shape=input_shape, dtype=input_dtype) + else: + inp = Input(shape=input_shape[1:], dtype=input_dtype) + outp = layer(inp) + assert outp.dtype == expected_output_dtype + + # check with the functional API + model = Model(inp, outp) + + actual_output = model.predict(input_data, verbose=0) # type:ignore + actual_output_shape = actual_output.shape + for expected_dim, actual_dim in zip(expected_output_shape, + actual_output_shape): + if expected_dim is not None: + assert expected_dim == actual_dim + + if expected_output is not None: + assert_allclose(actual_output, expected_output, rtol=1e-3) + + # test serialization, weight setting at model level + model_config = model.get_config() + recovered_model = model.__class__.from_config(model_config) + if model.weights: + weights = model.get_weights() + recovered_model.set_weights(weights) + _output = recovered_model.predict(input_data, verbose=0) # type:ignore + assert_allclose(_output, actual_output, rtol=1e-3) + + # test training mode (e.g. useful when the layer has a + # different behavior at training and testing time). + if has_arg(layer.call, 'training'): + model.compile('rmsprop', 'mse') + model.train_on_batch(input_data, actual_output) + + # test instantiation from layer config + layer_config = layer.get_config() + layer = layer.__class__.from_config(layer_config) + + # for further checks in the caller function + return actual_output + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_global_min_pooling_2d(dummy): # pylint:disable=unused-argument + """ Global Min Pooling 2D layer test """ + layer_test(layers.GlobalMinPooling2D, input_shape=(2, 4, 4, 1024)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_global_std_pooling_2d(dummy): # pylint:disable=unused-argument + """ Global Standard Deviation Pooling 2D layer test """ + layer_test(layers.GlobalStdDevPooling2D, input_shape=(2, 4, 4, 1024)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_k_resize_images(dummy): # pylint:disable=unused-argument + """ Global Standard Deviation Pooling 2D layer test """ + layer_test(layers.KResizeImages, input_shape=(2, 4, 4, 1024)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_l2_normalize(dummy): # pylint:disable=unused-argument + """ L2 Normalize layer test """ + layer_test(layers.L2Normalize, kwargs={"axis": 1}, input_shape=(2, 4, 4, 1024)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_pixel_shuffler(dummy): # pylint:disable=unused-argument + """ Pixel Shuffler layer test """ + layer_test(layers.PixelShuffler, input_shape=(2, 4, 4, 1024)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_quick_gelu(dummy): # pylint:disable=unused-argument + """ Global Standard Deviation Pooling 2D layer test """ + layer_test(layers.QuickGELU, input_shape=(2, 4, 4, 1024)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_reflection_padding_2d(dummy): # pylint:disable=unused-argument + """ Reflection Padding 2D layer test """ + layer_test(layers.ReflectionPadding2D, input_shape=(2, 4, 4, 512)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_swish(dummy): # pylint:disable=unused-argument + """ Swish activation layer test """ + layer_test(layers.Swish, input_shape=(2, 4, 4, 1024)) + + +_PARAMS = ("multiply", "truediv", "add", "subtract") +_IDS = [f"{x}[{get_backend().upper()}]" for x in _PARAMS] + + +@pytest.mark.parametrize("operation", _PARAMS, ids=_IDS) +def test_scalar_op(operation): + """ Scalar operation layer test """ + val = 2.0 + np_ops = {"multiply": np.multiply, + "truediv": np.true_divide, + "add": np.add, + "subtract": np.subtract} + input_data = np.random.random((2, 4, 4, 1024)).astype("float32") + output_data = np_ops[operation](input_data, val) + layer_test(layers.ScalarOp, + kwargs={"operation": operation, "value": val}, + input_data=input_data, + expected_output=output_data) diff --git a/tests/lib/model/losses/feature_loss_test.py b/tests/lib/model/losses/feature_loss_test.py new file mode 100644 index 0000000000..42e4af2a61 --- /dev/null +++ b/tests/lib/model/losses/feature_loss_test.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Feature Losses. Adapted from Keras tests. """ +import pytest +import numpy as np +from keras import device, Variable + +# pylint:disable=import-error +from lib.model.losses.feature_loss import LPIPSLoss +from lib.utils import get_backend + + +_NETS = ("alex", "squeeze", "vgg16") +_IDS = [f"LPIPS_{x}[{get_backend().upper()}]" for x in _NETS] + + +@pytest.mark.parametrize("net", _NETS, ids=_IDS) +def test_loss_output(net): + """ Basic dtype and value tests for loss functions. """ + with device("cpu"): + y_a = Variable(np.random.random((2, 32, 32, 3))) + y_b = Variable(np.random.random((2, 32, 32, 3))) + objective_output = LPIPSLoss(net)(y_a, y_b) + output = objective_output.detach().numpy() # type:ignore + assert output.dtype == "float32" and not np.any(np.isnan(output)) + assert output < 0.1 # LPIPS loss is reduced 10x diff --git a/tests/lib/model/losses/loss_test.py b/tests/lib/model/losses/loss_test.py new file mode 100644 index 0000000000..bf35bb5117 --- /dev/null +++ b/tests/lib/model/losses/loss_test.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Losses. + +Adapted from Keras tests. +""" + +import pytest +import numpy as np + +from keras import device, losses as k_losses, Variable + +from lib.model.losses.loss import (FocalFrequencyLoss, GeneralizedLoss, GradientLoss, + LaplacianPyramidLoss, LInfNorm, LossWrapper) +from lib.model.losses.feature_loss import LPIPSLoss +from lib.model.losses.perceptual_loss import DSSIMObjective, GMSDLoss, LDRFLIPLoss, MSSIMLoss + +from lib.utils import get_backend + + +_PARAMS = ((FocalFrequencyLoss, 1.0), + (GeneralizedLoss, 1.0), + (GradientLoss, 200.0), + (LaplacianPyramidLoss, 1.0), + (LInfNorm, 1.0)) +_IDS = [f"{x[0].__name__}[{get_backend().upper()}]" for x in _PARAMS] + + +@pytest.mark.parametrize(["loss_func", "max_target"], _PARAMS, ids=_IDS) +def test_loss_output(loss_func, max_target): + """ Basic dtype and value tests for loss functions. """ + with device("cpu"): + y_a = Variable(np.random.random((2, 32, 32, 3))) + y_b = Variable(np.random.random((2, 32, 32, 3))) + objective_output = loss_func()(y_a, y_b) + output = objective_output.detach().numpy() + assert output.dtype == "float32" and not np.any(np.isnan(output)) + assert output < max_target + + +_LWPARAMS = [(FocalFrequencyLoss, ()), + (GeneralizedLoss, ()), + (GradientLoss, ()), + (LaplacianPyramidLoss, ()), + (LInfNorm, ()), + (LPIPSLoss, ("squeeze", )), + (DSSIMObjective, ()), + (GMSDLoss, ()), + (LDRFLIPLoss, ()), + (MSSIMLoss, ()), + (k_losses.LogCosh, ()), + (k_losses.MeanAbsoluteError, ()), + (k_losses.MeanSquaredError, ())] +_LWIDS = [f"{x[0].__name__}[{get_backend().upper()}]" for x in _LWPARAMS] + + +@pytest.mark.parametrize(["loss_func", "func_args"], _LWPARAMS, ids=_LWIDS) +def test_loss_wrapper(loss_func, func_args): + """ Test penalized loss wrapper works as expected """ + with device("cpu"): + p_loss = LossWrapper() + p_loss.add_loss(loss_func(*func_args), 1.0, -1) + p_loss.add_loss(k_losses.MeanSquaredError(), 2.0, 3) + y_a = Variable(np.random.random((2, 32, 32, 4))) + y_b = Variable(np.random.random((2, 32, 32, 3))) + + output = p_loss(y_a, y_b) + output = output.detach().numpy() # type:ignore + assert output.dtype == "float32" and not np.any(np.isnan(output)) diff --git a/tests/lib/model/losses/perceptual_loss_test.py b/tests/lib/model/losses/perceptual_loss_test.py new file mode 100644 index 0000000000..9a37829323 --- /dev/null +++ b/tests/lib/model/losses/perceptual_loss_test.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Feature Losses. Adapted from Keras tests. """ +import pytest +import numpy as np +from keras import device, Variable + +# pylint:disable=import-error +from lib.model.losses.perceptual_loss import DSSIMObjective, GMSDLoss, LDRFLIPLoss, MSSIMLoss +from lib.utils import get_backend + + +_PARAMS = [DSSIMObjective, GMSDLoss, LDRFLIPLoss, MSSIMLoss] +_IDS = [f"{x.__name__}[{get_backend().upper()}]" for x in _PARAMS] + + +@pytest.mark.parametrize("loss_func", _PARAMS, ids=_IDS) +def test_loss_output(loss_func): + """ Basic dtype and value tests for loss functions. """ + with device("cpu"): + y_a = Variable(np.random.random((2, 32, 32, 3))) + y_b = Variable(np.random.random((2, 32, 32, 3))) + objective_output = loss_func()(y_a, y_b) + output = objective_output.detach().numpy() # type:ignore + assert output.dtype == "float32" and not np.any(np.isnan(output)) + assert output < 1.0 diff --git a/tests/lib/model/nn_blocks_test.py b/tests/lib/model/nn_blocks_test.py new file mode 100644 index 0000000000..34c0b2fe66 --- /dev/null +++ b/tests/lib/model/nn_blocks_test.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Custom Layers. + +Adapted from Keras tests. +""" + +from itertools import product + +import pytest +import numpy as np + +from numpy.testing import assert_allclose + +from keras import device, Input, Model, backend as K + +from lib.model import nn_blocks +from lib.utils import get_backend +from plugins.train import train_config as cfg +# pylint:disable=unused-import +from tests.lib.config.helpers import patch_config # noqa:[F401] + + +def block_test(layer_func, # pylint:disable=dangerous-default-value,too-many-locals + kwargs={}, + input_shape=None): + """Test routine for faceswap neural network blocks. """ + # generate input data + # pylint:disable=duplicate-code + assert input_shape + input_dtype = K.floatx() + input_data_shape = list(input_shape) + for i, var_e in enumerate(input_data_shape): + if var_e is None: + input_data_shape[i] = np.random.randint(1, 4) + input_data = 10 * np.random.random(input_data_shape) + input_data = input_data.astype(input_dtype) + expected_output_dtype = input_dtype + + # test in functional API + inp = Input(shape=input_shape[1:], dtype=input_dtype) + outp = layer_func(inp, **kwargs) + assert outp.dtype == expected_output_dtype + + # check with the functional API + model = Model(inp, outp) + + actual_output = model.predict(input_data, verbose=0) # type:ignore + + # test serialization, weight setting at model level + model_config = model.get_config() + recovered_model = model.__class__.from_config(model_config) + if model.weights: + weights = model.get_weights() + recovered_model.set_weights(weights) + _output = recovered_model.predict(input_data, verbose=0) # type:ignore + assert_allclose(_output, actual_output, rtol=1e-3) + + # for further checks in the caller function + return actual_output + + +_PARAMS = ["use_icnr_init", "use_convaware_init", "use_reflect_padding"] +_VALUES = list(product([True, False], repeat=len(_PARAMS))) +_IDS = [f"{'|'.join([_PARAMS[idx] for idx, b in enumerate(v) if b])}[{get_backend().upper()}]" + for v in _VALUES] + + +@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS) +def test_blocks(use_icnr_init, + use_convaware_init, + use_reflect_padding, + patch_config): # pylint:disable=redefined-outer-name # noqa:[F811] + """ Test for all blocks contained within the NNBlocks Class """ + config = {"icnr_init": use_icnr_init, + "conv_aware_init": use_convaware_init, + "reflect_padding": use_reflect_padding} + patch_config(cfg, config) + with device("cpu"): + block_test(nn_blocks.Conv2DOutput(64, 3), input_shape=(2, 8, 8, 32)) + block_test(nn_blocks.Conv2DBlock(64), input_shape=(2, 8, 8, 32)) + block_test(nn_blocks.SeparableConv2DBlock(64), input_shape=(2, 8, 8, 32)) + block_test(nn_blocks.UpscaleBlock(64), input_shape=(2, 4, 4, 128)) + block_test(nn_blocks.Upscale2xBlock(64, fast=True), input_shape=(2, 4, 4, 128)) + block_test(nn_blocks.Upscale2xBlock(64, fast=False), input_shape=(2, 4, 4, 128)) + block_test(nn_blocks.ResidualBlock(64), input_shape=(2, 4, 4, 64)) diff --git a/tests/lib/model/normalization_test.py b/tests/lib/model/normalization_test.py new file mode 100644 index 0000000000..d5f03fb396 --- /dev/null +++ b/tests/lib/model/normalization_test.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Normalization. + +Adapted from Keras tests. +""" +from itertools import product + +import numpy as np +import pytest + +from keras import device, regularizers, models, layers + +from lib.model import normalization +from lib.utils import get_backend + +from tests.lib.model.layers_test import layer_test + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_instance_normalization(dummy): # pylint:disable=unused-argument + """ Basic test for instance normalization. """ + layer_test(normalization.InstanceNormalization, + kwargs={'epsilon': 0.1, + 'gamma_regularizer': regularizers.l2(0.01), + 'beta_regularizer': regularizers.l2(0.01)}, + input_shape=(3, 4, 2)) + layer_test(normalization.InstanceNormalization, + kwargs={'epsilon': 0.1, + 'axis': 1}, + input_shape=(1, 4, 1)) + layer_test(normalization.InstanceNormalization, + kwargs={'gamma_initializer': 'ones', + 'beta_initializer': 'ones'}, + input_shape=(3, 4, 2, 4)) + layer_test(normalization.InstanceNormalization, + kwargs={'epsilon': 0.1, + 'axis': 1, + 'scale': False, + 'center': False}, + input_shape=(3, 4, 2, 4)) + + +@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()]) +def test_group_normalization(dummy): # pylint:disable=unused-argument + """ Basic test for instance normalization. """ + layer_test(normalization.GroupNormalization, + kwargs={'epsilon': 0.1, + 'gamma_regularizer': regularizers.l2(0.01), + 'beta_regularizer': regularizers.l2(0.01)}, + input_shape=(4, 3, 4, 128)) + layer_test(normalization.GroupNormalization, + kwargs={'epsilon': 0.1, + 'axis': 1}, + input_shape=(4, 1, 4, 256)) + layer_test(normalization.GroupNormalization, + kwargs={'gamma_init': 'ones', + 'beta_init': 'ones'}, + input_shape=(4, 64)) + layer_test(normalization.GroupNormalization, + kwargs={'epsilon': 0.1, + 'axis': 1, + 'group': 16}, + input_shape=(3, 64)) + + +_PARAMS_NORM = ["center", "scale"] +_VALUES_NORM = list(product([True, False], repeat=len(_PARAMS_NORM))) +_IDS = [f"{'|'.join([_PARAMS_NORM[idx] for idx, b in enumerate(v) if b])}[{get_backend().upper()}]" + for v in _VALUES_NORM] + + +@pytest.mark.parametrize(_PARAMS_NORM, _VALUES_NORM, ids=_IDS) +def test_adain_normalization(center, scale): + """ Basic test for Ada Instance Normalization. """ + with device("cpu"): + norm = normalization.AdaInstanceNormalization(center=center, scale=scale) + shapes = [(4, 8, 8, 1280), (4, 1, 1, 1280), (4, 1, 1, 1280)] + norm.build(shapes) + expected_output_shape = norm.compute_output_shape(shapes) + inputs = [layers.Input(shape=shapes[0][1:]), + layers.Input(shape=shapes[1][1:]), + layers.Input(shape=shapes[2][1:])] + model = models.Model(inputs, norm(inputs)) + data = [10 * np.random.random(shape) for shape in shapes] + + actual_output = model.predict(data, verbose=0) + actual_output_shape = actual_output.shape + + for expected_dim, actual_dim in zip(expected_output_shape, + actual_output_shape): + if expected_dim is not None: + assert expected_dim == actual_dim + + +_PARAMS = ["partial", "bias"] +_VALUES = [(0.0, False), (0.25, False), (0.5, True), (0.75, False), (1.0, True)] # type:ignore +_IDS = [f"partial={v[0]}|bias={v[1]}[{get_backend().upper()}]" for v in _VALUES] + + +@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS) +def test_rms_normalization(partial, bias): # pylint:disable=unused-argument + """ Basic test for RMS Layer normalization. """ + layer_test(normalization.RMSNormalization, + kwargs={"partial": partial, "bias": bias}, + input_shape=(4, 512)) diff --git a/tests/lib/model/optimizers_test.py b/tests/lib/model/optimizers_test.py new file mode 100644 index 0000000000..3f986ace50 --- /dev/null +++ b/tests/lib/model/optimizers_test.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +""" Tests for Faceswap Initializers. + +Adapted from Keras tests. +""" +import pytest + +import numpy as np + +from keras import device, layers as kl, optimizers as k_optimizers, Sequential + +from lib.model import optimizers +from lib.utils import get_backend + +from tests.utils import generate_test_data, to_categorical + + +def get_test_data(): + """ Obtain randomized test data for training """ + np.random.seed(1337) + (x_train, y_train), _ = generate_test_data(num_train=1000, + num_test=200, + input_shape=(10,), + classification=True, + num_classes=2) + y_train = to_categorical(y_train) + return x_train, y_train + + +def _test_optimizer(optimizer, target=0.75): + x_train, y_train = get_test_data() + + model = Sequential() + model.add(kl.Input((x_train.shape[1], ))) + model.add(kl.Dense(10)) + model.add(kl.Activation("relu")) + model.add(kl.Dense(y_train.shape[1])) + model.add(kl.Activation("softmax")) + model.compile(loss="categorical_crossentropy", + optimizer=optimizer, + metrics=["accuracy"]) + + history = model.fit(x_train, y_train, epochs=2, batch_size=16, verbose=0) # type:ignore + assert history.history["accuracy"][-1] >= target + config = k_optimizers.serialize(optimizer) + optim = k_optimizers.deserialize(config) + new_config = k_optimizers.serialize(optim) + config["class_name"] = config["class_name"].lower() # type:ignore + new_config["class_name"] = new_config["class_name"].lower() # type:ignore + assert config == new_config + + +# TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +@pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()]) +def test_adabelief(dummy): # pylint:disable=unused-argument + """ Test for custom Adam optimizer """ + with device("cpu"): + _test_optimizer(optimizers.AdaBelief(), target=0.20) diff --git a/tests/lib/system/__init__.py b/tests/lib/system/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/system/sysinfo_test.py b/tests/lib/system/sysinfo_test.py new file mode 100644 index 0000000000..ae5bfdf76c --- /dev/null +++ b/tests/lib/system/sysinfo_test.py @@ -0,0 +1,258 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.system.sysinfo` """ + +import platform +import typing as T + +from collections import namedtuple +from io import StringIO +from unittest.mock import MagicMock + +import pytest +import pytest_mock + +# pylint:disable=import-error +from lib.gpu_stats import GPUInfo +from lib.system.sysinfo import _Configs, _State, _SysInfo, get_sysinfo +from lib.system import Cuda, Packages, ROCm, System + +# pylint:disable=protected-access + + +# _SysInfo +@pytest.fixture(name="sys_info_instance") +def sys_info_fixture() -> _SysInfo: + """ Single :class:`~lib.system.sysinfo._SysInfo` object for tests """ + return _SysInfo() + + +def test_init(sys_info_instance: _SysInfo) -> None: + """ Test :class:`lib.system.sysinfo._SysInfo` __init__ and attributes """ + assert isinstance(sys_info_instance, _SysInfo) + + attrs = ["_state_file", "_configs", "_system", + "_python", "_packages", "_gpu", "_cuda", "_rocm"] + assert all(a in sys_info_instance.__dict__ for a in attrs) + assert all(a in attrs for a in sys_info_instance.__dict__) + + assert isinstance(sys_info_instance._state_file, str) + assert isinstance(sys_info_instance._configs, str) + assert isinstance(sys_info_instance._system, System) + assert isinstance(sys_info_instance._python, dict) + assert sys_info_instance._python == {"implementation": platform.python_implementation(), + "version": platform.python_version()} + assert isinstance(sys_info_instance._packages, Packages) + assert isinstance(sys_info_instance._gpu, GPUInfo) + assert isinstance(sys_info_instance._cuda, Cuda) + assert isinstance(sys_info_instance._rocm, ROCm) + + +def test_properties(sys_info_instance: _SysInfo) -> None: + """ Test :class:`lib.system.sysinfo._SysInfo` properties """ + ints = ["_ram_free", "_ram_total", "_ram_available", "_ram_used"] + strs = ["_fs_command", "_conda_version", "_git_commits", "_cuda_versions", + "_cuda_version", "_cudnn_versions", "_rocm_version", "_rocm_versions"] + + for prop in ints: + assert hasattr(sys_info_instance, prop), f"sysinfo missing property '{prop}'" + assert isinstance(getattr(sys_info_instance, prop), + int), f"sysinfo property '{prop}' not int" + + for prop in strs: + assert hasattr(sys_info_instance, prop), f"sysinfo missing property '{prop}'" + assert isinstance(getattr(sys_info_instance, prop), + str), f"sysinfo property '{prop}' not str" + + +def test_get_gpu_info(sys_info_instance: _SysInfo) -> None: + """ Test _get_gpu_info method of :class:`lib.system.sysinfo._SysInfo` returns as expected """ + assert hasattr(sys_info_instance, "_get_gpu_info") + gpu_info = sys_info_instance._get_gpu_info() + assert isinstance(gpu_info, GPUInfo) + + +def test__format_ram(sys_info_instance: _SysInfo, monkeypatch: pytest.MonkeyPatch) -> None: + """ Test the _format_ram method of :class:`lib.system.sysinfo._SysInfo` """ + assert hasattr(sys_info_instance, "_format_ram") + svmem = namedtuple("svmem", ["available", "free", "total", "used"]) + data = svmem(12345678, 1234567, 123456789, 123456) + monkeypatch.setattr("psutil.virtual_memory", lambda *args, **kwargs: data) + ram_info = sys_info_instance._format_ram() + + assert isinstance(ram_info, str) + assert ram_info == "Total: 117MB, Available: 11MB, Used: 0MB, Free: 1MB" + + +def test_full_info(sys_info_instance: _SysInfo) -> None: + """ Test the full_info method of :class:`lib.system.sysinfo._SysInfo` returns as expected """ + assert hasattr(sys_info_instance, "full_info") + sys_info = sys_info_instance.full_info() + assert isinstance(sys_info, str) + + sections = ["System Information", "Pip Packages", "Configs"] + for section in sections: + assert section in sys_info, f"Section {section} not in full_info" + if sys_info_instance._system.is_conda: + assert "Conda Packages" in sys_info + else: + assert "Conda Packages" not in sys_info + + keys = ["backend", "os_platform", "os_machine", "os_release", "py_conda_version", + "py_implementation", "py_version", "py_command", "py_virtual_env", "sys_cores", + "sys_processor", "sys_ram", "encoding", "git_branch", "git_commits", + "gpu_cuda_versions", "gpu_cuda", "gpu_cudnn", "gpu_rocm_versions", "gpu_rocm_version", + "gpu_driver", "gpu_devices", "gpu_vram", "gpu_devices_active"] + for key in keys: + assert f"{key}:" in sys_info, f"'{key}:' not in full_info" + + +# get_sys_info +def test_get_sys_info(mocker: pytest_mock.MockerFixture) -> None: + """ Thest that the :func:`~lib.utils.sysinfo.get_sysinfo` function executes correctly """ + sys_info = get_sysinfo() + assert isinstance(sys_info, str) + full_info = mocker.patch("lib.system.sysinfo._SysInfo.full_info") + get_sysinfo() + assert full_info.called + + +# _Configs +@pytest.fixture(name="configs_instance") +def configs_fixture(): + """ Pytest fixture for :class:`~lib.utils.sysinfo._Configs` """ + return _Configs() + + +def test__configs__init__(configs_instance: _Configs) -> None: + """ Test __init__ and attributes for :class:`~lib.utils.sysinfo._Configs` """ + assert hasattr(configs_instance, "config_dir") + assert isinstance(configs_instance.config_dir, str) + assert hasattr(configs_instance, "configs") + assert isinstance(configs_instance.configs, str) + + +def test__configs__get_configs(configs_instance: _Configs) -> None: + """ Test __init__ and attributes for :class:`~lib.utils.sysinfo._Configs` """ + assert hasattr(configs_instance, "_get_configs") + assert isinstance(configs_instance._get_configs(), str) + + +def test__configs__parse_configs(configs_instance: _Configs, + mocker: pytest_mock.MockerFixture) -> None: + """ Test _parse_configs function for :class:`~lib.utils.sysinfo._Configs` """ + assert hasattr(configs_instance, "_parse_configs") + assert isinstance(configs_instance._parse_configs([]), str) + configs_instance._parse_ini = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + configs_instance._parse_json = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + configs_instance._parse_configs(config_files=["test.ini", ".faceswap"]) + assert configs_instance._parse_ini.called + assert configs_instance._parse_json.called + + +def test__configs__parse_ini(configs_instance: _Configs, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test _parse_ini function for :class:`~lib.utils.sysinfo._Configs` """ + assert hasattr(configs_instance, "_parse_ini") + + file = ("[test.ini_header]\n" + "# Test Header\n\n" + "param = value") + monkeypatch.setattr("builtins.open", lambda *args, **kwargs: StringIO(file)) + + converted = configs_instance._parse_ini("test.ini") + assert isinstance(converted, str) + assert converted == ("\n[test.ini_header]\n" + "param: value\n") + + +def test__configs__parse_json(configs_instance: _Configs, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test _parse_json function for :class:`~lib.utils.sysinfo._Configs` """ + assert hasattr(configs_instance, "_parse_json") + file = '{"test": "param"}' + monkeypatch.setattr("builtins.open", lambda *args, **kwargs: StringIO(file)) + + converted = configs_instance._parse_json(".file") + assert isinstance(converted, str) + assert converted == ("test: param\n") + + +def test__configs__format_text(configs_instance: _Configs) -> None: + """ Test _format_text function for :class:`~lib.utils.sysinfo._Configs` """ + assert hasattr(configs_instance, "_format_text") + key, val = " test_key ", "test_val " + formatted = configs_instance._format_text(key, val) + assert isinstance(formatted, str) + assert formatted == "test_key: test_val\n" + + +# _State +@pytest.fixture(name="state_instance") +def state_fixture(): + """ Pytest fixture for :class:`~lib.utils.sysinfo._State` """ + return _State() + + +def test__state__init__(state_instance: _State) -> None: + """ Test __init__ and attributes for :class:`~lib.utils.sysinfo._State` """ + assert hasattr(state_instance, "_model_dir") + assert state_instance._model_dir is None + assert hasattr(state_instance, "_trainer") + assert state_instance._trainer is None + assert hasattr(state_instance, "state_file") + assert isinstance(state_instance.state_file, str) + + +def test__state__is_training(state_instance: _State, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test _is_training function for :class:`~lib.utils.sysinfo._State` """ + assert hasattr(state_instance, "_is_training") + assert isinstance(state_instance._is_training, bool) + assert not state_instance._is_training + monkeypatch.setattr("sys.argv", ["faceswap.py", "train"]) + assert state_instance._is_training + monkeypatch.setattr("sys.argv", ["faceswap.py", "extract"]) + assert not state_instance._is_training + + +def test__state__get_arg(state_instance: _State, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test _get_arg function for :class:`~lib.utils.sysinfo._State` """ + assert hasattr(state_instance, "_get_arg") + assert state_instance._get_arg("-t", "--test_arg") is None + monkeypatch.setattr("sys.argv", ["test", "command", "-t", "test_option"]) + assert state_instance._get_arg("-t", "--test_arg") == "test_option" + + +def test__state__get_state_file(state_instance: _State, + mocker: pytest_mock.MockerFixture, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test _get_state_file function for :class:`~lib.utils.sysinfo._State` """ + assert hasattr(state_instance, "_get_state_file") + assert isinstance(state_instance._get_state_file(), str) + + mock_is_training = mocker.patch("lib.system.sysinfo._State._is_training") + + # Not training or missing training arguments + mock_is_training.return_value = False + assert state_instance._get_state_file() == "" + mock_is_training.return_value = False + + monkeypatch.setattr(state_instance, "_model_dir", None) + assert state_instance._get_state_file() == "" + monkeypatch.setattr(state_instance, "_model_dir", "test_dir") + + monkeypatch.setattr(state_instance, "_trainer", None) + assert state_instance._get_state_file() == "" + monkeypatch.setattr(state_instance, "_trainer", "test_trainer") + + # Training but file not found + assert state_instance._get_state_file() == "" + + # State file is just a json dump + file = ('{\n' + ' "test": "json",\n' + '}') + monkeypatch.setattr("os.path.isfile", lambda *args, **kwargs: True) + monkeypatch.setattr("builtins.open", lambda *args, **kwargs: StringIO(file)) + assert state_instance._get_state_file().endswith(file) diff --git a/tests/lib/system/system_test.py b/tests/lib/system/system_test.py new file mode 100644 index 0000000000..e608d4185c --- /dev/null +++ b/tests/lib/system/system_test.py @@ -0,0 +1,256 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.system.system` """ + +import ctypes +import locale +import os +import platform +import sys + +import pytest +import pytest_mock + +# pylint:disable=import-error +import lib.system.system as system_mod +from lib.system.system import _lines_from_command, VALID_PYTHON, Packages, System +# pylint:disable=protected-access + + +def test_valid_python() -> None: + """ Confirm python version has a min and max and that it is Python 3 """ + assert len(VALID_PYTHON) == 2 + assert all(len(v) == 2 for v in VALID_PYTHON) + assert all(isinstance(x, int) for v in VALID_PYTHON for x in v) + assert all(v[0] == 3 for v in VALID_PYTHON) + assert VALID_PYTHON[0] <= VALID_PYTHON[1] + + +def test_lines_from_command(mocker: pytest_mock.MockerFixture) -> None: + """ Confirm lines from command executes as expected """ + input_ = ["test", "input"] + subproc_out = " this \nis\n test\noutput \n" + mock_run = mocker.patch("lib.system.system.run") + mock_run.return_value.stdout = subproc_out + result = _lines_from_command(input_) + assert mock_run.called + assert result == subproc_out.splitlines() + + +# System +@pytest.fixture(name="system_instance") +def system_fixture() -> System: + """ Single :class:`lib.system.System` object for tests """ + return System() + + +def test_system_init(system_instance: System) -> None: + """ Test :class:`lib.system.System` __init__ and attributes """ + assert isinstance(system_instance, System) + + attrs = ["platform", "system", "machine", "release", "processor", "cpu_count", + "python_implementation", "python_version", "python_architecture", "encoding", + "is_conda", "is_admin", "is_virtual_env"] + assert all(a in system_instance.__dict__ for a in attrs) + assert all(a in attrs for a in system_instance.__dict__) + + assert system_instance.platform == platform.platform() + assert system_instance.system == platform.system().lower() + assert system_instance.machine == platform.machine() + assert system_instance.release == platform.release() + assert system_instance.processor == platform.processor() + assert system_instance.cpu_count == os.cpu_count() + assert system_instance.python_implementation == platform.python_implementation() + assert system_instance.python_version == platform.python_version() + assert system_instance.python_architecture == platform.architecture()[0] + assert system_instance.encoding == locale.getpreferredencoding() + assert system_instance.is_conda == ("conda" in sys.version.lower() or + os.path.exists(os.path.join(sys.prefix, "conda-meta"))) + assert isinstance(system_instance.is_admin, bool) + assert isinstance(system_instance.is_virtual_env, bool) + + +def test_system_properties(system_instance: System) -> None: + """ Test :class:`lib.system.System` properties """ + assert hasattr(system_instance, "is_linux") + assert isinstance(system_instance.is_linux, bool) + if platform.system().lower() == "linux": + assert system_instance.is_linux + assert not system_instance.is_macos + assert not system_instance.is_windows + + assert hasattr(system_instance, "is_macos") + assert isinstance(system_instance.is_macos, bool) + if platform.system().lower() == "darwin": + assert system_instance.is_macos + assert not system_instance.is_linux + assert not system_instance.is_windows + + assert hasattr(system_instance, "is_windows") + assert isinstance(system_instance.is_windows, bool) + if platform.system().lower() == "windows": + assert system_instance.is_windows + assert not system_instance.is_linux + assert not system_instance.is_macos + + +def test_system_get_permissions(system_instance: System) -> None: + """ Test :class:`lib.system.System` _get_permissions method """ + assert hasattr(system_instance, "_get_permissions") + is_admin = system_instance._get_permissions() + if platform.system() == "Windows": + assert is_admin == (ctypes.windll.shell32.IsUserAnAdmin() != 0) # type:ignore + else: + assert is_admin == (os.getuid() == 0) # type:ignore # pylint:disable=no-member + + +def test_system_check_virtual_env(system_instance: System, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test :class:`lib.system.System` _check_virtual_env method """ + system_instance.is_conda = True + monkeypatch.setattr(system_mod.sys, "prefix", "/home/user/miniconda3/envs/testenv") + assert system_instance._check_virtual_env() + monkeypatch.setattr(system_mod.sys, "prefix", "/home/user/miniconda3/bin/") + assert not system_instance._check_virtual_env() + + system_instance.is_conda = False + monkeypatch.setattr(system_mod.sys, "base_prefix", "/home/user/venv/") + monkeypatch.setattr(system_mod.sys, "prefix", "/usr/bin/") + assert system_instance._check_virtual_env() + monkeypatch.setattr(system_mod.sys, "base_prefix", "/usr/bin/") + assert not system_instance._check_virtual_env() + + +def test_system_validate_python(system_instance: System, + monkeypatch: pytest.MonkeyPatch, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`lib.system.System` _validate_python method """ + monkeypatch.setattr(system_mod, "VALID_PYTHON", (((3, 11), (3, 13)))) + monkeypatch.setattr(system_mod.sys, "version_info", (3, 12, 0)) + monkeypatch.setattr("builtins.input", lambda _: "") + system_instance.python_architecture = "64bit" + + assert system_instance.validate_python() + assert system_instance.validate_python(max_version=(3, 12)) + + sys_exit = mocker.patch("lib.system.system.sys.exit") + system_instance.python_architecture = "32bit" + system_instance.validate_python() + assert sys_exit.called + system_instance.python_architecture = "64bit" + + system_instance.validate_python(max_version=(3, 11)) + assert sys_exit.called + + for vers in ((3, 10, 0), (3, 14, 0)): + monkeypatch.setattr(system_mod.sys, "version_info", vers) + system_instance.validate_python() + assert sys_exit.called + + +@pytest.mark.parametrize("system_name, machine, is_conda, should_exit", [ + ("other", "x86_64", False, True), # Unsupported OS + ("darwin", "arm64", True, False), # Apple Silicon inside conda + ("darwin", "arm64", False, True), # Apple Silicon outside conda + ("linux", "x86_64", True, False), # Supported + ("windows", "x86_64", True, False), # Supported + ]) +def test_system_validate(system_instance: System, + mocker: pytest_mock.MockerFixture, + system_name, + machine, + is_conda, + should_exit) -> None: + """ Test :class:`lib.system.System` _validate method """ + validate_python = mocker.patch("lib.system.System.validate_python") + system_instance.system = system_name + system_instance.machine = machine + system_instance.is_conda = is_conda + sys_exit = mocker.patch("lib.system.system.sys.exit") + system_instance.validate() + if should_exit: + assert sys_exit.called + else: + assert not sys_exit.called + assert validate_python.called + + +# Packages +@pytest.fixture(name="packages_instance") +def packages_fixture() -> Packages: + """ Single :class:`lib.system.Packages` object for tests """ + return Packages() + + +def test_packages_init(packages_instance: Packages, mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`lib.system.Packages` __init__ and attributes """ + assert isinstance(packages_instance, Packages) + + attrs = ["_conda_exe", "_installed_python", "_installed_conda"] + assert all(a in packages_instance.__dict__ for a in attrs) + assert all(a in attrs for a in packages_instance.__dict__) + + assert isinstance(packages_instance._conda_exe, + str) or packages_instance._conda_exe is None + assert isinstance(packages_instance._installed_python, dict) + assert isinstance(packages_instance._installed_conda, + list) or packages_instance._installed_conda is None + + which = mocker.patch("lib.system.system.which") + Packages() + which.assert_called_once_with("conda") + + +def test_packages_properties(packages_instance: Packages) -> None: + """ Test :class:`lib.system.Packages` properties """ + for prop in ("installed_python", "installed_conda"): + assert hasattr(packages_instance, prop) + assert isinstance(getattr(packages_instance, prop), dict) + pretty = f"{prop}_pretty" + assert hasattr(packages_instance, pretty) + assert isinstance(getattr(packages_instance, pretty), str) + + +def test_packages_get_installed_python(packages_instance: Packages, + mocker: pytest_mock.MockerFixture, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test :class:`lib.system.Packages` get_installed_python method """ + lines_from_command = mocker.patch("lib.system.system._lines_from_command") + monkeypatch.setattr(system_mod.sys, "executable", "python") + out = packages_instance._get_installed_python() + lines_from_command.assert_called_once_with(["python", "-m", "pip", "freeze", "--local"]) + assert isinstance(out, dict) + + monkeypatch.setattr(system_mod, "_lines_from_command", lambda _: ["pacKage1==1.0.0", + "PACKAGE2==1.1.0", + "# Ignored", + "malformed=1.2.3", + "package3==0.2.1"]) + out = packages_instance._get_installed_python() + assert out == {"package1": "1.0.0", "package2": "1.1.0", "package3": "0.2.1"} + + +def test_packages_get_installed_conda(packages_instance: Packages, + mocker: pytest_mock.MockerFixture, + monkeypatch: pytest.MonkeyPatch) -> None: + """ Test :class:`lib.system.Packages` get_installed_conda method """ + packages_instance._conda_exe = None + packages_instance._installed_conda = None + packages_instance._get_installed_conda() + assert packages_instance._installed_conda is None + + packages_instance._conda_exe = "conda" + lines_from_command = mocker.patch("lib.system.system._lines_from_command") + packages_instance._get_installed_conda() + lines_from_command.assert_called_once_with(["conda", "list", "--show-channel-urls"]) + + monkeypatch.setattr(system_mod, "_lines_from_command", lambda _: []) + packages_instance._get_installed_conda() + assert packages_instance._installed_conda == ["Could not get Conda package list"] + + _pkgs = [ + "package1 4.15.0 pypi_0 pypi", + "pkg2 2025b h78e105d_0 conda-forge", + "Packag3 3.1.3 pypi_0 defaults"] + monkeypatch.setattr(system_mod, "_lines_from_command", lambda _: _pkgs) + packages_instance._get_installed_conda() + assert packages_instance._installed_conda == _pkgs diff --git a/tests/lib/training/__init__.py b/tests/lib/training/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lib/training/augmentation_test.py b/tests/lib/training/augmentation_test.py new file mode 100644 index 0000000000..9687e6b601 --- /dev/null +++ b/tests/lib/training/augmentation_test.py @@ -0,0 +1,524 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.training.augmentation` """ +import typing as T + +import cv2 +import numpy as np +import pytest +import pytest_mock + +from lib.config import ConfigValueType +from lib.training.augmentation import (ConstantsAugmentation, ConstantsColor, ConstantsTransform, + ConstantsWarp, ImageAugmentation) +from plugins.train.trainer import trainer_config as cfg + +# pylint:disable=unused-import +from tests.lib.config.helpers import patch_config # noqa[F401] + +# pylint:disable=protected-access,redefined-outer-name + + +MODULE_PREFIX = "lib.training.augmentation" + + +# CONSTANTS # +_CLAHE_CONF = (({"color_clahe_chance": 12, "color_clahe_max_size": 2}, 64), + ({"color_clahe_chance": 25, "color_clahe_max_size": 4}, 128), + ({"color_clahe_chance": 50, "color_clahe_max_size": 6}, 256), + ({"color_clahe_chance": 75, "color_clahe_max_size": 8}, 384)) + + +@pytest.mark.parametrize(("config", "size"), _CLAHE_CONF, ids=[x[-1] for x in _CLAHE_CONF]) +def test_constants_get_clahe(config: dict[str, T.Any], + size: int, + patch_config) -> None: # noqa[F811] + """ Test ConstantsAugmentation._get_clahe works as expected """ + patch_config(cfg, config) + contrast, chance, max_size = ConstantsAugmentation._get_clahe(size) + assert isinstance(contrast, int) + assert isinstance(chance, float) + assert isinstance(max_size, int) + assert contrast == max(2, size // 128) + assert chance == config["color_clahe_chance"] / 100. + assert max_size == config["color_clahe_max_size"] + + +_LAB_CONF = ({"color_lightness": 30, "color_ab": 8}, + {"color_lightness": 8, "color_ab": 25}, + {"color_lightness": 63, "color_ab": 12}) + + +@pytest.mark.parametrize(("config"), _LAB_CONF) +def test_constants_get_lab(config: dict[str, T.Any], patch_config) -> None: # noqa[F811] + """ Test ConstantsAugmentation._get_lab works as expected """ + patch_config(cfg, config) + lab_adjust = ConstantsAugmentation._get_lab() + assert isinstance(lab_adjust, np.ndarray) + assert lab_adjust.dtype == np.float32 + assert lab_adjust.shape == (3, ) + assert lab_adjust[0] == config["color_lightness"] / 100. + assert lab_adjust[1] == config["color_ab"] / 100. + assert lab_adjust[2] == config["color_ab"] / 100. + + +_CLAHE_LAB_CONF = ( + {"color_clahe_chance": 50, "color_clahe_max_size": 4.0, "color_lightness": 30, "color_ab": 8}, + {"color_clahe_chance": 30, "color_clahe_max_size": 6.0, "color_lightness": 20, "color_ab": 6}, + {"color_clahe_chance": 75, "color_clahe_max_size": 8.0, "color_lightness": 10, "color_ab": 12}) + + +@pytest.mark.parametrize(("config"), _CLAHE_LAB_CONF) +def test_constants_get_color(config: dict[str, T.Any], + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture) -> None: + """ Test ConstantsAugmentation._get_color works as expected """ + patch_config(cfg, config) + clahe_mock = mocker.patch(f"{MODULE_PREFIX}.ConstantsAugmentation._get_clahe", + return_value=(1, 2.0, 3)) + lab_mock = mocker.patch(f"{MODULE_PREFIX}.ConstantsAugmentation._get_lab", + return_value=np.array([1.0, 2.0, 3.0], dtype="float32")) + color = ConstantsAugmentation._get_color(256) + clahe_mock.assert_called_once_with(256) + lab_mock.assert_called_once_with() + assert isinstance(color, ConstantsColor) + assert isinstance(color.clahe_base_contrast, int) + assert isinstance(color.clahe_chance, float) + assert isinstance(color.clahe_max_size, int) + assert isinstance(color.lab_adjust, np.ndarray) + + assert color.clahe_base_contrast == clahe_mock.return_value[0] + assert color.clahe_chance == clahe_mock.return_value[1] + assert color.clahe_max_size == clahe_mock.return_value[2] + assert np.all(color.lab_adjust == lab_mock.return_value) + + +_TRANSFORM_CONF = ( + ({"rotation_range": 25, "zoom_amount": 1, "shift_range": 6, "flip_chance": 10}, 64), + ({"rotation_range": 6, "zoom_amount": 2, "shift_range": 5, "flip_chance": 60}, 96), + ({"rotation_range": 39, "zoom_amount": 3, "shift_range": 4, "flip_chance": 23}, 128), + ({"rotation_range": 12, "zoom_amount": 4, "shift_range": 3, "flip_chance": 52}, 256), + ({"rotation_range": 47, "zoom_amount": 5, "shift_range": 2, "flip_chance": 33}, 384), + ({"rotation_range": 3, "zoom_amount": 6, "shift_range": 1, "flip_chance": 44}, 512)) + + +@pytest.mark.parametrize(("config", "size"), _TRANSFORM_CONF) +def test_constants_get_transform(config: dict[str, T.Any], + size: int, + patch_config) -> None: # noqa[F811] + """ Test ConstantsAugmentation._get_transform works as expected """ + patch_config(cfg, config) + transform = ConstantsAugmentation._get_transform(size) + assert isinstance(transform, ConstantsTransform) + assert isinstance(transform.rotation, int) + assert isinstance(transform.zoom, float) + assert isinstance(transform.shift, float) + assert isinstance(transform.flip, float) + assert transform.rotation == config["rotation_range"] + assert transform.zoom == config["zoom_amount"] / 100. + assert transform.shift == (config["shift_range"] / 100.) * size + assert transform.flip == config["flip_chance"] / 100. + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_constants_get_warp_to_landmarks(size: int, batch_size: int) -> None: + """ Test ConstantsAugmentation._get_warp_to_landmarks works as expected """ + anchors, grids = ConstantsAugmentation._get_warp_to_landmarks(size, batch_size) + assert isinstance(anchors, np.ndarray) + assert isinstance(grids, np.ndarray) + + assert anchors.dtype == np.int32 + assert anchors.shape == (batch_size, 8, 2) + assert anchors.min() == 0 + assert anchors.max() == size - 1 + + assert grids.dtype == np.float32 + assert grids.shape == (2, size, size) + assert grids.min() == 0. + assert grids.max() == size - 1 + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_constants_get_warp(size: int, batch_size: int, mocker: pytest_mock.MockerFixture) -> None: + """ Test ConstantsAugmentation._get_warp works as expected """ + warp_lm_mock = mocker.patch( + f"{MODULE_PREFIX}.ConstantsAugmentation._get_warp_to_landmarks", + return_value=((np.random.random((batch_size, 8, 2)) * 100).astype("int32"), + (np.random.random((2, size, size))).astype("float32"))) + warp_pad = int(1.25 * size) + + warps = ConstantsAugmentation._get_warp(size, batch_size) + + warp_lm_mock.assert_called_once_with(size, batch_size) + + assert isinstance(warps, ConstantsWarp) + + assert isinstance(warps.maps, np.ndarray) + assert warps.maps.dtype == "float32" + assert warps.maps.shape == (batch_size, 2, 5, 5) + assert warps.maps.min() == 0. + assert warps.maps.mean() == size / 2. + assert warps.maps.max() == size + + assert isinstance(warps.pad, tuple) + assert len(warps.pad) == 2 + assert all(isinstance(x, int) for x in warps.pad) + assert all(x == warp_pad for x in warps.pad) + + assert isinstance(warps.slices, slice) + assert warps.slices.step is None + assert warps.slices.start == warp_pad // 10 + assert warps.slices.stop == -warp_pad // 10 + + assert isinstance(warps.scale, float) + assert warps.scale == 5 / 256 * size + + assert isinstance(warps.lm_edge_anchors, np.ndarray) + assert warps.lm_edge_anchors.dtype == warp_lm_mock.return_value[0].dtype + assert warps.lm_edge_anchors.shape == warp_lm_mock.return_value[0].shape + assert np.all(warps.lm_edge_anchors == warp_lm_mock.return_value[0]) + + assert isinstance(warps.lm_grids, np.ndarray) + assert warps.lm_grids.dtype == warp_lm_mock.return_value[1].dtype + assert warps.lm_grids.shape == warp_lm_mock.return_value[1].shape + assert np.all(warps.lm_grids == warp_lm_mock.return_value[1]) + + assert isinstance(warps.lm_scale, float) + assert warps.lm_scale == 2 / 256 * size + + +_CONFIG = T.cast( + dict[str, ConfigValueType], + {"color_clahe_chance": 50, "color_clahe_max_size": 4, "color_lightness": 30, "color_ab": 8, + "rotation_range": 10, "zoom_amount": 5, "shift_range": 5, "flip_chance": 50}) + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_constants_from_config(size: int, + batch_size: int, + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture + ) -> None: + """ Test that ConstantsAugmentation.from_config executes correctly """ + patch_config(cfg, _CONFIG) + constants = ConstantsAugmentation.from_config(size, batch_size) + assert isinstance(constants, ConstantsAugmentation) + assert isinstance(constants.color, ConstantsColor) + assert isinstance(constants.transform, ConstantsTransform) + assert isinstance(constants.warp, ConstantsWarp) + + color_mock = mocker.patch(f"{MODULE_PREFIX}.ConstantsAugmentation._get_color") + transform_mock = mocker.patch(f"{MODULE_PREFIX}.ConstantsAugmentation._get_transform") + warp_mock = mocker.patch(f"{MODULE_PREFIX}.ConstantsAugmentation._get_warp") + ConstantsAugmentation.from_config(size, batch_size) + color_mock.assert_called_once_with(size) + transform_mock.assert_called_once_with(size) + warp_mock.assert_called_once_with(size, batch_size) + + +# IMAGE AUGMENTATION # +def get_batch(batch_size, size: int) -> np.ndarray: + """ Obtain a batch of random float32 image data for the given batch size and height/width """ + return (np.random.random((batch_size, size, size, 3)) * 255).astype("uint8") + + +def get_instance(batch_size, size) -> ImageAugmentation: + """ Obtain an ImageAugmentation instance for the given batch size and size """ + return ImageAugmentation(batch_size, size) + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_init(size: int, + batch_size: int, + patch_config) -> None: # noqa[F811] + """ Test ImageAugmentation initializes """ + patch_config(cfg, _CONFIG) + attrs = {"_processing_size": int, + "_batch_size": int, + "_constants": ConstantsAugmentation} + instance = get_instance(batch_size, size) + + assert all(x in instance.__dict__ for x in attrs) + assert all(x in attrs for x in instance.__dict__) + assert isinstance(instance._batch_size, int) + assert isinstance(instance._processing_size, int) + assert isinstance(instance._constants, ConstantsAugmentation) + assert instance._batch_size == batch_size + assert instance._processing_size == size + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_random_lab(size: int, + batch_size: int, + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation._random_lab executes as expected """ + patch_config(cfg, _CONFIG) + batch = get_batch(batch_size, size) + original = batch.copy() + instance = get_instance(batch_size, size) + + instance._random_lab(batch) + assert original.shape == batch.shape + assert original.dtype == batch.dtype + assert not np.allclose(original, batch) + + randoms_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.uniform") + instance._random_lab(batch) + randoms_mock.assert_called_once() + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_random_clahe(size: int, # pylint:disable=too-many-locals + batch_size: int, + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation._random_clahe executes as expected """ + # Expected output + patch_config(cfg, _CONFIG) + batch = get_batch(batch_size, size) + original = batch.copy() + instance = get_instance(batch_size, size) + + instance._random_clahe(batch) + assert original.shape == batch.shape + assert original.dtype == batch.dtype + assert not np.allclose(original, batch) + + # Functions called + rand_ret = np.random.rand(batch_size) + rand_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.rand", + return_value=rand_ret) + + where_ret = np.where(rand_ret < instance._constants.color.clahe_chance) + where_mock = mocker.patch(f"{MODULE_PREFIX}.np.where", + return_value=where_ret) + + randint_ret = np.random.randint(instance._constants.color.clahe_max_size, + size=where_ret[0].shape[0], + dtype="uint8") + randint_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.randint", + return_value=randint_ret) + + grid_sizes = (randint_ret * + (instance._constants.color.clahe_base_contrast // + 2)) + instance._constants.color.clahe_base_contrast + clahe_calls = [mocker.call(clipLimit=2.0, tileGridSize=(grid, grid)) for grid in grid_sizes] + clahe_mock = mocker.patch(f"{MODULE_PREFIX}.cv2.createCLAHE", + return_value=cv2.createCLAHE(clipLimit=2.0, tileGridSize=(3, 3))) + + batch = get_batch(batch_size, size) + instance._random_clahe(batch) + + rand_mock.assert_called_once_with(batch_size) + where_mock.assert_called_once() + randint_mock.assert_called_once_with(instance._constants.color.clahe_max_size + 1, + size=where_ret[0].shape[0], + dtype="uint8") + clahe_mock.assert_has_calls(clahe_calls) # type:ignore + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_color_adjust(size: int, + batch_size: int, + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation._color_adjust executes as expected """ + patch_config(cfg, _CONFIG) + batch = get_batch(batch_size, size) + output = get_instance(batch_size, size).color_adjust(batch) + assert output.shape == batch.shape + assert output.dtype == batch.dtype + assert not np.allclose(output, batch) + + batch_convert_mock = mocker.patch(f"{MODULE_PREFIX}.batch_convert_color") + lab_mock = mocker.patch(f"{MODULE_PREFIX}.ImageAugmentation._random_lab") + clahe_mock = mocker.patch(f"{MODULE_PREFIX}.ImageAugmentation._random_clahe") + + batch = get_batch(batch_size, size) + get_instance(batch_size, size).color_adjust(batch) + + assert batch_convert_mock.call_count == 2 + lab_mock.assert_called_once() + clahe_mock.assert_called_once() + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_transform(size: int, + batch_size: int, + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation.transform executes as expected """ + patch_config(cfg, _CONFIG) + batch = get_batch(batch_size, size) + instance = get_instance(batch_size, size) + original = batch.copy() + instance.transform(batch) + + assert original.shape == batch.shape + assert original.dtype == batch.dtype + assert not np.allclose(original, batch) + + rand_ret = [np.random.uniform(-10, 10, size=batch_size).astype("float32"), + np.random.uniform(.95, 1.05, size=batch_size).astype("float32"), + np.random.uniform(-9.2, 9.2, size=(batch_size, 2)).astype("float32")] + rand_calls = [mocker.call(-instance._constants.transform.rotation, + instance._constants.transform.rotation, + size=batch_size), + mocker.call(1 - instance._constants.transform.zoom, + 1 + instance._constants.transform.zoom, + size=batch_size), + mocker.call(-instance._constants.transform.shift, + instance._constants.transform.shift, + size=(batch_size, 2))] + rand_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.uniform", + side_effect=rand_ret) + + rotmat_mock = mocker.patch( + f"{MODULE_PREFIX}.cv2.getRotationMatrix2D", + return_value=np.array([[1.0, 0.0, -2.0], [-1.0, 1.0, 5.0]]).astype("float32")) + + affine_mock = mocker.patch(f"{MODULE_PREFIX}.cv2.warpAffine") + + batch = get_batch(batch_size, size) + get_instance(batch_size, size).transform(batch) + + rand_mock.assert_has_calls(rand_calls) # type:ignore + assert rotmat_mock.call_count == batch_size + assert affine_mock.call_count == batch_size + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_random_flip(size: int, + batch_size: int, + patch_config, # noqa[F811] + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation.flip_chance executes as expected """ + patch_config(cfg, _CONFIG) + batch = get_batch(batch_size, size) + original = batch.copy() + get_instance(batch_size, size).random_flip(batch) + + assert original.shape == batch.shape + assert original.dtype == batch.dtype + assert not np.allclose(original, batch) + + rand_ret = np.random.rand(batch_size) + rand_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.rand", return_value=rand_ret) + where_mock = mocker.patch(f"{MODULE_PREFIX}.np.where") + + batch = get_batch(batch_size, size) + get_instance(batch_size, size).random_flip(batch) + + rand_mock.assert_called_once_with(batch_size) + where_mock.assert_called_once() + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_random_warp(size: int, + batch_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation._random_warp executes as expected """ + batch = get_batch(batch_size, size) + instance = get_instance(batch_size, size) + output = instance._random_warp(batch) + + assert output.shape == batch.shape + assert output.dtype == batch.dtype + assert not np.allclose(output, batch) + + rand_ret = np.random.normal(size=(batch_size, 2, 5, 5), scale=0.02).astype("float32") + rand_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.normal", return_value=rand_ret) + + eval_ret = np.ones_like(rand_ret) + eval_mock = mocker.patch(f"{MODULE_PREFIX}.ne.evaluate", return_value=eval_ret) + + resize_ret = np.ones((size, size)).astype("float32") + resize_mock = mocker.patch(f"{MODULE_PREFIX}.cv2.resize", return_value=resize_ret) + + remap_mock = mocker.patch(f"{MODULE_PREFIX}.cv2.remap") + + instance._random_warp(batch) + + rand_mock.assert_called_once_with(size=(batch_size, 2, 5, 5), + scale=instance._constants.warp.scale) + eval_mock.assert_called_once() + assert resize_mock.call_count == batch_size * 2 + assert remap_mock.call_count == batch_size + + +@pytest.mark.parametrize(("size", "batch_size"), ((64, 16), (384, 32))) +def test_image_augmentation_random_warp_landmarks(size: int, + batch_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation._random_warp_landmarks executes as expected """ + src_points = np.random.random(size=(batch_size, 68, 2)).astype("float32") * size + dst_points = np.random.random(size=(batch_size, 68, 2)).astype("float32") * size + + batch = get_batch(batch_size, size) + instance = get_instance(batch_size, size) + output = instance._random_warp_landmarks(batch, src_points, dst_points) + + assert output.shape == batch.shape + assert output.dtype == batch.dtype + assert not np.allclose(output, batch) + + rand_ret = np.random.normal(size=dst_points.shape, scale=0.01) + rand_mock = mocker.patch(f"{MODULE_PREFIX}.np.random.normal", return_value=rand_ret) + + hull_ret = [cv2.convexHull(np.concatenate([src[17:], dst[17:]], axis=0)) + for src, dst in zip(src_points.astype("int32"), + (dst_points + rand_ret).astype("int32"))] + hull_mock = mocker.patch(f"{MODULE_PREFIX}.cv2.convexHull", side_effect=hull_ret) + + remap_mock = mocker.patch(f"{MODULE_PREFIX}.cv2.remap") + + instance._random_warp_landmarks(batch, src_points, dst_points) + + rand_mock.assert_called_once_with(size=(dst_points.shape), + scale=instance._constants.warp.lm_scale) + assert hull_mock.call_count == batch_size + assert remap_mock.call_count == batch_size + + +@pytest.mark.parametrize(("size", "batch_size", "to_landmarks"), + ((64, 16, True), (384, 32, False))) +def test_image_augmentation_warp(size: int, + batch_size: int, + to_landmarks: bool, + mocker: pytest_mock.MockerFixture) -> None: + """ Test that ImageAugmentation.warp executes as expected """ + kwargs = {} + if to_landmarks: + kwargs["batch_src_points"] = np.random.random( + size=(batch_size, 68, 2)).astype("float32") * size + kwargs["batch_dst_points"] = np.random.random( + size=(batch_size, 68, 2)).astype("float32") * size + batch = get_batch(batch_size, size) + output = get_instance(batch_size, size).warp(batch, to_landmarks, **kwargs) + + assert output.shape == batch.shape + assert output.dtype == batch.dtype + assert not np.allclose(output, batch) + + if to_landmarks: + with pytest.raises(AssertionError): + get_instance(batch_size, size).warp(batch, + to_landmarks, + batch_src_points=kwargs["batch_src_points"], + batch_dst_points=None) + with pytest.raises(AssertionError): + get_instance(batch_size, size).warp(batch, + to_landmarks, + batch_src_points=None, + batch_dst_points=kwargs["batch_dst_points"]) + + warp_mock = mocker.patch(f"{MODULE_PREFIX}.ImageAugmentation._random_warp") + warp_lm_mock = mocker.patch(f"{MODULE_PREFIX}.ImageAugmentation._random_warp_landmarks") + + get_instance(batch_size, size).warp(batch, to_landmarks, **kwargs) + if to_landmarks: + warp_mock.assert_not_called() + warp_lm_mock.assert_called_once() + else: + warp_mock.assert_called_once() + warp_lm_mock.assert_not_called() diff --git a/tests/lib/training/cache_test.py b/tests/lib/training/cache_test.py new file mode 100644 index 0000000000..afd3cef079 --- /dev/null +++ b/tests/lib/training/cache_test.py @@ -0,0 +1,964 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.training.cache` """ +import os +import typing as T + +from threading import Lock + +import numpy as np +import pytest +import pytest_mock + +from lib.align.constants import LandmarkType +from lib.training import cache as cache_mod +from lib.utils import FaceswapError +from plugins.train import train_config as cfg + + +from tests.lib.config.helpers import patch_config # # pylint:disable=unused-import # noqa[F401] + +# pylint:disable=protected-access,invalid-name,redefined-outer-name + + +# ## HELPERS ### + +MODULE_PREFIX = "lib.training.cache" +_DUMMY_IMAGE_LIST = ["/path/to/img1.png", "~/img2.png", "img3.png"] + + +def _get_config(centering="face", vertical_offset=0): + """ Return a fresh valid config """ + return {"centering": centering, + "vertical_offset": vertical_offset} + + +STANDARD_CACHE_ARGS = (_DUMMY_IMAGE_LIST, 256, 1.0) +STANDARD_MASK_ARGS = (256, 1.0, "face") + + +# ## MASK PROCESSING ### + +def get_mask_config(penalized_mask_loss=True, + learn_mask=True, + mask_type="extended", + mask_dilation=1.0, + mask_kernel=3, + mask_threshold=4, + mask_eye_multiplier=2, + mask_mouth_multiplier=3): + """ Generate the mask config dictionary with the given arguments """ + return {"penalized_mask_loss": penalized_mask_loss, + "learn_mask": learn_mask, + "mask_type": mask_type, + "mask_dilation": mask_dilation, + "mask_blur_kernel": mask_kernel, + "mask_threshold": mask_threshold, + "eye_multiplier": mask_eye_multiplier, + "mouth_multiplier": mask_mouth_multiplier} + + +_MASK_CONFIG_PARAMS = ( + (get_mask_config(True, True, "extended", 1.0, 3, 4, 2, 3), "pass-penalize|learn"), + (get_mask_config(True, False, "components", 0.0, 5, 4, 1, 2), "pass-penalize"), + (get_mask_config(False, True, "custom", -2.0, 6, 1, 3, 1), "pass-learn"), + (get_mask_config(True, True, None, 1.0, 6, 1, 3, 2), "pass-mask-disable1"), + (get_mask_config(False, False, "extended", 1.0, 6, 1, 3, 2), "pass-mask-disable2"), + (get_mask_config(True, True, "extended", 1.0, 1, 3, 1, 1), "pass-multiplier-disable"), + (get_mask_config("Error", True, "extended", 1.0, 1, 3, 2, 3), "fail-penalize"), + (get_mask_config(True, 1.4, "extended", 1.0, 1, 3, 2, 3), "fail-learn"), + (get_mask_config(True, True, 999, 1.0, 1, 3, 2, 3), "fail-type"), + (get_mask_config(True, True, "extended", 23, 1, 3, 2, 3), "fail-dilation"), + (get_mask_config(True, True, "extended", 1.0, 1.2, 3, 2, 3), "fail-kernel"), + (get_mask_config(True, True, "extended", 1.0, 1, "fail", 2, 3), "fail-threshold"), + (get_mask_config(True, True, "extended", 1.0, 1, 3, 3.9, 3), "fail-eye-multi"), + (get_mask_config(True, True, "extended", 1.0, 1, 3, 2, "fail"), "fail-mouth-multi")) +_MASK_CONFIG_IDS = [x[-1] for x in _MASK_CONFIG_PARAMS] + + +@pytest.mark.parametrize(("config", "status"), _MASK_CONFIG_PARAMS, ids=_MASK_CONFIG_IDS) +def test_MaskConfig(config: dict[str, T.Any], + status: str, + patch_config) -> None: # noqa[F811] + """ Test that cache._MaskConfig dataclass initializes from config """ + patch_config(cfg.Loss, config) + retval = cache_mod._MaskConfig() + if status.startswith("pass-mask-disable"): + assert not retval.mask_enabled + else: + assert retval.mask_enabled + + if status == "pass-multiplier-disable" or not config["penalized_mask_loss"]: + assert not retval.multiplier_enabled + else: + assert retval.multiplier_enabled + + +_MASK_INIT_PARAMS = ((64, 0.5, "face", "pass"), + (128, 0.75, "head", "pass"), + (384, 1.0, "legacy", "pass"), + (69.42, 0.75, "head", "fail-size"), + (128, "fail", "head", "fail-coverage"), + (128, 0.75, "fail", "fail-centering")) +_MASK_INIT_IDS = [x[-1] for x in _MASK_INIT_PARAMS] + + +@pytest.mark.parametrize(("size", "coverage", "centering", "status"), + _MASK_INIT_PARAMS, ids=_MASK_INIT_IDS) +def test_MaskProcessing_init(size, + coverage, + centering, + status: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test cache._MaskProcessing correctly initializes """ + mock_maskconfig = mocker.MagicMock() + mocker.patch(f"{MODULE_PREFIX}._MaskConfig", new=mock_maskconfig) + + if not status == "pass": + with pytest.raises(AssertionError): + cache_mod._MaskProcessing(size, coverage, centering) + return + + instance = cache_mod._MaskProcessing(size, coverage, centering) + attrs = {"_size": int, + "_coverage": float, + "_centering": str, + "_config": mocker.MagicMock} # Our mocked _MaskConfig + + for attr, dtype in attrs.items(): + assert attr in instance.__dict__ + assert isinstance(instance.__dict__[attr], dtype) + assert all(x in attrs for x in instance.__dict__) + + assert instance._size == size + assert instance._coverage == coverage + assert instance._centering == centering + mock_maskconfig.assert_called_once() + + +def test_MaskProcessing_check_mask_exists(mocker: pytest_mock.MockerFixture) -> None: + """ Test cache._MaskProcessing._check_mask_exists functions as expected """ + mock_det_face = mocker.MagicMock() + mock_det_face.mask = ["extended", "components"] + + instance = cache_mod._MaskProcessing(*STANDARD_MASK_ARGS) # type:ignore[arg-type] + + instance._check_mask_exists("", mock_det_face) + + mock_det_face.mask = [] + with pytest.raises(FaceswapError): + instance._check_mask_exists("", mock_det_face) + + +@pytest.mark.parametrize(("dilation", "kernel", "threshold"), + ((1.0, 3, 4), (-2.5, 5, 2), (3.3, 7, 9))) +def test_MaskProcessing_preprocess(dilation: float, + kernel: int, + threshold: int, + mocker: pytest_mock.MockerFixture, + patch_config) -> None: # noqa[F811] + """ Test cache._MaskProcessing._preprocess functions as expected """ + mock_mask = mocker.MagicMock() + mock_det_face = mocker.MagicMock() + mock_det_face.mask = {"extended": mock_mask} + + patch_config(cfg.Loss, get_mask_config(mask_dilation=dilation, + mask_kernel=kernel, + mask_threshold=threshold)) + + instance = cache_mod._MaskProcessing(*STANDARD_MASK_ARGS) # type:ignore[arg-type] + instance._preprocess(mock_det_face, "extended") + mock_mask.set_dilation.assert_called_once_with(dilation) + mock_mask.set_blur_and_threshold.assert_called_once_with(blur_kernel=kernel, + threshold=threshold) + + +@pytest.mark.parametrize( + ("mask_centering", "train_centering", "coverage", "y_offset", "size", "mask_size"), + (("face", "legacy", 0.75, 0.0, 256, 64), + ("legacy", "head", 0.66, -0.25, 128, 128), + ("head", "face", 1.0, 0.33, 64, 256))) +def test_MaskProcessing_crop_and_resize(mask_centering: str, # pylint:disable=too-many-locals + train_centering: T.Literal["legacy", "face", "head"], + coverage: float, + y_offset: float, + size: int, + mask_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test cache._MaskProcessing._crop_and_resize functions as expected """ + mock_pose = mocker.MagicMock() + mock_pose.offset = {"face": "face_centering", + "legacy": "legacy_centering", + "head": "head_centering"} + + mock_det_face = mocker.MagicMock() + mock_det_face.aligned.pose = mock_pose + mock_det_face.aligned.y_offset = y_offset + + mock_face_mask = mocker.MagicMock() + mock_face_mask.__get_item__ = mock_face_mask + mock_face_mask.shape = (mask_size, mask_size) + + mock_mask = mocker.MagicMock() + mock_mask.stored_centering = mask_centering + mock_mask.stored_size = mask_size + mock_mask.mask = mock_face_mask + + mock_cv2_resize_result = mocker.MagicMock() + mock_cv2_resize_item = mocker.MagicMock() + mock_cv2_resize = mocker.patch(f"{MODULE_PREFIX}.cv2.resize", + return_value=mock_cv2_resize_result) + mock_cv2_resize_result.__getitem__.return_value = mock_cv2_resize_item + + mock_cv2_cubic = mocker.patch(f"{MODULE_PREFIX}.cv2.INTER_CUBIC") + mock_cv2_area = mocker.patch(f"{MODULE_PREFIX}.cv2.INTER_AREA") + + instance = cache_mod._MaskProcessing(size, coverage, train_centering) + + retval = instance._crop_and_resize(mock_det_face, mock_mask) + mock_mask.set_sub_crop.assert_called_once_with(mock_pose.offset[mask_centering], + mock_pose.offset[train_centering], + train_centering, + coverage, + y_offset) + if mask_size == size: + assert retval is mock_face_mask + mock_cv2_resize.assert_not_called() + return + + assert retval is mock_cv2_resize_item + interp_used = mock_cv2_cubic if mask_size < size else mock_cv2_area + mock_cv2_resize.assert_called_once_with(mock_face_mask, + (size, size), + interpolation=interp_used) + + +@pytest.mark.parametrize("mask_type", (None, "extended", "components")) +def test_MaskProcessing_get_face_mask(mask_type: str | None, + mocker: pytest_mock.MockerFixture, + patch_config) -> None: # noqa[F811] + """ Test cache._MaskProcessing._get_face_mask functions as expected """ + patch_config(cfg, _get_config()) + patch_config(cfg.Loss, get_mask_config(mask_type=mask_type)) + instance = cache_mod._MaskProcessing(*STANDARD_MASK_ARGS) # type:ignore[arg-type] + assert instance._config.mask_type == mask_type # sanity check + + instance._check_mask_exists = mocker.MagicMock() # type:ignore[method-assign] + preprocess_return = "test_preprocess_return" + instance._preprocess = mocker.MagicMock( # type:ignore[method-assign] + return_value="test_preprocess_return") + crop_and_resize_return = mocker.MagicMock() + crop_and_resize_return.shape = (256, 256, 1) + instance._crop_and_resize = mocker.MagicMock( # type:ignore[method-assign] + return_value=crop_and_resize_return) + + filename = "test_filename" + detected_face = "test_detected_face" + + if mask_type is None: # Mask disabled + assert not instance._config.mask_enabled + retval1 = instance._get_face_mask(filename, detected_face) # type:ignore[arg-type] + assert retval1 is None + instance._check_mask_exists.assert_not_called() # type:ignore[attr-defined] + instance._preprocess.assert_not_called() # type:ignore[attr-defined] + instance._crop_and_resize.assert_not_called() # type:ignore[attr-defined] + else: # Mask enabled + assert instance._config.mask_enabled + retval2 = instance._get_face_mask(filename, detected_face) # type:ignore[arg-type] + assert retval2 is crop_and_resize_return + instance._check_mask_exists.assert_called_once_with( # type:ignore[attr-defined] + filename, detected_face) + + instance._preprocess.assert_called_once_with( # type:ignore[attr-defined] + detected_face, instance._config.mask_type) + + instance._crop_and_resize.assert_called_once_with( # type:ignore[attr-defined] + detected_face, preprocess_return) + + +@pytest.mark.parametrize(("eye_multiplier", "mouth_multiplier", "size", "enabled"), + ((0, 0, 64, False), + (1, 1, 64, False), + (1, 2, 64, True), + (2, 1, 96, True), + (2, 3, 128, True), + (3, 1, 256, True))) +def test_MaskProcessing_get_localized_mask(eye_multiplier: int, + mouth_multiplier: int, + size: int, + enabled: bool, + mocker: pytest_mock.MockerFixture, + patch_config) -> None: # noqa[F811] + """ Test cache._MaskProcessing._get_localized_mask functions as expected """ + args = STANDARD_MASK_ARGS[1:] + patch_config(cfg.Loss, get_mask_config(mask_eye_multiplier=eye_multiplier, + mask_mouth_multiplier=mouth_multiplier)) + instance = cache_mod._MaskProcessing(size, *args) # type:ignore[arg-type] + + filename = "filename" + detected_face = mocker.MagicMock() + landmark_mask_return_value = mocker.MagicMock() + + detected_face.get_landmark_mask = mocker.MagicMock(return_value=landmark_mask_return_value) + + for area in ("mouth", "eye"): + retval = instance._get_localized_mask(filename, detected_face, area) + if not enabled: + assert retval is None + detected_face.get_landmark_mask.assert_not_called() + else: + assert retval is landmark_mask_return_value + + if enabled: + detected_face.get_landmark_mask.assert_called_with(area, size // 16, 2.5) + if enabled: + assert detected_face.get_landmark_mask.call_count == 2 + + +def test_MaskProcessing_call(mocker: pytest_mock.MockerFixture) -> None: + """ Test cache._MaskProcessing.__call__ functions as expected """ + instance = cache_mod._MaskProcessing(*STANDARD_MASK_ARGS) # type:ignore[arg-type] + face_return = "face_mask" + area_return = "area_mask" + instance._get_face_mask = mocker.MagicMock( # type:ignore[method-assign] + return_value=face_return) # type:ignore[method-assign] + instance._get_localized_mask = mocker.MagicMock( # type:ignore[method-assign] + return_value=area_return) # type:ignore[method-assign] + + filename = "test_filename" + detected_face = mocker.MagicMock() + detected_face.store_training_masks = mocker.MagicMock() + + instance(filename, detected_face) + + instance._get_face_mask.assert_called_once_with( # type:ignore[attr-defined] + filename, detected_face) + + expected_localized_calls = [mocker.call(filename, detected_face, "eye"), + mocker.call(filename, detected_face, "mouth")] + instance._get_localized_mask.assert_has_calls( # type:ignore[attr-defined] + expected_localized_calls, any_order=False) # pyright:ignore[reportArgumentType] + assert instance._get_localized_mask.call_count == 2 # type:ignore[attr-defined] + + detected_face.store_training_masks.assert_called_once_with( + [face_return, area_return, area_return], + delete_masks=True) + + +# ## CACHE PROCESSING ### + +@pytest.fixture +def face_cache_reset_scenario(mocker: pytest_mock.MockerFixture, + request: pytest.FixtureRequest): + """ Build a scenario for cache._check_reset. + + request.param = {"caches": dict(Literal["a", "b"], bool], + "side": Literal["a", "b"]} + + If the key "a" or "b" exist in the caches dict, then that cache exists in the mocked + cache._FACE_CACHES with a mock representing the return value of the cache.Cache.check_reset() + value as given + + The mocked Cache item for the currently testing side is returned, or a default mocked item if + the given side is not meant to be in the _FACE_CACHES dict + """ + cache_dict = {} + for side, val in request.param["caches"].items(): + check_mock = mocker.MagicMock() + check_mock.check_reset.return_value = val + cache_dict[side] = check_mock + mocker.patch(f"{MODULE_PREFIX}._FACE_CACHES", new=cache_dict) + return cache_dict.get(request.param["side"], mocker.MagicMock()) + + +_RESET_PARAMS = [({"side": side, "caches": caches}, expected, f"{name}-{side}") + for side in ("a", "b") + for caches, expected, name in [ + ({}, False, "no-cache"), + ({"a": False}, False, "a-exists"), + ({"b": False}, False, "b-exists"), + ({"a": True, "b": False}, side == "b", "a-reset"), + ({"a": False, "b": True}, side == "a", "b-reset"), + ({"a": True, "b": True}, True, "both-reset"), + ({"a": False, "b": False}, False, "no-reset")]] +_RESET_IDS = [x[-1] for x in _RESET_PARAMS] +_RESET_PARAMS = [x[:-1] for x in _RESET_PARAMS] # type:ignore[misc] + + +@pytest.mark.parametrize(("face_cache_reset_scenario", "expected"), + _RESET_PARAMS, + ids=_RESET_IDS, + indirect=["face_cache_reset_scenario"]) +def test_check_reset(face_cache_reset_scenario, expected): # pylint:disable=redefined-outer-name + """ Test that cache._check_reset functions as expected """ + this_cache = face_cache_reset_scenario + assert cache_mod._check_reset(this_cache) == expected + + +@pytest.mark.parametrize( + ("filenames", "size", "coverage_ratio", "centering"), + [(_DUMMY_IMAGE_LIST, 256, 1.0, "face"), + (_DUMMY_IMAGE_LIST[:-1], 96, .75, "head"), + (_DUMMY_IMAGE_LIST[2:], 384, .66, "legacy")]) +def test_Cache_init(filenames, size, coverage_ratio, centering, patch_config): # noqa[F811] + """ Test that cache.Cache correctly initializes """ + attrs = {"_lock": type(Lock()), + "_cache_info": dict, + "_config": cache_mod._CacheConfig, + "_partially_loaded": list, + "_image_count": int, + "_cache": dict, + "_aligned_landmarks": dict, + "_extract_version": float, + "_mask_prepare": cache_mod._MaskProcessing} + patch_config(cfg, _get_config(centering=centering)) + instance = cache_mod.Cache(filenames, size, coverage_ratio) + + for attr, attr_type in attrs.items(): + assert attr in instance.__dict__ + assert isinstance(getattr(instance, attr), attr_type) + for key in instance.__dict__: + assert key in attrs + + assert set(instance._cache_info) == {"cache_full", "has_reset"} + assert all(x is False for x in instance._cache_info.values()) + + assert not instance._partially_loaded + assert not instance._cache + assert instance._image_count == len(filenames) + assert not instance._aligned_landmarks + assert instance._extract_version == 0.0 + assert instance._config.size == size + assert instance._config.centering == centering + assert instance._config.coverage == coverage_ratio + + +def test_Cache_cache_full(mocker: pytest_mock.MockerFixture): + """ Test that cache.Cache.cache_full property behaves correctly """ + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._lock = mocker.MagicMock() + + is_full1 = instance.cache_full + assert not is_full1 + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + + instance._cache_info["cache_full"] = True + is_full2 = instance.cache_full + assert is_full2 + # lock not called when cache is full + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + + +def test_Cache_aligned_landmarks(mocker: pytest_mock.MockerFixture): + """ Test that cache.Cache.aligned_landmarks property behaves correcly """ + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._lock = mocker.MagicMock() + for fname in _DUMMY_IMAGE_LIST: + mock_face = mocker.MagicMock() + mock_face.aligned.landmarks = f"landmarks_for_{fname}" + instance._cache[fname] = mock_face + + retval1 = instance.aligned_landmarks + assert len(_DUMMY_IMAGE_LIST) == len(retval1) + assert retval1 == {fname: f"landmarks_for_{fname}" for fname in _DUMMY_IMAGE_LIST} + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + + retval2 = instance.aligned_landmarks + assert len(_DUMMY_IMAGE_LIST) == len(retval1) + assert retval2 == {fname: f"landmarks_for_{fname}" for fname in _DUMMY_IMAGE_LIST} + # lock not called after first call has populated + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + + +@pytest.mark.parametrize("size", (64, 96, 128, 256, 384)) +def test_Cache_size(size): + """ Test that cache.Cache.size property returns correctly """ + instance = cache_mod.Cache(_DUMMY_IMAGE_LIST, size, 1.0) + assert instance.size == size + + +def test_Cache_check_reset(): + """ Test that cache.Cache.check_reset behaves correctly """ + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + retval1 = instance.check_reset() + assert not retval1 + assert not instance._cache_info["has_reset"] + + instance._cache_info["has_reset"] = True + retval2 = instance.check_reset() + assert retval2 + assert not instance._cache_info["has_reset"] + + +@pytest.mark.parametrize("filenames", + (_DUMMY_IMAGE_LIST, _DUMMY_IMAGE_LIST[:-1], _DUMMY_IMAGE_LIST[2:])) +def test_Cache_get_items(filenames: list[str]) -> None: + """ Test that cache.Cache.get_items returns correctly """ + instance = cache_mod.Cache(filenames, 256, 1.0) + instance._cache = {os.path.basename(f): f"faces_for_{f}" # type:ignore[misc] + for f in filenames} + + retval = instance.get_items(filenames) + assert retval == [f"faces_for_{f}" for f in filenames] + + +@pytest.mark.parametrize("set_flag", (True, False), ids=("set-flag", "no-set-flag")) +def test_Cache_reset_cache(set_flag: bool, + mocker: pytest_mock.MockerFixture, + patch_config) -> None: # noqa[F811] + """ Test that cache.Cache._reset_cache functions correctly """ + patch_config(cfg, _get_config(centering="head")) + mock_warn = mocker.MagicMock() + mocker.patch(f"{MODULE_PREFIX}.logger.warning", mock_warn) + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._cache = {"test": "cache"} # type:ignore[dict-item] + instance._cache_info["cache_full"] = True + + assert instance._config.centering != "legacy" + assert instance._cache + assert instance._cache_info["cache_full"] + + instance._reset_cache(set_flag) + + assert instance._config.centering == "legacy" + assert not instance._cache + assert instance._cache_info["cache_full"] is False + + if set_flag: + mock_warn.assert_called_once() + + +@pytest.mark.parametrize("png_meta", + ({"source": {"alignments_version": 1.0}}, + {"source": {"alignments_version": 2.0}}, + {"source": {"alignments_version": 2.2}}), + ids=("v1.0", "v2.0", "v2.2")) +def test_Cache_validate_version(png_meta, mocker): + """ Test that cache.Cache._validate_version executes correctly """ + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._reset_cache = mocker.MagicMock() + fname = "test_filename.png" + version = png_meta["source"]["alignments_version"] + + if version == 1.0: + for centering in ("legacy", "face"): + instance._extract_version = 0.0 + instance._config.centering = centering + instance._validate_version(png_meta, fname) + if centering == "legacy": + instance._reset_cache.assert_not_called() + else: + instance._reset_cache.assert_called_once_with(True) + assert instance._extract_version == version + else: + instance._validate_version(png_meta, fname) + instance._reset_cache.assert_not_called() + assert instance._extract_version == version + + instance._extract_version = 1.0 # Legacy alignments have been seen + if version > 1.0: # Newer alignments inbound + with pytest.raises(FaceswapError): + instance._validate_version(png_meta, fname) + else: + instance._validate_version(png_meta, fname) + + instance._extract_version = 2.0 # Newer alignments have been seen + if version < 2.0: # Legacy alignments inbound + with pytest.raises(FaceswapError): + instance._validate_version(png_meta, fname) + return # Exit early on 1.0 because cannot pass any more tests + + instance._validate_version(png_meta, fname) + if version > 2.0: + assert instance._extract_version == 2.0 # Defaulted to lowest version + + instance._extract_version = 2.5 + instance._validate_version(png_meta, fname) + assert instance._extract_version == version # Defaulted to lowest version + + +_DET_FACE_PARAMS = ((64, 0.5, 0, 1.0), + (96, 0.75, 1, 1.0), + (256, 0.66, 2, 2.0), + (384, 1.0, 3.0, 2.2)) +_DET_FACE_IDS = [f"size:{x[0]}|coverage:{x[1]}|y-offset:{x[2]}|extract-vers:{x[3]}" + for x in _DET_FACE_PARAMS] + + +@pytest.mark.parametrize(("size", "coverage", "y_offset", "extract_version"), + _DET_FACE_PARAMS, + ids=_DET_FACE_IDS) +def test_Cache_load_detected_face(size: int, + coverage: float, + y_offset: int | float, + extract_version: float, + mocker: pytest_mock.MockerFixture, + patch_config) -> None: # noqa[F811] + """ Test that cache.Cache._load_detected_faces executes correctly """ + patch_config(cfg, _get_config(vertical_offset=y_offset)) + instance = cache_mod.Cache(_DUMMY_IMAGE_LIST, size, coverage) + instance._extract_version = extract_version + alignments = {} # type:ignore[var-annotated] + + mock_det_face = mocker.MagicMock() + mock_det_face.from_png_meta = mocker.MagicMock() + mock_det_face.load_aligned = mocker.MagicMock() + mocker.patch(f"{MODULE_PREFIX}.DetectedFace", return_value=mock_det_face) + + retval = instance._load_detected_face("", alignments) # type:ignore[arg-type] + assert retval is mock_det_face + mock_det_face.from_png_meta.assert_called_once_with(alignments) + mock_det_face.load_aligned.assert_called_once_with(None, + size=instance._config.size, + centering=instance._config.centering, + coverage_ratio=instance._config.coverage, + y_offset=y_offset / 100., + is_aligned=True, + is_legacy=extract_version == 1.0) + + +@pytest.mark.parametrize("partially_loaded", (True, False), ids=("partial", "full")) +def test_Cache_populate_cache(partially_loaded: bool, + mocker: pytest_mock.MockerFixture) -> None: + """ Test that cache.Cache._populate_cache executes correctly """ + already_cached = ["/path/to/img4.png", "/path/img5.png"] + needs_cache = _DUMMY_IMAGE_LIST + filenames = _DUMMY_IMAGE_LIST + already_cached + metadata = [{"alignments": f"{f}_alignments"} for f in filenames] + + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._validate_version = mocker.MagicMock() # type:ignore[method-assign] + instance._mask_prepare = mocker.MagicMock() + instance._cache = {os.path.basename(f): "existing" # type:ignore[misc] + for f in filenames if f not in needs_cache} + + mock_detected_faces = {f: mocker.MagicMock() for f in needs_cache} + + if partially_loaded: + instance._cache.update({os.path.basename(f): mock_detected_faces[f] for f in needs_cache}) + instance._partially_loaded = [os.path.basename(f) for f in filenames] # Add our partials + else: + instance._load_detected_face = mocker.MagicMock( # type:ignore[method-assign] + side_effect=[mock_detected_faces[f] for f in needs_cache]) + + # Call the function + instance._populate_cache(needs_cache, metadata, filenames) # type:ignore[arg-type] + + expected_validate = [mocker.call(metadata[idx], f) for idx, f in enumerate(needs_cache)] + instance._validate_version.assert_has_calls(expected_validate, # type:ignore[attr-defined] + any_order=False) + assert instance._validate_version.call_count == len(needs_cache) # type:ignore[attr-defined] + + expected_mask_prepare = [mocker.call(f, mock_detected_faces[f]) for f in needs_cache] + instance._mask_prepare.assert_has_calls(expected_mask_prepare, # type:ignore[attr-defined] + any_order=False) + assert instance._mask_prepare.call_count == len(needs_cache) # type:ignore[attr-defined] + + assert len(instance._cache) == len(filenames) + for filename in filenames: + key = os.path.basename(filename) + assert key in instance._cache + if filename in needs_cache: # item got added/updated + assert instance._cache[key] == mock_detected_faces[filename] + else: # item pre-existed + assert instance._cache[key] == "existing" + + if partially_loaded: + assert instance._partially_loaded == [os.path.basename(f) for f in filenames + if f not in needs_cache] + + +@pytest.mark.parametrize("scenario", ("read-error", "size-error", "success")) +def test_Cache_get_batch_with_metadata(scenario: str, mocker: pytest_mock.MockerFixture) -> None: + """ Test that cache.Cache._get_batch_with_metadata executes correctly """ + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + filenames = ["list", "of", "test", "filenames"] + + mock_read_image_batch = mocker.MagicMock() + if scenario == "read-error": + mock_read_image_batch.side_effect = ValueError("inhomogeneous") + else: + mock_return = (mocker.MagicMock(), {"test": "meta"}) + if scenario == "size-error": + mock_return[0].shape = (len(filenames), ) + else: + mock_return[0].shape = (len(filenames), 64, 64, 3) + mock_read_image_batch.return_value = mock_return + + mocker.patch(f"{MODULE_PREFIX}.read_image_batch", new=mock_read_image_batch) + + if scenario != "success": + with pytest.raises(FaceswapError): + instance._get_batch_with_metadata(filenames) + mock_read_image_batch.assert_called_once_with(filenames, with_metadata=True) + return + + retval = instance._get_batch_with_metadata(filenames) + mock_read_image_batch.assert_called_once_with(filenames, with_metadata=True) + assert retval == mock_return # pyright:ignore[reportPossiblyUnboundVariable] + + +@pytest.mark.parametrize("scenario", ("full", "not-full", "partial")) +def test_Cache_update_cache_full(scenario: bool, mocker: pytest_mock.MockerFixture) -> None: + """ Test that cache.Cache._update_cache_full executes correctly """ + mock_verbose = mocker.patch(f"{MODULE_PREFIX}.logger.verbose") + filenames = ["test", "file", "names"] + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._image_count = 10 + + assert instance._cache_info["cache_full"] is False + assert not instance._cache + assert not instance._partially_loaded + + if scenario == "full": + instance._cache = {i: i for i in range(10)} # type:ignore[misc] + if scenario == "patial": + instance._cache = {i: i for i in range(10)} # type:ignore[misc] + instance._partially_loaded = filenames.copy() + + instance._update_cache_full(filenames) + + if scenario == "full": + assert instance._cache_info["cache_full"] is True + mock_verbose.assert_called_once() + else: + assert instance._cache_info["cache_full"] is False + mock_verbose.assert_not_called() + + +@pytest.mark.parametrize("scenario", ("full", "partial", "empty", "needs-reset")) +def test_Cache_cache_metadata(scenario: str, mocker: pytest_mock.MockerFixture) -> None: + """ Test that cache.Cache.cache_metadata executes correctly """ + mock_check_reset = mocker.patch(f"{MODULE_PREFIX}._check_reset") + mock_check_reset.return_value = scenario == "needs-reset" + mock_return_batch = mocker.MagicMock() + + mock_read_image_batch = mocker.patch(f"{MODULE_PREFIX}.read_image_batch", + return_value=mock_return_batch) + + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + filenames = _DUMMY_IMAGE_LIST.copy() + + if scenario in ("full", "partial"): + instance._cache = {os.path.basename(f): f for f in filenames} # type:ignore[misc] + if scenario == "partial": + instance._partially_loaded = [os.path.basename(f) for f in filenames] + + instance._lock = mocker.MagicMock() + instance._reset_cache = mocker.MagicMock() # type:ignore[method-assign] + returned_meta = {"test": "meta"} + instance._get_batch_with_metadata = mocker.MagicMock( # type:ignore[method-assign] + return_value=(mock_return_batch, returned_meta)) + instance._populate_cache = mocker.MagicMock() # type:ignore[method-assign] + instance._update_cache_full = mocker.MagicMock() # type:ignore[method-assign] + + retval = instance.cache_metadata(filenames) # Call + + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + + mock_check_reset.assert_called_once_with(instance) + + if scenario == "needs-reset": + instance._reset_cache.assert_called_once_with(False) # type:ignore[attr-defined] + else: + instance._reset_cache.assert_not_called() # type:ignore[attr-defined] + + if scenario == "full": + mock_read_image_batch.assert_called_once_with(filenames) + instance._get_batch_with_metadata.assert_not_called() # type:ignore[attr-defined] + instance._populate_cache.assert_not_called() # type:ignore[attr-defined] + instance._update_cache_full.assert_not_called() # type:ignore[attr-defined] + else: + mock_read_image_batch.assert_not_called() + instance._get_batch_with_metadata.assert_called_once_with( # type:ignore[attr-defined] + filenames) + instance._populate_cache.assert_called_once_with( # type:ignore[attr-defined] + filenames, returned_meta, filenames) + instance._update_cache_full.assert_called_once_with(filenames) # type:ignore[attr-defined] + + assert retval is mock_return_batch + + +@pytest.mark.parametrize("scenario", ("fail-meta", "fail-landmarks", "success")) +def test_Cache_pre_fill(scenario: str, mocker: pytest_mock.MockerFixture) -> None: + """ Test that cache.Cache.prefill executes correctly """ + filenames = _DUMMY_IMAGE_LIST.copy() + mock_read_image_batch = mocker.patch(f"{MODULE_PREFIX}.read_image_meta_batch") + side_effect_read_image_batch = [(f, {}) for f in filenames] # type:ignore[var-annotated] + if scenario != "fail-meta": # Set successful return data + for effect in side_effect_read_image_batch: + effect[1]["itxt"] = {"alignments": [1, 2, 3]} + mock_read_image_batch.side_effect = [side_effect_read_image_batch] + + instance = cache_mod.Cache(*STANDARD_CACHE_ARGS) + instance._lock = mocker.MagicMock() + instance._validate_version = mocker.MagicMock() # type:ignore[method-assign] + mock_detected_faces = [mocker.MagicMock() for _ in filenames] + + for m in mock_detected_faces: + m.aligned.landmark_type = (LandmarkType.LM_2D_68 if scenario == "success" else "fail") + instance._load_detected_face = mocker.MagicMock( # type:ignore[method-assign] + side_effect=mock_detected_faces) + + if scenario in ("fail-meta", "fail-landmarks"): + with pytest.raises(FaceswapError): + instance.pre_fill(filenames, "a") + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + mock_read_image_batch.assert_called_once_with(filenames) + if scenario == "fail-meta": + instance._validate_version.assert_not_called() # type:ignore[attr-defined] + instance._load_detected_face.assert_not_called() # type:ignore[attr-defined] + else: + meta = side_effect_read_image_batch[0][1]["itxt"] + instance._validate_version.assert_called_once_with( # type:ignore[attr-defined] + meta, filenames[0]) + instance._load_detected_face.assert_called_once_with( # type:ignore[attr-defined] + filenames[0], meta["alignments"]) + return + + # success + instance.pre_fill(filenames, "a") + instance._lock.__enter__.assert_called_once() # type:ignore[attr-defined] + instance._lock.__exit__.assert_called_once() # type:ignore[attr-defined] + mock_read_image_batch.assert_called_once_with(filenames) + + fname_calls = [x[0] for x in side_effect_read_image_batch] + meta_calls = [x[1]["itxt"] for x in side_effect_read_image_batch] + call_validate = [mocker.call(l, f) for f, l in zip(fname_calls, meta_calls)] + call_det_face = [mocker.call(f, l["alignments"]) for f, l in zip(fname_calls, meta_calls)] + + instance._validate_version.assert_has_calls( # type:ignore[attr-defined] + call_validate, any_order=False) # type:ignore[attr-defined] + assert instance._validate_version.call_count == len(filenames) # type:ignore[attr-defined] + instance._load_detected_face.assert_has_calls( # type:ignore[attr-defined] + call_det_face, any_order=False) # type:ignore[attr-defined] + assert instance._load_detected_face.call_count == len(filenames) # type:ignore[attr-defined] + + assert instance._cache == {os.path.basename(f): d for f, d in zip(filenames, + mock_detected_faces)} + assert instance._partially_loaded == [os.path.basename(f) for f in filenames] + + +_PARAMS_GET = (("a", _DUMMY_IMAGE_LIST, 256, 1.), + ("b", _DUMMY_IMAGE_LIST, 256, 1.), + ("c", _DUMMY_IMAGE_LIST, 256, 1.), + ("a", None, 256, 1,), + ("a", _DUMMY_IMAGE_LIST, None, 1.), + ("a", _DUMMY_IMAGE_LIST, 256, None)) +_IDS_GET = ("pass-a", "pass-b", "fail-side", "fail-no-filenames", + "fail-no-size", "fail-no-coverage") + + +@pytest.mark.parametrize(("side", "filenames", "size", "coverage_ratio", "status"), + (x + (y,) for x, y in zip(_PARAMS_GET, _IDS_GET)), + ids=_IDS_GET) +def test_get_cache_initial(side: str, + filenames: list[str], + size: int, + coverage_ratio: float, + status: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test cache.get_cache function when the cache does not yet exist """ + mocker.patch(f"{MODULE_PREFIX}._FACE_CACHES", new={}) + patched_cache = mocker.patch(f"{MODULE_PREFIX}.Cache") + if status.startswith("fail"): + with pytest.raises(AssertionError): + cache_mod.get_cache(side, filenames, size, coverage_ratio) # type:ignore[arg-type] + patched_cache.assert_not_called() + return + + retval = cache_mod.get_cache(side, filenames, size, coverage_ratio) # type:ignore[arg-type] + assert side in cache_mod._FACE_CACHES + patched_cache.assert_called_once_with(filenames, size, coverage_ratio) + assert cache_mod._FACE_CACHES[side] is patched_cache.return_value + assert retval is patched_cache.return_value + + retval2 = cache_mod.get_cache(side, filenames, size, coverage_ratio) # type:ignore[arg-type] + patched_cache.assert_called_once() # Not called again + assert retval2 is retval + + +_IDS_GET2 = ("pass-a", "pass-b", "fail-side", "pass-no-filenames", + "pass-no-size", "pass-no-coverage") + + +@pytest.mark.parametrize(("side", "filenames", "size", "coverage_ratio", "status"), + (x + (y,) for x, y in zip(_PARAMS_GET, _IDS_GET2)), + ids=_IDS_GET2) +def test_get_cache_exists(side: str, + filenames: list[str], + size: int, + coverage_ratio: float, + status: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test cache.get_cache function when the cache exists """ + mocker.patch(f"{MODULE_PREFIX}._FACE_CACHES", new={"a": mocker.MagicMock(), + "b": mocker.MagicMock()}) + patched_cache = mocker.patch(f"{MODULE_PREFIX}.Cache") + + if status.startswith("fail"): + with pytest.raises(AssertionError): + cache_mod.get_cache(side, filenames, size, coverage_ratio) # type:ignore[arg-type] + patched_cache.assert_not_called() + return + + retval = cache_mod.get_cache(side, filenames, size, coverage_ratio) # type:ignore[arg-type] + patched_cache.assert_not_called() + assert retval is cache_mod._FACE_CACHES[side] + + +# ## Ring Buffer ## # + +_RING_BUFFER_PARAMS = ((2, (384, 384, 3), 2, "uint8"), + (16, (128, 128, 3), 5, "float32"), + (32, (64, 64, 3), 4, "int32")) +_RING_BUFFER_IDS = [f"bs{x[0]}|{x[1][0]}px|buffer-size{x[2]}|dtype-{x[3]}" + for x in _RING_BUFFER_PARAMS] + + +@pytest.mark.parametrize(("batch_size", "image_shape", "buffer_size", "dtype"), + ((2, (384, 384, 3), 2, "uint8"), + (16, (128, 128, 3), 5, "float32"), + (32, (64, 64, 3), 4, "int32")), + ids=_RING_BUFFER_IDS) +def test_RingBuffer_init(batch_size, image_shape, buffer_size, dtype): + """ test cache.RingBuffer initializes correctly """ + attrs = {"_max_index": int, "_index": int, "_buffer": list} + instance = cache_mod.RingBuffer(batch_size, image_shape, buffer_size, dtype) + + for attr, attr_type in attrs.items(): + assert attr in instance.__dict__ + assert isinstance(getattr(instance, attr), attr_type) + for key in instance.__dict__: + assert key in attrs + + assert instance._max_index == buffer_size - 1 + assert instance._index == 0 + assert len(instance._buffer) == buffer_size + assert all(isinstance(b, np.ndarray) for b in instance._buffer) + assert all(b.shape == (batch_size, *image_shape) for b in instance._buffer) + assert all(b.dtype == dtype for b in instance._buffer) + + +@pytest.mark.parametrize(("batch_size", "image_shape", "buffer_size", "dtype"), + ((2, (384, 384, 3), 2, "uint8"), + (16, (128, 128, 3), 5, "float32"), + (32, (64, 64, 3), 4, "int32")), + ids=_RING_BUFFER_IDS) +def test_RingBuffer_call(batch_size, image_shape, buffer_size, dtype): + """ Test calling cache.RingBuffer works correctly """ + instance = cache_mod.RingBuffer(batch_size, image_shape, buffer_size, dtype) + for i in range(buffer_size * 3): + retval = instance() + assert isinstance(retval, np.ndarray) + assert retval.shape == (batch_size, *image_shape) + assert retval.dtype == dtype + if i % buffer_size == buffer_size - 1: + assert instance._index == 0 + else: + assert instance._index == i % buffer_size + 1 diff --git a/tests/lib/training/lr_finder_test.py b/tests/lib/training/lr_finder_test.py new file mode 100644 index 0000000000..c2914707e6 --- /dev/null +++ b/tests/lib/training/lr_finder_test.py @@ -0,0 +1,270 @@ +#! /usr/env/bin/python3 +""" Unit tests for Learning Rate Finder. """ + +import pytest +import pytest_mock + +import numpy as np + +from lib.training.lr_finder import LearningRateFinder +from plugins.train import train_config as cfg + +# pylint:disable=unused-import +from tests.lib.config.helpers import patch_config # noqa:[F401] + +# pylint:disable=protected-access,invalid-name,redefined-outer-name + + +@pytest.fixture +def _trainer_mock(patch_config, mocker: pytest_mock.MockFixture): # noqa:[F811] + """ Generate a mocked model and feeder object and patch user config items """ + def _apply_patch(iters=1000, mode="default", strength="default"): + patch_config(cfg, {"lr_finder_iterations": iters}) + patch_config(cfg, {"lr_finder_mode": mode}) + patch_config(cfg, {"lr_finder_strength": strength}) + trainer = mocker.MagicMock() + model = mocker.MagicMock() + model.name = "TestModel" + optimizer = mocker.MagicMock() + trainer._plugin.model = model + trainer._plugin.model.model.optimizer = optimizer + return trainer, model, optimizer + return _apply_patch + + +_STRENGTH_LOOKUP = {"default": 10, "aggressive": 5, "extreme": 2.5} + + +_LR_CONF = ((20, "graph_and_set", "default"), + (500, "set", "aggressive"), + (1000, "graph_and_exit", "extreme")) +_LR_CONF_PARAMS = ("iters", "mode", "strength") + +_LR_CMDS = ((4, 0.98), (8, 0.66), (2, 0.33) + ) +_LR_CMDS_PARAMS = ("stop_factor", "beta") +_LR_CMDS_IDS = [f"stop:{x[0]}|beta:{x[1]}" for x in _LR_CMDS] + + +@pytest.mark.parametrize(_LR_CONF_PARAMS, _LR_CONF) +@pytest.mark.parametrize(_LR_CMDS_PARAMS, _LR_CMDS, ids=_LR_CMDS_IDS) +def test_LearningRateFinder_init(iters, mode, strength, stop_factor, beta, _trainer_mock): + """ Test lib.train.LearingRateFinder.__init__ """ + trainer, model, optimizer = _trainer_mock(iters, mode, strength) + lrf = LearningRateFinder(trainer, stop_factor=stop_factor, beta=beta) + assert lrf._trainer is trainer + assert lrf._model is model + assert lrf._optimizer is optimizer + assert lrf._start_lr == 1e-10 + assert lrf._stop_factor == stop_factor + assert lrf._beta == beta + + +_BATCH_END = ((1, 0.01, 1e-5, 0.5), + (27, 0.01, 1e-5, 1e-6), + (42, 0.001, 1e-5, 0.002),) +_BATCH_END_PARAMS = ("iteration", "loss", "learning_rate", "best") +_BATCH_END_IDS = [f"iter:{x[0]}|loss:{x[1]}|lr:{x[2]}" for x in _BATCH_END] + + +@pytest.mark.parametrize(_LR_CMDS_PARAMS, _LR_CMDS, ids=_LR_CMDS_IDS) +@pytest.mark.parametrize(_BATCH_END_PARAMS, _BATCH_END, ids=_BATCH_END_IDS) +def test_LearningRateFinder_on_batch_end(iteration, + loss, + learning_rate, + best, + stop_factor, + beta, + _trainer_mock, + mocker): + """ Test lib.train.LearingRateFinder._on_batch_end """ + trainer, model, optimizer = _trainer_mock() + lrf = LearningRateFinder(trainer, stop_factor=stop_factor, beta=beta) + optimizer.learning_rate.assign = mocker.MagicMock() + optimizer.learning_rate.numpy = mocker.MagicMock(return_value=learning_rate) + + initial_avg = lrf._loss["avg"] + lrf._loss["best"] = best + lrf._on_batch_end(iteration, loss) + + assert lrf._metrics["learning_rates"][-1] == learning_rate + assert lrf._loss["avg"] == (lrf._beta * initial_avg) + ((1 - lrf._beta) * loss) + assert lrf._metrics["losses"][-1] == lrf._loss["avg"] / (1 - (lrf._beta ** iteration)) + + if iteration > 1 and lrf._metrics["losses"][-1] > lrf._stop_factor * lrf._loss["best"]: + assert model.model.stop_training is True + optimizer.learning_rate.assign.assert_not_called() + return + + if iteration == 1: + assert lrf._loss["best"] == lrf._metrics["losses"][-1] + + assert model.model.stop_training is not True + optimizer.learning_rate.assign.assert_called_with( + learning_rate * lrf._lr_multiplier) + + +@pytest.mark.parametrize(_LR_CONF_PARAMS, _LR_CONF) +def test_LearningRateFinder_train(iters, # pylint:disable=too-many-locals + mode, + strength, + _trainer_mock, + mocker): + """ Test lib.train.LearingRateFinder._train """ + trainer, _, _ = _trainer_mock(iters, mode, strength) + + mock_loss_return = np.random.rand(2).tolist() + trainer.train_one_batch = mocker.MagicMock(return_value=mock_loss_return) + + lrf = LearningRateFinder(trainer) + + lrf._on_batch_end = mocker.MagicMock() + lrf._update_description = mocker.MagicMock() + + lrf._train() + + trainer.train_one_batch.assert_called() + assert trainer.train_one_batch.call_count == iters + + train_call_args = [mocker.call(x + 1, mock_loss_return[0]) for x in range(iters)] + assert lrf._on_batch_end.call_args_list == train_call_args + + lrf._update_description.assert_called() + assert lrf._update_description.call_count == iters + + # NaN break + mock_loss_return = (np.nan, np.nan) + trainer.train_one_batch = mocker.MagicMock(return_value=mock_loss_return) + + lrf._train() + + assert trainer.train_one_batch.call_count == 1 # Called once + + assert lrf._update_description.call_count == iters # Not called + assert lrf._on_batch_end.call_count == iters # Not called + + +def test_LearningRateFinder_rebuild_optimizer(_trainer_mock): + """ Test lib.train.LearingRateFinder._rebuild_optimizer """ + trainer, _, _ = _trainer_mock() + lrf = LearningRateFinder(trainer) + + class Dummy: + """ Dummy Optimizer""" + name = "test" + + def get_config(self): + """Dummy get_config""" + return {} + + opt = Dummy() + new_opt = lrf._rebuild_optimizer(opt) + assert isinstance(new_opt, Dummy) and opt is not new_opt + + +@pytest.mark.parametrize(_LR_CONF_PARAMS, _LR_CONF) +@pytest.mark.parametrize("new_lr", (1e-4, 3.5e-5, 9.3e-6)) +def test_LearningRateFinder_reset_model(iters, mode, strength, new_lr, _trainer_mock, mocker): + """ Test lib.train.LearingRateFinder._reset_model """ + trainer, model, optimizer = _trainer_mock(iters, mode, strength) + model.state.add_lr_finder = mocker.MagicMock() + model.state.save = mocker.MagicMock() + model.model.load_weights = mocker.MagicMock() + + old_optimizer = optimizer + new_optimizer = mocker.MagicMock() + + def compile_side_effect(*args, **kwargs): # pylint:disable=unused-argument + """ Side effect for model.compile""" + model.model.optimizer = new_optimizer + + model.model.compile.side_effect = compile_side_effect + + lrf = LearningRateFinder(trainer) + lrf._rebuild_optimizer = mocker.MagicMock() + + lrf._reset_model(1e-5, new_lr) + + model.state.add_lr_finder.assert_called_with(new_lr) + model.state.save.assert_called_once() + + if mode == "graph_and_exit": + lrf._rebuild_optimizer.assert_not_called() + model.model.compile.assert_not_called() + model.model.load_weights.assert_not_called() + assert model.model.optimizer is old_optimizer + new_optimizer.learning_rate.assign.assert_not_called() + else: + lrf._rebuild_optimizer.assert_called_once_with(old_optimizer) + model.model.load_weights.assert_called_once() + model.model.compile.assert_called_once() + assert model.model.optimizer is new_optimizer + new_optimizer.learning_rate.assign.assert_called_once_with(new_lr) + + +_LR_FIND = ( + (True, [0.100, 0.050, 0.025], 0.025, [1e-5, 1e-4, 1e-3], "model_exist"), + (False, [0.100, 0.050, 0.025], 0.025, [1e-5, 1e-4, 1e-3], "no_model"), + (True, [0.100, 0.050, 0.025], 0.025, [1e-5, 1e-4, 1e-10], "low_lr"), + ) +_LR_PARAMS_FIND = ("exists", "losses", "best", "learning_rates") + + +@pytest.mark.parametrize(_LR_PARAMS_FIND, + [x[:-1] for x in _LR_FIND], + ids=[x[-1] for x in _LR_FIND]) +@pytest.mark.parametrize(_LR_CONF_PARAMS, _LR_CONF) +@pytest.mark.parametrize(_LR_CMDS_PARAMS, _LR_CMDS[0:1]) +def test_LearningRateFinder_find(iters, # pylint:disable=too-many-arguments,too-many-positional-arguments # noqa[E501] + mode, + strength, + stop_factor, + beta, + exists, + losses, + best, + learning_rates, + _trainer_mock, + mocker): + """ Test lib.train.LearingRateFinder.find """ + # pylint:disable=too-many-locals + trainer, model, optimizer = _trainer_mock(iters, mode, strength) + model.io.model_exists = exists + model.io.save = mocker.MagicMock() + original_lr = float(np.random.rand()) + optimizer.learning_rate.numpy = mocker.MagicMock(return_value=original_lr) + optimizer.learning_rate.assign = mocker.MagicMock() + mocker.patch("shutil.rmtree") + + lrf = LearningRateFinder(trainer, stop_factor=stop_factor, beta=beta) + + train_mock = mocker.MagicMock() + plot_mock = mocker.MagicMock() + reset_mock = mocker.MagicMock() + lrf._train = train_mock + lrf._plot_loss = plot_mock + lrf._reset_model = reset_mock + + lrf._metrics = {"losses": losses, "learning_rates": learning_rates} + lrf._loss = {"best": best} + + result = lrf.find() + + if exists: + model.io.save_assert_not_called() + else: + model.io.save.assert_called_once() + + optimizer.learning_rate.assign.assert_called_with(lrf._start_lr) + train_mock.assert_called_once() + + new_lr = learning_rates[losses.index(best)] / _STRENGTH_LOOKUP[strength] + if new_lr < 1e-9: + plot_mock.assert_not_called() + reset_mock.assert_not_called() + assert not result + return + + plot_mock.assert_called_once() + reset_mock.assert_called_once_with(original_lr, new_lr) + assert result diff --git a/tests/lib/training/lr_warmup_test.py b/tests/lib/training/lr_warmup_test.py new file mode 100644 index 0000000000..c1150c8dae --- /dev/null +++ b/tests/lib/training/lr_warmup_test.py @@ -0,0 +1,181 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.training.lr_warmup` """ + +import pytest +import pytest_mock + +from keras.layers import Input, Dense +from keras.models import Model +from keras.optimizers import SGD + +from lib.training import LearningRateWarmup + + +# pylint:disable=protected-access,redefined-outer-name + + +@pytest.fixture +def model_fixture(): + """ Model fixture for testing LR Warmup """ + inp = Input((4, 4, 3)) + var_x = Dense(8)(inp) + model = Model(inputs=inp, outputs=var_x) + model.compile(optimizer=SGD(), loss="mse") + return model + + +_LR_STEPS = [(1e-5, 100), + (3.4e-6, 250), + (9e-4, 599), + (6e-5, 1000)] +_LR_STEPS_IDS = [f"lr:{x[0]}|steps:{x[1]}" for x in _LR_STEPS] + + +@pytest.mark.parametrize(("target_lr", "steps"), _LR_STEPS, ids=_LR_STEPS_IDS) +def test_init(model_fixture: Model, target_lr: float, steps: int) -> None: + """ Test class initializes correctly """ + instance = LearningRateWarmup(model_fixture, target_lr, steps) + + attrs = ["_model", "_target_lr", "_steps", "_current_lr", "_current_step", "_reporting_points"] + assert all(a in instance.__dict__ for a in attrs) + assert all(a in attrs for a in instance.__dict__) + assert instance._current_lr == 0.0 + assert instance._current_step == 0 + + assert isinstance(instance._model, Model) + assert instance._target_lr == target_lr + assert instance._steps == steps + + assert len(instance._reporting_points) == 11 + assert all(isinstance(x, int) for x in instance._reporting_points) + assert instance._reporting_points == [int(steps * i / 10) for i in range(11)] + + +_NOTATION = [(1e-5, "1.0e-05"), + (3.45489e-6, "3.5e-06"), + (0.0004, "4.0e-04"), + (0.1234, "1.2e-01")] + + +@pytest.mark.parametrize(("value", "expected"), _NOTATION, ids=[x[1] for x in _NOTATION]) +def test_format_notation(value: float, expected: str) -> None: + """ Test floats format to string correctly """ + result = LearningRateWarmup._format_notation(value) + assert result == expected + + +_LR_STEPS_CURRENT = [(1e-5, 100, 79), + (3.4e-6, 250, 250), + (9e-4, 599, 0), + (6e-5, 1000, 12)] +_LR_STEPS_CURRENT_IDS = [f"lr:{x[0]}|steps:{x[1]}|current_step:{x[2]}" for x in _LR_STEPS_CURRENT] + + +@pytest.mark.parametrize(("target_lr", "steps", "current_step"), + _LR_STEPS_CURRENT, + ids=_LR_STEPS_CURRENT_IDS) +def test_set_current_learning_rate(model_fixture: Model, + target_lr: float, + steps: int, + current_step: int) -> None: + """ Test that learning rate is set correctly """ + instance = LearningRateWarmup(model_fixture, target_lr, steps) + instance._current_step = current_step + instance._set_learning_rate() + + assert instance._current_lr == instance._current_step / instance._steps * instance._target_lr + assert instance._model.optimizer.learning_rate.value.cpu().numpy() == instance._current_lr + + +_STEPS_CURRENT = [(1000, 1, "start"), + (250, 250, "end"), + (500, 69, "unreported"), + (1000, 200, "reported")] +_STEPS_CURRENT_ID = [f"steps:{x[0]}|current_step:{x[1]}|action:{x[2]}" for x in _STEPS_CURRENT] + + +@pytest.mark.parametrize(("steps", "current_step", "action"), + _STEPS_CURRENT, + ids=_STEPS_CURRENT_ID) +def test_output_status(model_fixture: Model, + steps: int, + current_step: int, + action: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test that information is output correctly """ + mock_logger = mocker.patch("lib.training.lr_warmup.logger.info") + mock_print = mocker.patch("builtins.print") + instance = LearningRateWarmup(model_fixture, 5e-5, steps) + instance._current_step = current_step + instance._format_notation = mocker.MagicMock() # type:ignore[method-assign] + + instance._output_status() + + if action == "unreported": + assert current_step not in instance._reporting_points + mock_logger.assert_not_called() + instance._format_notation.assert_not_called() # type:ignore[attr-defined] + mock_print.assert_not_called() + return + + mock_logger.assert_called_once() + log_message: str = mock_logger.call_args.args[0] + assert log_message.startswith("[Learning Rate Warmup] ") + + instance._format_notation.assert_called() # type:ignore[attr-defined] + notation_args = [ + x.args for x in instance._format_notation.call_args_list] # type:ignore[attr-defined] + assert all(len(a) == 1 for a in notation_args) + assert all(isinstance(a[0], float) for a in notation_args) + + if action == "start": + mock_print.assert_not_called() + assert all(x in log_message for x in ("Start: ", "Target: ", "Steps: ")) + assert instance._format_notation.call_count == 2 # type:ignore[attr-defined] + return + + if action == "end": + mock_print.assert_called() + assert "Final Learning Rate: " in log_message + instance._format_notation.assert_called_once() # type:ignore[attr-defined] + return + + if action == "reported": + mock_print.assert_called() + assert current_step in instance._reporting_points + assert all(x in log_message for x in ("Step: ", "Current: ", "Target: ")) + assert instance._format_notation.call_count == 2 # type:ignore[attr-defined] + + +_STEPS_CURRENT_CALL = [(0, 500, "disabled"), + (1000, 500, "progress"), + (1000, 1000, "completed"), + (1000, 1111, "completed2")] +_STEPS_CURRENT_CALL_ID = [f"steps:{x[0]}|current_step:{x[1]}|action:{x[2]}" + for x in _STEPS_CURRENT_CALL] + + +@pytest.mark.parametrize(("steps", "current_step", "action"), + _STEPS_CURRENT_CALL, + ids=_STEPS_CURRENT_CALL_ID) +def test__call__(model_fixture: Model, + steps: int, + current_step: int, + action: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test calling the instance works correctly """ + instance = LearningRateWarmup(model_fixture, 5e-5, steps) + instance._current_step = current_step + instance._set_learning_rate = mocker.MagicMock() # type:ignore[method-assign] + instance._output_status = mocker.MagicMock() # type:ignore[method-assign] + + instance() + + if action in ("disabled", "completed", "completed2"): + assert instance._current_step == current_step + instance._set_learning_rate.assert_not_called() # type:ignore[attr-defined] + instance._output_status.assert_not_called() # type:ignore[attr-defined] + else: + assert instance._current_step == current_step + 1 + instance._set_learning_rate.assert_called_once() # type:ignore[attr-defined] + instance._output_status.assert_called_once() # type:ignore[attr-defined] diff --git a/tests/lib/training/tensorboard_test.py b/tests/lib/training/tensorboard_test.py new file mode 100644 index 0000000000..aab7a9bc51 --- /dev/null +++ b/tests/lib/training/tensorboard_test.py @@ -0,0 +1,166 @@ +#! /usr/env/bin/python3 +""" Unit test for :mod:`lib.training.tensorboard` """ +import os + +import pytest + +from keras import layers, Sequential +import numpy as np +from tensorboard.compat.proto import event_pb2 +from torch.utils.tensorboard import SummaryWriter + +from lib.training import tensorboard as mod_tb + +# pylint:disable=protected-access,invalid-name + + +@pytest.fixture() +def _gen_events_file(tmpdir): + log_dir = tmpdir.mkdir("logs") + + def _apply(keys=["test1"], # pylint:disable=dangerous-default-value + values=[0.42], + global_steps=[4]): + writer = SummaryWriter(log_dir) + for key, val, step in zip(keys, values, global_steps): + writer.add_scalar(key, val, global_step=step) + writer.flush() + return os.path.join(log_dir, os.listdir(log_dir)[0]) + + return _apply + + +@pytest.mark.parametrize("entries", ({"loss1": np.random.rand()}, + {f"test{i}": np.random.rand() for i in range(4)}, + {f"another_test{i}": np.random.rand() for i in range(10)})) +@pytest.mark.parametrize("batch", [1, 42, 69, 1024, 143432]) +@pytest.mark.parametrize("is_live", (True, False), ids=("live", "not_live")) +def test_RecordIterator(entries, batch, is_live, _gen_events_file): + """ Test that our :class:`lib.training.tensorboard.RecordIterator` returns expected results """ + keys = list(entries) + vals = list(entries.values()) + batches = [batch + i for i in range(len(keys))] + + file = _gen_events_file(keys, vals, batches) + iterator = mod_tb.RecordIterator(file, is_live=is_live) + + results = list(event_pb2.Event.FromString(v) for v in iterator) + valid = [r for r in results if r.summary.value] + + assert len(valid) == len(keys) + for entry, key, val, btc in zip(valid, keys, vals, batches): + assert len(entry.summary.value) == 1 + assert entry.step == btc + assert entry.summary.value[0].tag == key + assert np.isclose(entry.summary.value[0].simple_value, val) + + if is_live: + assert iterator._is_live is True + assert os.path.getsize(file) == iterator._position # At end of file + else: + assert iterator._is_live is False + assert iterator._position == 0 + + +@pytest.fixture() +def _get_ttb_instance(tmpdir): + log_dir = tmpdir.mkdir("logs") + + def _apply(write_graph=False, update_freq="batch"): + instance = mod_tb.TorchTensorBoard(log_dir=log_dir, + write_graph=write_graph, + update_freq=update_freq) + return log_dir, instance + + return _apply + + +def _get_logs(temp_path): + train_logs = os.path.join(temp_path, "train") + log_files = os.listdir(train_logs) + assert len(log_files) == 1 + records = [event_pb2.Event.FromString(record) + for record in mod_tb.RecordIterator(os.path.join(train_logs, log_files[0]))] + return records + + +@pytest.mark.parametrize("write_graph", (True, False), ids=("write_graph", "no_write_graph")) +def test_TorchTensorBoard_set_model(write_graph, _get_ttb_instance): + """ Test that :class:`lib.training.tensorboard.set_model` functions """ + log_dir, instance = _get_ttb_instance(write_graph=write_graph) + + model = Sequential() + model.add(layers.Input(shape=(8, ))) + model.add(layers.Dense(4)) + model.add(layers.Dense(4)) + + assert not os.path.exists(os.path.join(log_dir, "train")) + instance.set_model(model) + instance.on_save() + + logs = [x for x in _get_logs(os.path.join(log_dir)) + if x.summary.value] + + if not write_graph: + assert not logs + return + + # Only a single logged entry + assert len(logs) == 1 and len(logs[0].summary.value) == 1 + # Should be our Keras model summary + assert logs[0].summary.value[0].tag == "keras/text_summary" + + +def test_TorchTensorBoard_on_train_begin(_get_ttb_instance): + """ Test that :class:`lib.training.tensorboard.on_train_begin` functions """ + _, instance = _get_ttb_instance() + instance.on_train_begin() + assert instance._global_train_batch == 0 + assert instance._previous_epoch_iterations == 0 + + +@pytest.mark.parametrize("batch", (1, 3, 57, 124)) +@pytest.mark.parametrize("logs", ({"loss_a": 2.45, "loss_b": 1.56}, + {"loss_c": 0.54, "loss_d": 0.51}, + {"loss_c": 0.69, "loss_d": 0.42, "loss_g": 2.69})) +def test_TorchTensorBoard_on_train_batch_end(batch, logs, _get_ttb_instance): + """ Test that :class:`lib.training.tensorboard.on_train_batch_end` functions """ + log_dir, instance = _get_ttb_instance() + + assert not os.path.exists(os.path.join(log_dir, "train")) + + instance.on_train_batch_end(batch, logs) + instance.on_save() + + tb_logs = [x for x in _get_logs(os.path.join(log_dir)) + if x.summary.value] + + assert len(tb_logs) == len(logs) + for (k, v), out in zip(logs.items(), tb_logs): + assert len(out.summary.value) == 1 + assert out.summary.value[0].tag == f"batch_{k}" + assert np.isclose(out.summary.value[0].simple_value, v) + assert out.step == batch + + +def test_TorchTensorBoard_on_save(_get_ttb_instance, mocker): + """ Test that :class:`lib.training.tensorboard.on_save` functions """ + # Implicitly checked in other tests, so just make sure it calls flush on the writer + _, instance = _get_ttb_instance() + instance._train_writer.flush = mocker.MagicMock() + + instance.on_save() + instance._train_writer.flush.assert_called_once() + + +def test_TorchTensorBoard_on_train_end(_get_ttb_instance, mocker): + """ Test that :class:`lib.training.tensorboard.on_train_end` functions """ + # Saving is already implicitly checked in other tests, so just make sure it calls flush and + # close on the train writer + _, instance = _get_ttb_instance() + instance._train_writer.flush = mocker.MagicMock() + instance._train_writer.close = mocker.MagicMock() + + instance.on_train_end() + instance._train_writer.flush.assert_called_once() + instance._train_writer.close.assert_called_once() diff --git a/tests/lib/utils_test.py b/tests/lib/utils_test.py new file mode 100644 index 0000000000..aa4a8669db --- /dev/null +++ b/tests/lib/utils_test.py @@ -0,0 +1,641 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`lib.utils` """ +import os +import platform +import sys +import time +import typing as T +import types +import zipfile + +from io import StringIO +from socket import timeout as socket_timeout, error as socket_error +from shutil import rmtree +from unittest.mock import MagicMock +from urllib import error as urlliberror + +import pytest +import pytest_mock + +from lib import utils +from lib.utils import ( + _Backend, camel_case_split, convert_to_secs, DebugTimes, deprecation_warning, FaceswapError, + full_path_split, get_backend, get_dpi, get_folder, get_image_paths, get_module_objects, + get_torch_version, GetModel, safe_shutdown, set_backend) + +from lib.logger import log_setup +# Need to setup logging to avoid trace/verbose errors +log_setup("DEBUG", "pytest_utils.log", "PyTest, False") + + +# pylint:disable=protected-access + + +# Backend tests +def test_set_backend(monkeypatch: pytest.MonkeyPatch) -> None: + """ Test the :func:`~lib.utils.set_backend` function + + Parameters + ---------- + monkeypatch: :class:`pytest.MonkeyPatch` + Monkey patching _FS_BACKEND + """ + monkeypatch.setattr(utils, "_FS_BACKEND", "cpu") # _FS_BACKEND already defined + set_backend("nvidia") + assert utils._FS_BACKEND == "nvidia" + monkeypatch.delattr(utils, "_FS_BACKEND") # _FS_BACKEND is not already defined + set_backend("rocm") + assert utils._FS_BACKEND == "rocm" + + +def test_get_backend(monkeypatch: pytest.MonkeyPatch) -> None: + """ Test the :func:`~lib.utils.get_backend` function + + Parameters + ---------- + monkeypatch: :class:`pytest.MonkeyPatch` + Monkey patching _FS_BACKEND + """ + monkeypatch.setattr(utils, "_FS_BACKEND", "apple-silicon") + assert get_backend() == "apple-silicon" + + +def test__backend(monkeypatch: pytest.MonkeyPatch) -> None: + """ Test the :class:`~lib.utils._Backend` class + + Parameters + ---------- + monkeypatch: :class:`pytest.MonkeyPatch` + Monkey patching :func:`os.environ`, :func:`os.path.isfile`, :func:`builtins.open` and + :func:`builtins.input` + """ + monkeypatch.setattr("os.environ", {"FACESWAP_BACKEND": "nvidia"}) # Environment variable set + backend = _Backend() + assert backend.backend == "nvidia" + + monkeypatch.setattr("os.environ", {}) # Environment variable not set, dummy in config file + monkeypatch.setattr("os.path.isfile", lambda x: True) + monkeypatch.setattr("builtins.open", lambda *args, **kwargs: StringIO('{"backend": "cpu"}')) + backend = _Backend() + assert backend.backend == "cpu" + + monkeypatch.setattr("os.path.isfile", lambda x: False) # no config file, dummy in user input + monkeypatch.setattr("builtins.input", lambda x: "2") + backend = _Backend() + assert backend._configure_backend() == "nvidia" + + +# Folder and path utils +def test_get_folder(tmp_path: str) -> None: + """ Unit test for :func:`~lib.utils.get_folder` + + Parameters + ---------- + tmp_path: str + pytest temporary path to generate folders + """ + # New folder + path = os.path.join(tmp_path, "test_new_folder") + expected_output = path + assert not os.path.isdir(path) + assert get_folder(path) == expected_output + assert os.path.isdir(path) + + # Test not creating a new folder when it already exists + path = os.path.join(tmp_path, "test_new_folder") + expected_output = path + assert os.path.isdir(path) + stats = os.stat(path) + assert get_folder(path) == expected_output + assert os.path.isdir(path) + assert stats == os.stat(path) + + # Test not creating a new folder when make_folder is False + path = os.path.join(tmp_path, "test_no_folder") + expected_output = "" + assert get_folder(path, make_folder=False) == expected_output + assert not os.path.isdir(path) + + +def test_get_image_paths(tmp_path: str) -> None: + """ Unit test for :func:`~lib.utils.test_get_image_paths` + + Parameters + ---------- + tmp_path: str + pytest temporary path to generate folders + """ + # Test getting image paths from a folder with no images + test_folder = os.path.join(tmp_path, "test_image_folder") + os.makedirs(test_folder) + assert not get_image_paths(test_folder) + + # Populate 2 different image files and 1 text file + test_jpg_path = os.path.join(test_folder, "test_image.jpg") + test_png_path = os.path.join(test_folder, "test_image.png") + test_txt_path = os.path.join(test_folder, "test_file.txt") + for fname in (test_jpg_path, test_png_path, test_txt_path): + with open(fname, "a", encoding="utf-8"): + pass + + # Test getting any image paths from a folder with images and random files + exists = [os.path.join(test_folder, img) + for img in os.listdir(test_folder) if os.path.splitext(img)[-1] != ".txt"] + assert sorted(get_image_paths(test_folder)) == sorted(exists) + + # Test getting image paths from a folder with images with a specific extension + exists = [os.path.join(test_folder, img) + for img in os.listdir(test_folder) if os.path.splitext(img)[-1] == ".png"] + assert sorted(get_image_paths(test_folder, extension=".png")) == sorted(exists) + + +def test_get_module_objects(mocker: pytest_mock.MockerFixture): + """ Test :func:`lib.utils.get_module_objects` returns as expected """ + # pylint:disable=too-few-public-methods,missing-class-docstring + test_module = types.ModuleType("our_mod") + + class InternalPublic: + pass + InternalPublic.__module__ = "our_mod" + setattr(test_module, "InternalPublic", InternalPublic) + + class _InternalPrivate: + pass + _InternalPrivate.__module__ = "our_mod" + setattr(test_module, "_InternalPrivate", _InternalPrivate) + + class External: + pass + External.__module__ = "other_mod" + setattr(test_module, "External", External) + + def func_public(): + pass + func_public.__module__ = "our_mod" + setattr(test_module, "func_public", func_public) + + def _func_private(): + pass + _func_private.__module__ = "our_mod" + setattr(test_module, "_func_private", _func_private) + + def func_external(): + pass + func_external.__module__ = "other_mod" + setattr(test_module, "func_external", func_external) + + mocker.patch.dict(sys.modules, {"our_mod": test_module}) + + result = get_module_objects("our_mod") + assert sorted(result, key=str.casefold) == ["func_public", "InternalPublic"] + + +_PATHS = ( # type:ignore[var-annotated] + ("/path/to/file.txt", ["/", "path", "to", "file.txt"]), # Absolute + ("/path/to/directory/", ["/", "path", "to", "directory"]), + ("/path/to/directory", ["/", "path", "to", "directory"]), + ("path/to/file.txt", ["path", "to", "file.txt"]), # Relative + ("path/to/directory/", ["path", "to", "directory"]), + ("path/to/directory", ["path", "to", "directory"]), + ("", []), # Edge cases + ("/", ["/"]), + (".", ["."]), + ("..", [".."])) + + +@pytest.mark.parametrize("path,result", _PATHS, ids=[f'"{p[0]}"' for p in _PATHS]) +def test_full_path_split(path: str, result: list[str]) -> None: + """ Test the :func:`~lib.utils.full_path_split` function works correctly + + Parameters + ---------- + path: str + The path to test + result: list + The expected result from the path + """ + split = full_path_split(path) + assert isinstance(split, list) + assert split == result + + +_CASES = (("camelCase", ["camel", "Case"]), # type:ignore[var-annotated] + ("camelCaseTest", ["camel", "Case", "Test"]), + ("camelCaseTestCase", ["camel", "Case", "Test", "Case"]), + ("CamelCase", ["Camel", "Case"]), + ("CamelCaseTest", ["Camel", "Case", "Test"]), + ("CamelCaseTestCase", ["Camel", "Case", "Test", "Case"]), + ("CAmelCASETestCase", ["C", "Amel", "CASE", "Test", "Case"]), + ("camelcasetestcase", ["camelcasetestcase"]), + ("CAMELCASETESTCASE", ["CAMELCASETESTCASE"]), + ("", [])) + + +@pytest.mark.parametrize("text, result", _CASES, ids=[f'"{p[0]}"' for p in _CASES]) +def test_camel_case_split(text: str, result: list[str]) -> None: + """ Test the :func:`~lib.utils.camel_case_spli` function works correctly + + Parameters + ---------- + text: str + The camel case text to test + result: list + The expected result from the path + """ + split = camel_case_split(text) + assert isinstance(split, list) + assert split == result + + +_TORCH_PARAMS = (("2.4.9", (2, 4)), ("2.6", (2, 6)), ("2.8.rc3", (2, 8))) +_TORCH_IDS = [x[0] for x in _TORCH_PARAMS] + + +# General utils +@pytest.mark.parametrize("str_vers, tuple_vers", _TORCH_PARAMS, ids=_TORCH_IDS) +def test_get_torch_version(str_vers, tuple_vers, monkeypatch: pytest.MonkeyPatch) -> None: + """ Test the :func:`~lib.utils.get_torch_version` function version returns correctly """ + monkeypatch.setattr("lib.utils._versions", {}) + monkeypatch.setattr("torch.__version__", str_vers) + torch_version = get_torch_version() + assert torch_version == tuple_vers + + +def test_get_dpi() -> None: + """ Test the :func:`~lib.utils.get_dpi` function version returns correctly in a sane + range """ + dpi = get_dpi() + assert isinstance(dpi, float) or dpi is None + if dpi is None: # No display detected + return + assert dpi > 0 + assert dpi < 600.0 + + +_SECPARAMS = [((1, ), 1), # 1 argument + ((10, ), 10), + ((0, 1), 1), + ((0, 60), 60), # 2 arguments + ((1, 0), 60), + ((1, 1), 61), + ((0, 0, 1), 1), + ((0, 0, 60), 60), # 3 arguments + ((0, 1, 0), 60), + ((1, 0, 0), 3600), + ((1, 1, 1), 3661)] + + +@pytest.mark.parametrize("args,result", _SECPARAMS, ids=[str(p[0]) for p in _SECPARAMS]) +def test_convert_to_secs(args: tuple[int, ...], result: int) -> None: + """ Test the :func:`~lib.utils.convert_to_secs` function works correctly + + Parameters + ---------- + args: tuple + Tuple of 1, 2 or 3 integers to pass to the function + result: int + The expected results for the args tuple + """ + secs = convert_to_secs(*args) + assert isinstance(secs, int) + assert secs == result + + +@pytest.mark.parametrize("additional_info", [None, "additional information"]) +def test_deprecation_warning(caplog: pytest.LogCaptureFixture, additional_info: str) -> None: + """ Test the :func:`~lib.utils.deprecation_warning` function works correctly + + Parameters + ---------- + caplog: :class:`pytest.LogCaptureFixture` + Pytest's log capturing fixture + additional_info: str + Additional information to pass to the warning function + """ + func_name = "function_name" + test = f"{func_name} has been deprecated and will be removed from a future update." + if additional_info: + test = f"{test} {additional_info}" + deprecation_warning(func_name, additional_info=additional_info) + assert test in caplog.text + + +@pytest.mark.parametrize("got_error", [True, False]) +def test_safe_shutdown(caplog: pytest.LogCaptureFixture, got_error: bool) -> None: + """ Test the :func:`~lib.utils.safe_shutdown` function works correctly + + Parameters + ---------- + caplog: :class:`pytest.LogCaptureFixture` + Pytest's log capturing fixture + got_error: bool + The got_error parameter to pass to safe_shutdown + """ + caplog.set_level("DEBUG") + with pytest.raises(SystemExit) as wrapped_exit: + safe_shutdown(got_error=got_error) + + exit_value = 1 if got_error else 0 + assert wrapped_exit.typename == "SystemExit" + assert wrapped_exit.value.code == exit_value + assert "Safely shutting down" in caplog.messages + assert "Cleanup complete. Shutting down queue manager and exiting" in caplog.messages + + +def test_faceswap_error(): + """ Test the :class:`~lib.utils.FaceswapError` raises correctly """ + with pytest.raises(Exception): + raise FaceswapError + + +# GetModel class +@pytest.fixture(name="get_model_instance") +def fixture_get_model_instance(monkeypatch: pytest.MonkeyPatch, + tmp_path: pytest.TempdirFactory, + request: pytest.FixtureRequest) -> GetModel: + """ Create a fixture of the :class:`~lib.utils.GetModel` object, prevent _get() from running at + __init__ and point the cache_dir at our local test folder """ + cache_dir = os.path.join(str(tmp_path), "get_model") + os.mkdir(cache_dir) + + model_filename = "test_model_file_v1.h5" + git_model_id = 123 + + original_get = GetModel._get + # Patch out _get() so it is not called from __init__() + monkeypatch.setattr(utils.GetModel, "_get", lambda x: None) + model_instance = GetModel(model_filename, git_model_id) + # Reinsert _get() so we can test it + monkeypatch.setattr(model_instance, "_get", original_get) + model_instance._cache_dir = cache_dir + + def teardown(): + rmtree(cache_dir) + + request.addfinalizer(teardown) + return model_instance + + +_INPUT = ("test_model_file_v3.h5", + ["test_multi_model_file_v1.1.npy", "test_multi_model_file_v1.2.npy"]) +_EXPECTED = ((["test_model_file_v3.h5"], "test_model_file_v3", "test_model_file", 3), + (["test_multi_model_file_v1.1.npy", "test_multi_model_file_v1.2.npy"], + "test_multi_model_file_v1", "test_multi_model_file", 1)) + + +@pytest.mark.parametrize("filename,results", zip(_INPUT, _EXPECTED), ids=[str(i) for i in _INPUT]) +def test_get_model_model_filename_input( + get_model_instance: GetModel, # pylint:disable=unused-argument + filename: str | list[str], + results: str | list[str]) -> None: + """ Test :class:`~lib.utils.GetModel` filename parsing works + + Parameters + --------- + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + filename: list or str + The test filenames + results: tuple + The expected results for :attr:`_model_filename`, :attr:`_model_full_name`, + :attr:`_model_name`, :attr:`_model_version` respectively + """ + model = GetModel(filename, 123) + assert model._model_filename == results[0] + assert model._model_full_name == results[1] + assert model._model_name == results[2] + assert model._model_version == results[3] + + +def test_get_model_attributes(get_model_instance: GetModel) -> None: + """ Test :class:`~lib.utils.GetModel` private attributes set correctly + + Parameters + --------- + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + """ + model = get_model_instance + assert model._git_model_id == 123 + assert model._url_base == ("https://github.com/deepfakes-models/faceswap-models" + "/releases/download") + assert model._chunk_size == 1024 + assert model._retries == 6 + + +def test_get_model_properties(get_model_instance: GetModel) -> None: + """ Test :class:`~lib.utils.GetModel` calculated attributes return correctly + + Parameters + --------- + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + """ + model = get_model_instance + assert model.model_path == os.path.join(model._cache_dir, "test_model_file_v1.h5") + assert model._model_zip_path == os.path.join(model._cache_dir, "test_model_file_v1.zip") + assert not model._model_exists + assert model._url_download == ("https://github.com/deepfakes-models/faceswap-models/releases/" + "download/v123.1/test_model_file_v1.zip") + assert model._url_partial_size == 0 + + +@pytest.mark.parametrize("model_exists", (True, False)) +def test_get_model__get(mocker: pytest_mock.MockerFixture, + get_model_instance: GetModel, + model_exists: bool) -> None: + """ Test :func:`~lib.utils.GetModel._get` executes logic correctly + + Parameters + --------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + model_exists: bool + For testing the function when a model exists and when it does not + """ + model = get_model_instance + model._download_model = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + model._unzip_model = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + os_remove = mocker.patch("os.remove") + + if model_exists: # Dummy in a model file + assert isinstance(model.model_path, str) + with open(model.model_path, "a", encoding="utf-8"): + pass + + model._get(model) # type:ignore + + assert (model_exists and not model._download_model.called) or ( + not model_exists and model._download_model.called) + assert (model_exists and not model._unzip_model.called) or ( + not model_exists and model._unzip_model.called) + assert model_exists or not (model_exists and os_remove.called) + os_remove.reset_mock() + + +_DLPARAMS = [(None, None), + (socket_error, ()), + (socket_timeout, ()), + (urlliberror.URLError, ("test_reason", )), + (urlliberror.HTTPError, ("test_uri", 400, "", "", 0))] + + +@pytest.mark.parametrize("error_type,error_args", _DLPARAMS, ids=[str(p[0]) for p in _DLPARAMS]) +def test_get_model__download_model(mocker: pytest_mock.MockerFixture, + get_model_instance: GetModel, + error_type: T.Any, + error_args: tuple[str | int, ...]) -> None: + """ Test :func:`~lib.utils.GetModel._download_model` executes its logic correctly + + Parameters + --------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + error_type: connection error type or ``None`` + Connection error type to mock, or ``None`` for succesful download + error_args: tuple + The arguments to be passed to the exception to be raised + """ + mock_urlopen = mocker.patch("urllib.request.urlopen") + if not error_type: # Model download is successful + get_model_instance._write_zipfile = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + get_model_instance._download_model() + assert mock_urlopen.called + assert get_model_instance._write_zipfile.called + else: # Test that the process exits on download errors + mock_urlopen.side_effect = error_type(*error_args) + with pytest.raises(SystemExit): + get_model_instance._download_model() + mock_urlopen.reset_mock() + + +# TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +@pytest.mark.parametrize("dl_type", ["complete", "new", "continue"]) +def test_get_model__write_zipfile(mocker: pytest_mock.MockerFixture, + get_model_instance: GetModel, + dl_type: str) -> None: + """ Test :func:`~lib.utils.GetModel._write_zipfile` executes its logic correctly """ + response = mocker.MagicMock() + assert not os.path.isfile(get_model_instance._model_zip_path) + + downloaded = 10 if dl_type == "complete" else 0 + response.getheader.return_value = 0 + + if dl_type in ("new", "continue"): + chunks = [32, 64, 128, 256, 512, 1024] + data = [b"\x00" * size for size in chunks] + [b""] + response.getheader.return_value = sum(chunks) + response.read.side_effect = data + + if dl_type == "continue": # Write a partial download of the correct size + with open(get_model_instance._model_zip_path, "wb") as partial: + partial.write(b"\x00" * sum(chunks)) # type:ignore + downloaded = os.path.getsize(get_model_instance._model_zip_path) + + get_model_instance._write_zipfile(response, downloaded) + + if dl_type == "complete": # Already downloaded. No more tests + assert not response.read.called + return + + assert response.read.call_count == len(data) # all data read # type:ignore + assert os.path.isfile(get_model_instance._model_zip_path) + downloaded_size = os.path.getsize(get_model_instance._model_zip_path) + downloaded_size = downloaded_size if dl_type == "new" else downloaded_size // 2 + assert downloaded_size == sum(chunks) # type:ignore + + +# TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +def test_get_model__unzip_model(mocker: pytest_mock.MockerFixture, + get_model_instance: GetModel) -> None: + """ Test :func:`~lib.utils.GetModel._unzip_model` executes its logic correctly + + Parameters + --------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + """ + mock_zipfile = mocker.patch("zipfile.ZipFile") + # Successful + get_model_instance._unzip_model() + assert mock_zipfile.called + mock_zipfile.reset_mock() + # Error + mock_zipfile.side_effect = zipfile.BadZipFile() + with pytest.raises(SystemExit): + get_model_instance._unzip_model() + mock_zipfile.reset_mock() + + +def test_get_model__write_model(mocker: pytest_mock.MockerFixture, + get_model_instance: GetModel) -> None: + """ Test :func:`~lib.utils.GetModel._write_model` executes its logic correctly + + Parameters + --------- + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in function calls + get_model_instance: `~lib.utils.GetModel` + The patched instance of the class + """ + out_file = os.path.join(get_model_instance._cache_dir, get_model_instance._model_filename[0]) + chunks = [8, 16, 32, 64, 128, 256, 512, 1024] + data = [b"\x00" * size for size in chunks] + [b""] + assert not os.path.isfile(out_file) + mock_zipfile = mocker.patch("zipfile.ZipFile") + mock_zipfile.namelist.return_value = get_model_instance._model_filename + mock_zipfile.open.return_value = mock_zipfile + mock_zipfile.read.side_effect = data + get_model_instance._write_model(mock_zipfile) + assert mock_zipfile.read.call_count == len(data) + assert os.path.isfile(out_file) + assert os.path.getsize(out_file) == sum(chunks) + + +# DebugTimes class +def test_debug_times(): + """ Test :class:`~lib.utils.DebugTimes` executes its logic correctly """ + debug_times = DebugTimes() + + debug_times.step_start("Test1") + time.sleep(0.1) + debug_times.step_end("Test1") + + debug_times.step_start("Test2") + time.sleep(0.2) + debug_times.step_end("Test2") + + debug_times.step_start("Test1") + time.sleep(0.1) + debug_times.step_end("Test1") + + debug_times.summary() + + # Ensure that the summary method prints the min, mean, and max times for each step + assert debug_times._display["min"] is True + assert debug_times._display["mean"] is True + assert debug_times._display["max"] is True + + # Ensure that the summary method includes the correct number of items for each step + assert len(debug_times._times["Test1"]) == 2 + assert len(debug_times._times["Test2"]) == 1 + + # Ensure that the summary method includes the correct min, mean, and max times for each step + # Github workflow for macos-latest can swing out a fair way + threshold = 2e-1 if platform.system() == "Darwin" else 1e-1 + assert min(debug_times._times["Test1"]) == pytest.approx(0.1, abs=threshold) + assert min(debug_times._times["Test2"]) == pytest.approx(0.2, abs=threshold) + assert max(debug_times._times["Test1"]) == pytest.approx(0.1, abs=threshold) + assert max(debug_times._times["Test2"]) == pytest.approx(0.2, abs=threshold) + assert (sum(debug_times._times["Test1"]) / + len(debug_times._times["Test1"])) == pytest.approx(0.1, abs=threshold) + assert (sum(debug_times._times["Test2"]) / + len(debug_times._times["Test2"]) == pytest.approx(0.2, abs=threshold)) diff --git a/tests/plugins/__init.__.py b/tests/plugins/__init.__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/train/__init__.py b/tests/plugins/train/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/train/trainer/__init__.py b/tests/plugins/train/trainer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/train/trainer/test_distributed.py b/tests/plugins/train/trainer/test_distributed.py new file mode 100644 index 0000000000..0d913377b5 --- /dev/null +++ b/tests/plugins/train/trainer/test_distributed.py @@ -0,0 +1,138 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`plugins.train.trainer.distributed` Trainer plug in """ +# pylint:disable=protected-access, invalid-name + +import numpy as np +import pytest +import pytest_mock +import torch + +from plugins.train.trainer import distributed as mod_distributed +from plugins.train.trainer import original as mod_original +from plugins.train.trainer import _base as mod_base + + +_MODULE_PREFIX = "plugins.train.trainer.distributed" + + +@pytest.mark.parametrize("batch_size", (4, 8, 16, 32, 64)) +@pytest.mark.parametrize("outputs", (1, 2, 4)) +def test_WrappedModel(batch_size, outputs, mocker): + """ Test that the wrapped model calls preds and loss """ + model = mocker.MagicMock() + instance = mod_distributed.WrappedModel(model) + assert instance._keras_model is model + + loss_return = [torch.from_numpy((np.random.random((1, )))) for _ in range(outputs * 2)] + model.loss = [mocker.MagicMock(return_value=ret) for ret in loss_return] + + test_dims = (batch_size, 16, 16, 3) + + inp_a = torch.from_numpy(np.random.random(test_dims)) + inp_b = torch.from_numpy(np.random.random(test_dims)) + targets = [torch.from_numpy(np.random.random(test_dims)) + for _ in range(outputs * 2)] + preds = [*torch.from_numpy(np.random.random((outputs * 2, *test_dims)))] + + model.return_value = preds + + # Call forwards + result = instance.forward(inp_a, inp_b, *targets) + + # Confirm model was called once forward with correct args + model.assert_called_once() + model_args, model_kwargs = model.call_args + assert model_kwargs == {"training": True} + assert len(model_args) == 1 + assert len(model_args[0]) == 2 + for real, expected in zip(model_args[0], [inp_a, inp_b]): + assert np.allclose(real.numpy(), expected.numpy()) + + # Confirm ZeroGrad called + model.zero_grad.assert_called_once() + + # Confirm loss functions correctly called + expected_targets = targets[0::2] + targets[1::2] + + for target, pred, loss in zip(expected_targets, preds, model.loss): + loss.assert_called_once() + loss_args, loss_kwargs = loss.call_args + assert not loss_kwargs + assert len(loss_args) == 2 + for actual, expected in zip(loss_args, [target, pred]): + assert np.allclose(actual.numpy(), expected.numpy()) + + # Check that the result comes out as we put it in + for expected, actual in zip(loss_return, result.squeeze()): + assert np.isclose(expected.numpy(), actual.numpy()) + + +@pytest.fixture +def _trainer_mocked(mocker: pytest_mock.MockFixture): # noqa:[F811] + """ Generate a mocked model and feeder object and patch torch GPU count """ + + def _apply_patch(gpus=2, batch_size=8): + patched_cuda_device = mocker.patch(f"{_MODULE_PREFIX}.torch.cuda.device_count") + patched_cuda_device.return_value = gpus + patched_parallel = mocker.patch(f"{_MODULE_PREFIX}.torch.nn.DataParallel") + patched_parallel.return_value = mocker.MagicMock() + model = mocker.MagicMock() + instance = mod_distributed.Trainer(model, batch_size) + return instance, patched_parallel + + return _apply_patch + + +@pytest.mark.parametrize("gpu_count", (2, 3, 5, 8)) +@pytest.mark.parametrize("batch_size", (4, 8, 16, 32, 64)) +def test_Trainer(gpu_count, batch_size, _trainer_mocked): + """ Test that original trainer creates correctly """ + instance, patched_parallel = _trainer_mocked(gpus=gpu_count, batch_size=batch_size) + assert isinstance(instance, mod_base.TrainerBase) + assert isinstance(instance, mod_original.Trainer) + # Confirms that _validate_batch_size executed correctly + assert instance.batch_size == max(gpu_count, batch_size) + assert hasattr(instance, "train_batch") + # Confirms that _set_distributed executed correctly + assert instance._distributed_model is patched_parallel.return_value + + +@pytest.mark.parametrize("gpu_count", (2, 3, 5, 8), ids=[f"gpus:{x}" for x in (2, 3, 5, 8)]) +@pytest.mark.parametrize("outputs", (1, 2, 4)) +@pytest.mark.parametrize("batch_size", (4, 8, 16, 32, 64)) +def test_Trainer_forward(gpu_count, batch_size, outputs, _trainer_mocked, mocker): + """ Test that original trainer _forward calls the correct model methods """ + instance, _ = _trainer_mocked(gpus=gpu_count, batch_size=batch_size) + + test_dims = (2, batch_size, 16, 16, 3) + + inputs = torch.from_numpy(np.random.random(test_dims)) + targets = [torch.from_numpy(np.random.random(test_dims)) for _ in range(outputs)] + + loss_return = torch.rand((gpu_count * 2 * outputs)) + instance._distributed_model = mocker.MagicMock(return_value=loss_return) + + # Call the forward pass + result = instance._forward(inputs, targets).cpu().numpy() + + # Make sure multi-outs are enabled + if outputs > 1: + assert instance._is_multi_out is True + else: + assert instance._is_multi_out is False + + # Make sure that our wrapped distributed model was called in the correct order + instance._distributed_model.assert_called_once() + call_args, call_kwargs = instance._distributed_model.call_args + assert not call_kwargs + assert len(call_args) == len(inputs) + (len(targets) * 2) + + expected_tgts = [t[i].cpu().numpy() for t in targets for i in range(2)] + + for expected, actual in zip([*inputs, *expected_tgts], call_args): + assert np.allclose(expected, actual) + + # Make sure loss gets grouped, summed and scaled correctly + expected = loss_return.cpu().numpy() + expected = expected.reshape((gpu_count, 2, -1)).sum(axis=0).flatten() / gpu_count + assert np.allclose(result, expected) diff --git a/tests/plugins/train/trainer/test_original.py b/tests/plugins/train/trainer/test_original.py new file mode 100644 index 0000000000..983e691948 --- /dev/null +++ b/tests/plugins/train/trainer/test_original.py @@ -0,0 +1,122 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`plugins.train.trainer.original` Trainer plug in """ +# pylint:disable=protected-access,invalid-name + +import numpy as np +import pytest +import pytest_mock +import torch + +from plugins.train.trainer import original as mod_original +from plugins.train.trainer import _base as mod_base + + +@pytest.fixture +def _trainer_mocked(mocker: pytest_mock.MockFixture): # noqa:[F811] + """ Generate a mocked model and feeder object and patch user config items """ + + def _apply_patch(batch_size=8): + model = mocker.MagicMock() + instance = mod_original.Trainer(model, batch_size) + return instance + + return _apply_patch + + +@pytest.mark.parametrize("batch_size", (4, 8, 16, 32, 64)) +def test_Trainer(batch_size, _trainer_mocked): + """ Test that original trainer creates correctly """ + instance = _trainer_mocked(batch_size=batch_size) + assert isinstance(instance, mod_base.TrainerBase) + assert instance.batch_size == batch_size + assert hasattr(instance, "train_batch") + + +def test_Trainer_train_batch(_trainer_mocked, mocker): + """ Test that original trainer calls the forward and backwards methods """ + instance = _trainer_mocked() + loss_return = float(np.random.rand()) + instance._forward = mocker.MagicMock(return_value=loss_return) + instance._backwards_and_apply = mocker.MagicMock() + + ret_val = instance.train_batch("TEST_INPUT", "TEST_TARGET") + + assert ret_val == loss_return + instance._forward.assert_called_once_with("TEST_INPUT", "TEST_TARGET") + instance._backwards_and_apply.assert_called_once_with(loss_return) + + +@pytest.mark.parametrize("outputs", (1, 2, 4)) +@pytest.mark.parametrize("batch_size", (4, 8, 16, 32, 64)) +def test_Trainer_forward(batch_size, # pylint:disable=too-many-locals + outputs, + _trainer_mocked, + mocker): + """ Test that original trainer _forward calls the correct model methods """ + instance = _trainer_mocked(batch_size=batch_size) + + loss_returns = [torch.from_numpy(np.random.random((1, ))) for _ in range(outputs * 2)] + mock_preds = [torch.from_numpy(np.random.random((batch_size, 16, 16, 3))) + for _ in range(outputs * 2)] + instance.model.model.return_value = mock_preds + instance.model.model.zero_grad = mocker.MagicMock() + instance.model.model.loss = [mocker.MagicMock(return_value=ret) for ret in loss_returns] + + inputs = torch.from_numpy(np.random.random((2, batch_size, 16, 16, 3))) + targets = [torch.from_numpy(np.random.random((2, batch_size, 16, 16, 3))) + for _ in range(outputs)] + + # Call forwards + result = instance._forward(inputs, targets) + + # Output comes from loss functions + assert (np.allclose(e.numpy(), a.numpy()) for e, a in zip(result, loss_returns)) + + # Model was zero'd + instance.model.model.zero_grad.assert_called_once() + + # model forward pass called with inputs split + train_call = instance.model.model + + call_args, call_kwargs = train_call.call_args + assert call_kwargs == {"training": True} + expected_inputs = [a.numpy() for a in inputs] + actual_inputs = [a.numpy() for a in call_args[0]] + assert (np.allclose(e, a) for e, a in zip(expected_inputs, actual_inputs)) + + # losses called with targets split + loss_calls = instance.model.model.loss + expected_targets = [t[i].numpy() for i in range(2) for t in targets] + expected_preds = [p.numpy() for p in mock_preds] + for loss_call, pred, target in zip(loss_calls, expected_preds, expected_targets): + loss_call.assert_called_once() + call_args, call_kwargs = loss_call.call_args + assert not call_kwargs + assert len(call_args) == 2 + + actual_target = call_args[0].numpy() + actual_pred = call_args[1].numpy() + assert np.allclose(pred, actual_pred) + assert np.allclose(target, actual_target) + + +def test_Trainer_backwards_and_apply(_trainer_mocked, mocker): + """ Test that original trainer _backwards_and_apply calls the correct model methods """ + instance = _trainer_mocked() + + mock_loss = mocker.MagicMock() + instance.model.model.optimizer.scale_loss = mocker.MagicMock(return_value=mock_loss) + instance.model.model.optimizer.app = mocker.MagicMock(return_value=mock_loss) + + all_loss = np.random.rand() + instance._backwards_and_apply(all_loss) + + scale_mock = instance.model.model.optimizer.scale_loss + scale_mock.assert_called_once() + assert not scale_mock.call_args[1] + assert len(scale_mock.call_args[0]) == 1 + assert np.isclose(all_loss, scale_mock.call_args[0][0].cpu().numpy()) + + mock_loss.backward.assert_called_once() + + instance.model.model.optimizer.apply.assert_called_once() diff --git a/tests/simple_tests.py b/tests/simple_tests.py new file mode 100644 index 0000000000..5db6109a29 --- /dev/null +++ b/tests/simple_tests.py @@ -0,0 +1,213 @@ +""" +Contains some simple tests. +The purpose of this tests is to detect crashes and hangs +but NOT to guarantee the corectness of the operations. +For this we want another set of testcases using pytest. + +Due to my lazy coding, DON'T USE PATHES WITH BLANKS ! +""" + +import sys +from subprocess import check_call, CalledProcessError +import os +from os.path import join as pathjoin, abspath, dirname + +_fail_count = 0 +_test_count = 0 +_COLORS = { + "FAIL": "\033[1;31m", + "OK": "\033[1;32m", + "STATUS": "\033[1;37m", + "BOLD": "\033[1m", + "ENDC": "\033[0m" +} + + +def print_colored(text, color="OK", bold=False): + """ Print colored text + This might not work on windows, + although travis runs windows stuff in git bash, so it might ? + """ + color = _COLORS.get(color, color) + fmt = '' if not bold else _COLORS['BOLD'] + print(f"{color}{fmt}{text}{_COLORS['ENDC']}") + + +def print_ok(text): + """ Print ok in colored text """ + print_colored(text, "OK", True) + + +def print_fail(text): + """ Print fail in colored text """ + print_colored(text, "FAIL", True) + + +def print_status(text): + """ Print status in colored text """ + print_colored(text, "STATUS", True) + + +def run_test(name, cmd): + """ run a test """ + global _fail_count, _test_count # pylint:disable=global-statement + print_status(f"[?] running {name}") + print(f"Cmd: {' '.join(cmd)}") + _test_count += 1 + try: + check_call(cmd) + print_ok("[+] Test success") + return True + except CalledProcessError as err: + print_fail(f"[-] Test failed with {err}") + _fail_count += 1 + return False + + +def extract_args(detector, aligner, in_path, out_path, args=None): + """ Extraction command """ + py_exe = sys.executable + _extract_args = (f"{py_exe} faceswap.py extract -i {in_path} -o {out_path} -D {detector} " + f"-A {aligner}") + if args: + _extract_args += f" {args}" + return _extract_args.split() + + +def train_args(model, model_path, faces, iterations=1, batchsize=2, extra_args=""): + """ Train command """ + py_exe = sys.executable + args = (f"{py_exe} faceswap.py train -A {faces} -B {faces} -m {model_path} -t {model} " + f"-b {batchsize} -i {iterations} {extra_args}") + return args.split() + + +def convert_args(in_path, out_path, model_path, writer, args=None): + """ Convert command """ + py_exe = sys.executable + conv_args = (f"{py_exe} faceswap.py convert -i {in_path} -o {out_path} -m {model_path} " + f"-w {writer}") + if args: + conv_args += f" {args}" + return conv_args.split() # Don't use pathes with spaces ;) + + +def sort_args(in_path, out_path, sortby="face", groupby="hist"): + """ Sort command """ + py_exe = sys.executable + _sort_args = f"{py_exe} tools.py sort -i {in_path} -o {out_path} -s {sortby} -g {groupby} -k" + return _sort_args.split() + + +def set_train_config(value): + """ Update the mixed_precision and autoclip values to given value + + Parameters + ---------- + value: bool + The value to set the config parameters to. + """ + old_val, new_val = ("False", "True") if value else ("True", "False") + base_path = os.path.split(os.path.dirname(os.path.abspath(__file__)))[0] + train_ini = os.path.join(base_path, "config", "train.ini") + try: + cmd = ["sed", "-i", f"s/autoclip = {old_val}/autoclip = {new_val}/", train_ini] + check_call(cmd) + cmd = ["sed", + "-i", + f"s/mixed_precision = {old_val}/mixed_precision = {new_val}/", + train_ini] + check_call(cmd) + print_ok(f"Set autoclip and mixed_precision to `{new_val}`") + except CalledProcessError as err: + print_fail(f"[-] Test failed with {err}") + + +def main(): + """ Main testing script """ + base_dir = pathjoin(dirname(abspath(__file__)), "data") + vid_base = pathjoin(base_dir, "vid") + img_base = pathjoin(base_dir, "imgs") + py_exe = sys.executable + was_trained = False + + vid_path = pathjoin(vid_base, "test.mp4") + vid_extract = run_test( + "Extraction video with cv2-dnn detector and cv2-dnn aligner.", + extract_args("Cv2-Dnn", "Cv2-Dnn", vid_path, pathjoin(vid_base, "faces")) + ) + + run_test( + "Extraction images with cv2-dnn detector and cv2-dnn aligner.", + extract_args("Cv2-Dnn", "Cv2-Dnn", img_base, pathjoin(img_base, "faces")) + ) + + if vid_extract: + run_test( + "Generate configs and test help output", + ( + py_exe, "faceswap.py", "-h" + ) + ) + run_test( + "Sort faces.", + sort_args( + pathjoin(vid_base, "faces"), pathjoin(vid_base, "faces_sorted"), + sortby="face" + ) + ) + + run_test( + "Rename sorted faces.", + ( + py_exe, "tools.py", "alignments", "-j", "rename", + "-a", pathjoin(vid_base, "test_alignments.fsa"), + "-c", pathjoin(vid_base, "faces_sorted"), + ) + ) + set_train_config(True) + run_test( + "Train lightweight model for 1 iteration with WTL, AutoClip, MixedPrecion", + train_args("lightweight", + pathjoin(vid_base, "model"), + pathjoin(vid_base, "faces"), + iterations=1, + batchsize=1, + extra_args="-M")) + + set_train_config(False) + was_trained = run_test( + "Train lightweight model for 1 iterations WITHOUT WTL, AutoClip, MixedPrecion", + train_args("lightweight", + pathjoin(vid_base, "model"), + pathjoin(vid_base, "faces"), + iterations=1, + batchsize=1)) + + if was_trained: + run_test( + "Convert video.", + convert_args( + vid_path, pathjoin(vid_base, "conv"), + pathjoin(vid_base, "model"), "ffmpeg" + ) + ) + + run_test( + "Convert images.", + convert_args( + img_base, pathjoin(img_base, "conv"), + pathjoin(vid_base, "model"), "opencv" + ) + ) + + if _fail_count == 0: + print_ok(f"[+] Failed {_fail_count}/{_test_count} tests.") + sys.exit(0) + else: + print_fail(f"[-] Failed {_fail_count}/{_test_count} tests.") + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/tests/startup_test.py b/tests/startup_test.py new file mode 100644 index 0000000000..20e939c84d --- /dev/null +++ b/tests/startup_test.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +""" Sanity checks for Faceswap. """ + +import inspect +import sys + +import pytest +import keras +import torch + +from lib.utils import get_backend +from lib.system.system import VALID_KERAS, VALID_PYTHON, VALID_TORCH + +_BACKEND = get_backend().upper() + +_LIBS = (VALID_KERAS + (keras.__version__, ), + VALID_PYTHON + (sys.version, ), + VALID_TORCH + (torch.__version__, )) +_IDS = [f"{x}[{_BACKEND}" for x in ("keras", "python", "torch")] + + +@pytest.mark.parametrize(["min_vers", "max_vers", "installed_vers"], _LIBS, ids=_IDS) +def test_libraries(min_vers: tuple[int, int], + max_vers: tuple[int, int], + installed_vers: str) -> None: + """ Sanity check to ensure that we are running on a valid libraries """ + installed = tuple(int(x) for x in installed_vers.split(".")[:2]) + assert min_vers <= installed <= max_vers + + +@pytest.mark.parametrize('dummy', [None], ids=[_BACKEND]) +def test_backend(dummy): # pylint:disable=unused-argument + """ Sanity check to ensure that Keras backend is returning the correct object type. """ + with keras.device("cpu"): + test_var = keras.Variable((1, 1, 4, 4), trainable=False) + mod = inspect.getmodule(test_var) + assert mod is not None + lib = mod.__name__.split(".")[0] + assert lib == "keras" diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tools/alignments/media_test.py b/tests/tools/alignments/media_test.py new file mode 100644 index 0000000000..2639759ebb --- /dev/null +++ b/tests/tools/alignments/media_test.py @@ -0,0 +1,868 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`tools.alignments.media` """ +from __future__ import annotations +import os +import typing as T + +from operator import itemgetter +from unittest.mock import MagicMock + +import cv2 +import numpy as np +import pytest +import pytest_mock + +from lib.logger import log_setup +# Need to setup logging to avoid trace/verbose errors +log_setup("DEBUG", f"{__name__}.log", "PyTest, False") + +# pylint:disable=wrong-import-position,protected-access +from lib.utils import FaceswapError # noqa:E402 +from tools.alignments.media import (AlignmentData, Faces, ExtractedFaces, # noqa:E402 + Frames, MediaLoader) + +if T.TYPE_CHECKING: + from collections.abc import Generator + + +class TestAlignmentData: + """ Test for :class:`~tools.alignments.media.AlignmentData` """ + + @pytest.fixture + def alignments_file(self, tmp_path: str) -> Generator[str, None, None]: + """ Fixture for creating dummy alignments files + + Parameters + ---------- + tmp_path: str + pytest temporary path to generate folders + + Yields + ------ + str + Path to a dummy alignments file + """ + alignments_file = os.path.join(tmp_path, "alignments.fsa") + with open(alignments_file, "w", encoding="utf8") as afile: + afile.write("test") + yield alignments_file + os.remove(alignments_file) + + def test_init(self, + alignments_file: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.AlignmentData` __init__ method + + Parameters + ---------- + alignments_file: str + The temporarily generated alignments file + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking the superclass __init__ + """ + alignments_parent_init = mocker.patch("tools.alignments.media.Alignments.__init__") + mocker.patch("tools.alignments.media.Alignments.frames_count", + new_callable=mocker.PropertyMock(return_value=20)) + + AlignmentData(alignments_file) + folder, filename = os.path.split(alignments_file) + alignments_parent_init.assert_called_once_with(folder, filename=filename) + + def test_check_file_exists(self, alignments_file: str) -> None: + """ Test for :class:`~tools.alignments.media.AlignmentData` _check_file_exists method + + Parameters + ---------- + alignments_file: str + The temporarily generated alignments file + """ + assert AlignmentData.check_file_exists(alignments_file) == os.path.split(alignments_file) + fake_file = "/not/possibly/a/real/path/alignments.fsa" + with pytest.raises(SystemExit): + AlignmentData.check_file_exists(fake_file) + + def test_save(self, + alignments_file: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.AlignmentData`save method + + Parameters + ---------- + alignments_file: str + The temporarily generated alignments file + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking the superclass calls + """ + mocker.patch("tools.alignments.media.Alignments.__init__") + mocker.patch("tools.alignments.media.Alignments.frames_count", + new_callable=mocker.PropertyMock(return_value=20)) + alignments_parent_backup = mocker.patch("tools.alignments.media.Alignments.backup") + alignments_parent_save = mocker.patch("tools.alignments.media.Alignments.save") + align_data = AlignmentData(alignments_file) + align_data.save() + alignments_parent_backup.assert_called_once() + alignments_parent_save.assert_called_once() + + +@pytest.fixture(name="folder") +def folder_fixture(tmp_path: str) -> Generator[str, None, None]: + """ Fixture for creating dummy folders + + Parameters + ---------- + tmp_path: str + pytest temporary path to generate folders + + Yields + ------ + str + Path to a dummy folder + """ + folder = os.path.join(tmp_path, "images") + os.mkdir(folder) + for fname in (["a.png", "b.png"]): + with open(os.path.join(folder, fname), "wb"): + pass + yield folder + for fname in (["a.png", "b.png"]): + os.remove(os.path.join(folder, fname)) + os.rmdir(folder) + + +class TestMediaLoader: + """ Test for :class:`~tools.alignments.media.MediaLoader` """ + + @pytest.fixture(name="media_loader_instance") + def media_loader_fixture(self, + folder: str, + mocker: pytest_mock.MockerFixture) -> MediaLoader: + """ An instance of :class:`~tools.alignments.media.MediaLoader` with unimplemented + child methods patched out of __init__ and initialized with a dummy folder containing + 2 images + + Parameters + ---------- + folder : str + Dummy media folder + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking subclass calls + + Returns + ------- + :class:`~tools.alignments.media.MediaLoader` + Initialized instance for testing + """ + mocker.patch("tools.alignments.media.MediaLoader.sorted_items", + return_value=os.listdir(folder)) + mocker.patch("tools.alignments.media.MediaLoader.load_items") + loader = MediaLoader(folder) + return loader + + def test_init(self, + folder: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader`__init__ method + + Parameters + ---------- + folder : str + Dummy media folder + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking subclass calls + """ + sort_patch = mocker.patch("tools.alignments.media.MediaLoader.sorted_items", + return_value=os.listdir(folder)) + load_patch = mocker.patch("tools.alignments.media.MediaLoader.load_items") + loader = MediaLoader(folder) + sort_patch.assert_called_once() + load_patch.assert_called_once() + assert loader.folder == folder + assert loader._count == 2 + assert loader.count == 2 + assert not loader.is_video + + def test_check_input_folder(self, media_loader_instance: MediaLoader) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader` check_input_folder method + + Parameters + ---------- + media_loader_instance: :class:`~tools.alignments.media.MediaLoader` + The class instance for testing + """ + media_loader = media_loader_instance + assert media_loader.check_input_folder() is None + media_loader.folder = "" + with pytest.raises(SystemExit): + media_loader.check_input_folder() + media_loader.folder = "/this/path/does/not/exist" + with pytest.raises(SystemExit): + media_loader.check_input_folder() + + def test_valid_extension(self, media_loader_instance: MediaLoader) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader` valid_extension method + + Parameters + ---------- + media_loader_instance: :class:`~tools.alignments.media.MediaLoader` + The class instance for testing + """ + media_loader = media_loader_instance + assert media_loader.valid_extension("test.png") + assert media_loader.valid_extension("test.PNG") + assert media_loader.valid_extension("test.jpg") + assert media_loader.valid_extension("test.JPG") + assert not media_loader.valid_extension("test.doc") + assert not media_loader.valid_extension("test.txt") + assert not media_loader.valid_extension("test.mp4") + + def test_load_image(self, + media_loader_instance: MediaLoader, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader` load_image method + + Parameters + ---------- + media_loader_instance: :class:`~tools.alignments.media.MediaLoader` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking loader specific calls + """ + media_loader = media_loader_instance + expected = np.random.rand(256, 256, 3) + media_loader.load_video_frame = T.cast(MagicMock, # type:ignore + mocker.MagicMock(return_value=expected)) + read_image_patch = mocker.patch("tools.alignments.media.read_image", return_value=expected) + filename = "test.png" + output = media_loader.load_image(filename) + np.testing.assert_equal(expected, output) + read_image_patch.assert_called_once_with(os.path.join(media_loader.folder, filename), + raise_error=True) + + mocker.patch("tools.alignments.media.MediaLoader.is_video", + new_callable=mocker.PropertyMock(return_value=True)) + filename = "test.mp4" + output = media_loader.load_image(filename) + np.testing.assert_equal(expected, output) + media_loader.load_video_frame.assert_called_once_with(filename) + + def test_load_video_frame(self, + media_loader_instance: MediaLoader, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader` load_video_frame method + + Parameters + ---------- + media_loader_instance: :class:`~tools.alignments.media.MediaLoader` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking cv2 calls + """ + media_loader = media_loader_instance + filename = "test_0001.png" + with pytest.raises(AssertionError): + media_loader.load_video_frame(filename) + + mocker.patch("tools.alignments.media.MediaLoader.is_video", + new_callable=mocker.PropertyMock(return_value=True)) + expected = np.random.rand(256, 256, 3) + vid_cap = mocker.MagicMock(cv2.VideoCapture) + vid_cap.read.side_effect = ((1, expected), ) + + media_loader._vid_reader = T.cast(MagicMock, vid_cap) # type:ignore + output = media_loader.load_video_frame(filename) + vid_cap.set.assert_called_once() + np.testing.assert_equal(output, expected) + + # TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + def test_stream(self, + media_loader_instance: MediaLoader, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader` stream method + + Parameters + ---------- + media_loader_instance: :class:`~tools.alignments.media.MediaLoader` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking loader specific calls + """ + media_loader = media_loader_instance + + loader = mocker.patch("tools.alignments.media.ImagesLoader.load") + expected = [(fname, np.random.rand(256, 256, 3)) + for fname in os.listdir(media_loader.folder)] + loader.side_effect = [expected] + output = list(media_loader.stream()) + assert output == expected + + skip_call = mocker.patch("tools.alignments.media.ImagesLoader.add_skip_list") + skip_list = [0] + expected = [expected[1]] + loader.side_effect = [expected] + output = list(media_loader.stream(skip_list)) + assert output == expected + skip_call.assert_called_once_with(skip_list) + + # TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + def test_save_image(self, + media_loader_instance: MediaLoader, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.MediaLoader` save_image method + + Parameters + ---------- + media_loader_instance: :class:`~tools.alignments.media.MediaLoader` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking saver specific calls + """ + media_loader = media_loader_instance + out_folder = media_loader.folder + filename = "test_out.jpg" + expected_filename = os.path.join(media_loader.folder, "test_out.png") + img = np.random.rand(256, 256, 3) + metadata = {"test": "data"} + + cv2_write_mock = mocker.patch("cv2.imwrite") + cv2_encode_mock = mocker.patch("cv2.imencode") + png_write_meta_mock = mocker.patch("tools.alignments.media.png_write_meta") + open_mock = mocker.patch("builtins.open") + + media_loader.save_image(out_folder, filename, img, metadata=None) + cv2_write_mock.assert_called_once_with(expected_filename, img) + cv2_encode_mock.assert_not_called() + png_write_meta_mock.assert_not_called() + + cv2_write_mock.reset_mock() + + media_loader.save_image(out_folder, filename, img, metadata=metadata) # type:ignore + cv2_write_mock.assert_not_called() + cv2_encode_mock.assert_called_once_with(".png", img) + png_write_meta_mock.assert_called_once() + open_mock.assert_called_once() + + +class TestFaces: + """ Test for :class:`~tools.alignments.media.Faces` """ + + @pytest.fixture(name="faces_instance") + def faces_fixture(self, + folder: str, + mocker: pytest_mock.MockerFixture) -> Faces: + """ An instance of :class:`~tools.alignments.media.Faces` patching out + read_image_meta_batch so nothing is loaded + + Parameters + ---------- + folder : str + Dummy media folder + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking read_image_meta_batch calls + + Returns + ------- + :class:`~tools.alignments.media.Faces` + Initialized instance for testing + """ + mocker.patch("tools.alignments.media.read_image_meta_batch") + loader = Faces(folder, None) + return loader + + def test_init(self, + folder: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.Faces`__init__ method + + Parameters + ---------- + folder : str + Dummy media folder + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking superclass calls + """ + parent_mock = mocker.patch("tools.alignments.media.super") + alignments_mock = mocker.patch("tools.alignments.media.AlignmentData") + Faces(folder, alignments_mock) + parent_mock.assert_called_once() + + def test__handle_legacy(self, + faces_instance: Faces, + mocker: pytest_mock.MockerFixture, + caplog: pytest.LogCaptureFixture) -> None: + """ Test for :class:`~tools.alignments.media.Faces` _handle_legacy method + + Parameters + ---------- + faces_instance: :class:`~tools.alignments.media.Faces` + Test class instance + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking various objects + caplog: :class:`pytest.LogCaptureFixture + For capturing logging messages + """ + faces = faces_instance + folder = faces.folder + legacy_file = os.path.join(folder, "a.png") + + # No alignments file + with pytest.raises(FaceswapError): + faces._handle_legacy(legacy_file) + + # No returned metadata + alignments_mock = mocker.patch("tools.alignments.media.AlignmentData") + alignments_mock.version = 2.1 + update_mock = mocker.patch("tools.alignments.media.update_legacy_png_header", + return_value={}) + faces = Faces(folder, alignments_mock) + faces.folder = folder + with pytest.raises(FaceswapError): + faces._handle_legacy(legacy_file) + update_mock.assert_called_once_with(legacy_file, alignments_mock) + + # Correct data with logging + caplog.clear() + update_mock.reset_mock() + update_mock.return_value = {"test": "data"} + faces._handle_legacy(legacy_file, log=True) + assert "Legacy faces discovered" in caplog.text + + # Correct data without logging + caplog.clear() + update_mock.reset_mock() + update_mock.return_value = {"test": "data"} + faces._handle_legacy(legacy_file, log=False) + assert "Legacy faces discovered" not in caplog.text + + def test__handle_duplicate(self, faces_instance: Faces) -> None: + """ Test for :class:`~tools.alignments.media.Faces` _handle_duplicate method + + Parameters + ---------- + faces_instance: :class:`~tools.alignments.media.Faces` + The class instance for testing + """ + faces = faces_instance + dupe_dir = os.path.join(faces.folder, "_duplicates") + src_filename = "test_0001.png" + src_face_idx = 0 + paths = [os.path.join(faces.folder, fname) for fname in os.listdir(faces.folder)] + data = {"source": {"source_filename": src_filename, + "face_index": src_face_idx}} + seen: dict[str, list[int]] = {} + + # New item + is_dupe = faces._handle_duplicate(paths[0], data, seen) # type:ignore + assert src_filename in seen and seen[src_filename] == [src_face_idx] + assert not os.path.exists(dupe_dir) + assert not is_dupe + + # Dupe item + is_dupe = faces._handle_duplicate(paths[1], data, seen) # type:ignore + assert src_filename in seen and seen[src_filename] == [src_face_idx] + assert len(seen) == 1 + assert os.path.exists(dupe_dir) + assert not os.path.exists(paths[1]) + assert is_dupe + + # Move everything back for fixture cleanup + os.rename(os.path.join(dupe_dir, os.path.basename(paths[1])), paths[1]) + os.rmdir(dupe_dir) + + def test_process_folder(self, + faces_instance: Faces, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.Faces` process_folder method + + Parameters + ---------- + faces_instance: :class:`~tools.alignments.media.Faces` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking various logic calls + """ + faces = faces_instance + read_image_meta_mock = mocker.patch("tools.alignments.media.read_image_meta_batch") + img_sources = [os.path.join(faces.folder, fname) for fname in os.listdir(faces.folder)] + meta_data = {"itxt": {"source": ({"source_filename": "data.png"})}} + expected = [(fname, meta_data["itxt"]) for fname in os.listdir(faces.folder)] + read_image_meta_mock.side_effect = [[(src, meta_data) for src in img_sources]] + + legacy_mock = mocker.patch("tools.alignments.media.Faces._handle_legacy", + return_value=meta_data["itxt"]) + dupe_mock = mocker.patch("tools.alignments.media.Faces._handle_duplicate", + return_value=False) + + # valid itxt + output = list(faces.process_folder()) + assert read_image_meta_mock.call_count == 1 + assert dupe_mock.call_count == 2 + assert not legacy_mock.called + assert output == expected + + dupe_mock.reset_mock() + read_image_meta_mock.reset_mock() + + # valid itxt with alignemnts data + read_image_meta_mock.side_effect = [[(src, meta_data) for src in img_sources]] + faces._alignments = mocker.MagicMock(AlignmentData) + faces._alignments.version = 2.1 # type:ignore + output = list(faces.process_folder()) + assert faces._alignments.frame_exists.call_count == 2 # type:ignore + assert read_image_meta_mock.call_count == 1 + assert dupe_mock.call_count == 2 + + dupe_mock.reset_mock() + read_image_meta_mock.reset_mock() + faces._alignments = None + + # invalid itxt + read_image_meta_mock.side_effect = [[(src, {}) for src in img_sources]] + output = list(faces.process_folder()) + assert read_image_meta_mock.call_count == 1 + assert legacy_mock.call_count == 2 + assert dupe_mock.call_count == 2 + assert output == expected + + def test_load_items(self, + faces_instance: Faces) -> None: + """ Test for :class:`~tools.alignments.media.Faces` load_items method + + Parameters + ---------- + faces_instance: :class:`~tools.alignments.media.Faces` + The class instance for testing + """ + faces = faces_instance + data = [(f"file{idx}.png", {"source": {"source_filename": f"src{idx}.png", + "face_index": 0}}) + for idx in range(4)] + faces.file_list_sorted = data # type: ignore + expected = {"src0.png": [0], "src1.png": [0], "src2.png": [0], "src3.png": [0]} + result = faces.load_items() + assert result == expected + + data = [(f"file{idx}.png", {"source": {"source_filename": f"src{idx // 2}.png", + "face_index": 0 if idx % 2 == 0 else 1}}) + for idx in range(4)] + faces.file_list_sorted = data # type: ignore + expected = {"src0.png": [0, 1], "src1.png": [0, 1]} + result = faces.load_items() + assert result == expected + + def test_sorted_items(self, + faces_instance: Faces, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.Faces` sorted_items method + + Parameters + ---------- + faces_instance: :class:`~tools.alignments.media.Faces` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking various logic calls + """ + faces = faces_instance + data: list[tuple[str, dict]] = [("file4.png", {}), ("file3.png", {}), + ("file1.png", {}), ("file2.png", {})] + expected = sorted(data) + process_folder_mock = mocker.patch("tools.alignments.media.Faces.process_folder", + side_effect=[data]) + result = faces.sorted_items() + assert process_folder_mock.called + assert result == expected + + +class TestFrames: + """ Test for :class:`~tools.alignments.media.Frames` """ + + def test_process_folder(self, + folder: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.Frames` process_folder method + + Parameters + ---------- + folder : str + Dummy media folder + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking superclass calls + """ + process_video_mock = mocker.patch("tools.alignments.media.Frames.process_video") + process_frames_mock = mocker.patch("tools.alignments.media.Frames.process_frames") + + frames = Frames(folder, None) + frames.process_folder() + process_frames_mock.assert_called_once() + process_video_mock.assert_not_called() + + process_frames_mock.reset_mock() + mocker.patch("tools.alignments.media.Frames.is_video", + new_callable=mocker.PropertyMock(return_value=True)) + frames = Frames(folder, None) + frames.process_folder() + process_frames_mock.assert_not_called() + process_video_mock.assert_called_once() + + def test_process_frames(self, folder: str) -> None: + """ Test for :class:`~tools.alignments.media.Frames` process_frames method + + Parameters + ---------- + folder : str + Dummy media folder + """ + expected = [{"frame_fullname": "a.png", "frame_name": "a", "frame_extension": ".png"}, + {"frame_fullname": "b.png", "frame_name": "b", "frame_extension": ".png"}] + + frames = Frames(folder, None) + returned = sorted(list(frames.process_frames()), key=itemgetter("frame_fullname")) + assert returned == sorted(expected, key=itemgetter("frame_fullname")) + + def test_process_video(self, folder: str) -> None: + """ Test for :class:`~tools.alignments.media.Frames` process_video method + + Parameters + ---------- + folder : str + Dummy media folder + """ + ext = os.path.splitext(folder)[-1] + expected = [{"frame_fullname": f"images_000001{ext}", + "frame_name": "images_000001", + "frame_extension": ext}, + {"frame_fullname": f"images_000002{ext}", + "frame_name": "images_000002", + "frame_extension": ext}] + + frames = Frames(folder, None) + returned = list(frames.process_video()) + assert returned == expected + + def test_load_items(self, folder: str) -> None: + """ Test for :class:`~tools.alignments.media.Frames` load_items method + + Parameters + ---------- + folder : str + Dummy media folder + """ + expected = {"a.png": ("a", ".png"), "b.png": ("b", ".png")} + frames = Frames(folder, None) + result = frames.load_items() + assert result == expected + + def test_sorted_items(self, + folder: str, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.Frames` sorted_items method + + Parameters + ---------- + folder : str + Dummy media folder + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking process_folder call + """ + frames = Frames(folder, None) + data = [{"frame_fullname": "c.png", "frame_name": "c", "frame_extension": ".png"}, + {"frame_fullname": "d.png", "frame_name": "d", "frame_extension": ".png"}, + {"frame_fullname": "b.jpg", "frame_name": "b", "frame_extension": ".jpg"}, + {"frame_fullname": "a.png", "frame_name": "a", "frame_extension": ".png"}] + expected = [{"frame_fullname": "a.png", "frame_name": "a", "frame_extension": ".png"}, + {"frame_fullname": "b.jpg", "frame_name": "b", "frame_extension": ".jpg"}, + {"frame_fullname": "c.png", "frame_name": "c", "frame_extension": ".png"}, + {"frame_fullname": "d.png", "frame_name": "d", "frame_extension": ".png"}] + process_folder_mock = mocker.patch("tools.alignments.media.Frames.process_folder", + side_effect=[data]) + result = frames.sorted_items() + + assert process_folder_mock.called + assert result == expected + + +class TestExtractedFaces: + """ Test for :class:`~tools.alignments.media.ExtractedFaces` """ + + @pytest.fixture(name="extracted_faces_instance") + def extracted_faces_fixture(self, mocker: pytest_mock.MockerFixture) -> ExtractedFaces: + """ An instance of :class:`~tools.alignments.media.ExtractedFaces` patching out Frames and + AlignmentData parameters + + Parameters + ---------- + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking read_image_meta_batch calls + + Returns + ------- + :class:`~tools.alignments.media.ExtractedFaces` + Initialized instance for testing + """ + frames_mock = mocker.MagicMock(Frames) + alignments_mock = mocker.MagicMock(AlignmentData) + return ExtractedFaces(frames_mock, alignments_mock, size=512) + + def test_init(self, extracted_faces_instance: ExtractedFaces) -> None: + """ Test for :class:`~tools.alignments.media.ExtractedFace` __init__ method + + Parameters + ---------- + extracted_faces_instance: :class:`~tools.alignments.media.ExtractedFace` + The class instance for testing + """ + faces = extracted_faces_instance + assert faces.size == 512 + assert faces.padding == int(512 * 0.1875) + assert faces.current_frame is None + assert faces.faces == [] + + def test_get_faces(self, + extracted_faces_instance: ExtractedFaces, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.ExtractedFace` get_faces method + + Parameters + ---------- + extracted_faces_instance: :class:`~tools.alignments.media.ExtractedFace` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking Frames and AlignmentData classes + """ + extract_face_mock = mocker.patch("tools.alignments.media.ExtractedFaces.extract_one_face") + faces = extracted_faces_instance + + frame = "test_frame" + img = np.random.rand(256, 256, 3) + + # No alignment data + faces.alignments.get_faces_in_frame.return_value = [] # type:ignore + faces.get_faces(frame, img) + faces.alignments.get_faces_in_frame.assert_called_once_with(frame) # type:ignore + faces.frames.load_image.assert_not_called() # type:ignore + extract_face_mock.assert_not_called() + assert faces.current_frame is None + + faces.alignments.reset_mock() # type:ignore + + # Alignment data + image + faces.alignments.get_faces_in_frame.return_value = [1, 2, 3] # type:ignore + faces.get_faces(frame, img) + faces.alignments.get_faces_in_frame.assert_called_once_with(frame) # type:ignore + faces.frames.load_image.assert_not_called() # type:ignore + assert extract_face_mock.call_count == 3 + assert faces.current_frame == frame + + faces.alignments.reset_mock() # type:ignore + extract_face_mock.reset_mock() + + # Alignment data + no image + faces.alignments.get_faces_in_frame.return_value = ["data1"] # type:ignore + faces.get_faces(frame, None) + faces.alignments.get_faces_in_frame.assert_called_once_with(frame) # type:ignore + faces.frames.load_image.assert_called_once_with(frame) # type:ignore + assert extract_face_mock.call_count == 1 + assert faces.current_frame == frame + + # TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + def test_extract_one_face(self, + extracted_faces_instance: ExtractedFaces, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.ExtractedFace` extract_one_face method + + Parameters + ---------- + extracted_faces_instance: :class:`~tools.alignments.media.ExtractedFace` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking DetectedFace object + """ + detected_face = mocker.patch("tools.alignments.media.DetectedFace") + thumbnail_mock = mocker.patch("tools.alignments.media.generate_thumbnail") + faces = extracted_faces_instance + alignment = {"test"} + img = np.random.rand(256, 256, 3) + returned = faces.extract_one_face(alignment, img) # type:ignore + detected_face.assert_called_once() + detected_face.return_value.from_alignment.assert_called_once_with(alignment, + image=img) + detected_face.return_value.load_aligned.assert_called_once_with(img, + size=512, + centering="head") + thumbnail_mock.assert_called_once() + assert isinstance(returned, MagicMock) + + def test_get_faces_in_frame(self, + extracted_faces_instance: ExtractedFaces, + mocker: pytest_mock.MockerFixture) -> None: + """ Test for :class:`~tools.alignments.media.ExtractedFace` get_faces_in_frame method + + Parameters + ---------- + extracted_faces_instance: :class:`~tools.alignments.media.ExtractedFace` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking get_faces method + """ + faces = extracted_faces_instance + faces.get_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + + frame = "test_frame" + img = None + + faces.get_faces_in_frame(frame, update=False, image=img) + faces.get_faces.assert_called_once_with(frame, image=img) + + faces.get_faces.reset_mock() + + faces.current_frame = frame + faces.get_faces_in_frame(frame, update=False, image=img) + faces.get_faces.assert_not_called() + + faces.get_faces_in_frame(frame, update=True, image=img) + faces.get_faces.assert_called_once_with(frame, image=img) + + _params = [(np.array(([[25, 47], [32, 232], [244, 237], [240, 21]])), 216), + (np.array(([[127, 392], [403, 510], [32, 237], [19, 210]])), 211), + (np.array(([[26, 1927], [112, 1234], [1683, 1433], [78, 1155]])), 773)] + + @pytest.mark.parametrize("roi,expected", _params) + def test_get_roi_size_for_frame(self, + extracted_faces_instance: ExtractedFaces, + mocker: pytest_mock.MockerFixture, + roi: np.ndarray, + expected: int) -> None: + """ Test for :class:`~tools.alignments.media.ExtractedFace` get_roi_size_for_frame method + + Parameters + ---------- + extracted_faces_instance: :class:`~tools.alignments.media.ExtractedFace` + The class instance for testing + mocker: :class:`pytest_mock.MockerFixture` + Fixture for mocking get_faces method and DetectedFace object + roi: :class:`numpy.ndarray` + Test ROI box to feed into the function + expected: int + The expected output for the given ROI box + """ + faces = extracted_faces_instance + faces.get_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + + frame = "test_frame" + faces.get_roi_size_for_frame(frame) + faces.get_faces.assert_called_once_with(frame) + + faces.get_faces.reset_mock() + + faces.current_frame = frame + faces.get_roi_size_for_frame(frame) + faces.get_faces.assert_not_called() + + detected_face = mocker.MagicMock("tools.alignments.media.DetectedFace") + detected_face.aligned = detected_face + detected_face.original_roi = roi + faces.faces = [detected_face] + result = faces.get_roi_size_for_frame(frame) + assert result == [expected] diff --git a/tests/tools/preview/__init__.py b/tests/tools/preview/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tools/preview/viewer_test.py b/tests/tools/preview/viewer_test.py new file mode 100644 index 0000000000..2ce814b234 --- /dev/null +++ b/tests/tools/preview/viewer_test.py @@ -0,0 +1,489 @@ +#!/usr/bin python3 +""" Pytest unit tests for :mod:`tools.preview.viewer` """ +from __future__ import annotations +import tkinter as tk +import typing as T + +from tkinter import ttk + +from unittest.mock import MagicMock + +import pytest +import pytest_mock +import numpy as np +from PIL import ImageTk + +from lib.logger import log_setup +# Need to setup logging to avoid trace/verbose errors +log_setup("DEBUG", "pytest_viewer.log", "PyTest, False") + +from lib.utils import get_backend # pylint:disable=wrong-import-position # noqa +from tools.preview.viewer import _Faces, FacesDisplay, ImagesCanvas # pylint:disable=wrong-import-position # noqa + +if T.TYPE_CHECKING: + from lib.align.aligned_face import CenteringType + + +# pylint:disable=protected-access + + +def test__faces(): + """ Test the :class:`~tools.preview.viewer._Faces dataclass initializes correctly """ + faces = _Faces() + assert isinstance(faces.filenames, list) and not faces.filenames + assert isinstance(faces.matrix, list) and not faces.matrix + assert isinstance(faces.src, list) and not faces.src + assert isinstance(faces.dst, list) and not faces.dst + + +_PARAMS = ((3, 448), (4, 333), (5, 254), (6, 128)) # columns/face_size +_IDS = [f"cols:{c},size:{s}[{get_backend().upper()}]" for c, s in _PARAMS] + + +class TestFacesDisplay(): + """ Test :class:`~tools.preview.viewer.FacesDisplay """ + _padding = 64 + + def get_faces_display_instance(self, columns: int = 5, face_size: int = 256) -> FacesDisplay: + """ Obtain an instance of :class:`~tools.preview.viewer.FacesDisplay` with the given column + and face size layout. + + Parameters + ---------- + columns: int, optional + The number of columns to display in the viewer, default: 5 + face_size: int, optional + The size of each face image to be displayed in the viewer, default: 256 + + Returns + ------- + :class:`~tools.preview.viewer.FacesDisplay` + An instance of the FacesDisplay class at the given settings + """ + app = MagicMock() + retval = FacesDisplay(app, face_size, self._padding) + retval._faces = _Faces( + matrix=[np.random.rand(2, 3) for _ in range(columns)], + src=[np.random.rand(face_size, face_size, 3) for _ in range(columns)], + dst=[np.random.rand(face_size, face_size, 3) for _ in range(columns)]) + return retval + + def test_init(self) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` __init__ method """ + f_display = self.get_faces_display_instance(face_size=256) + assert f_display._size == 256 + assert f_display._padding == self._padding + assert isinstance(f_display._app, MagicMock) + + assert f_display._display_dims == (1, 1) + assert isinstance(f_display._faces, _Faces) + + assert f_display._centering is None + assert f_display._faces_source.size == 0 + assert f_display._faces_dest.size == 0 + assert f_display._tk_image is None + assert f_display.update_source is False + assert not f_display.source and isinstance(f_display.source, list) + assert not f_display.destination and isinstance(f_display.destination, list) + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test__total_columns(self, columns: int, face_size: int) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _total_columns property is correctly + calculated + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display.source = [None for _ in range(columns)] # type:ignore + assert f_display._total_columns == columns + + def test_set_centering(self) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` set_centering method """ + f_display = self.get_faces_display_instance() + assert f_display._centering is None + centering: CenteringType = "legacy" + f_display.set_centering(centering) + assert f_display._centering == centering + + def test_set_display_dimensions(self) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` set_display_dimensions method """ + f_display = self.get_faces_display_instance() + assert f_display._display_dims == (1, 1) + dimensions = (800, 600) + f_display.set_display_dimensions(dimensions) + assert f_display._display_dims == dimensions + + # TODO remove the next line that supresses a weird pytest bug when it tears down the tempdir + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test_update_tk_image(self, + columns: int, + face_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` update_tk_image method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking _build_faces_image method called + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display._build_faces_image = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._get_scale_size = T.cast(MagicMock, # type:ignore + mocker.MagicMock(return_value=(128, 128))) + f_display._faces_source = np.zeros((face_size, face_size, 3), dtype=np.uint8) + f_display._faces_dest = np.zeros((face_size, face_size, 3), dtype=np.uint8) + + try: + tk.Tk() # tkinter instance needed for image creation + except tk.TclError: + # Some Windows runners arbitrarily don't install Tk correctly + pytest.skip("Tk not available on this system") + f_display.update_tk_image() + + f_display._build_faces_image.assert_called_once() + f_display._get_scale_size.assert_called_once() + assert isinstance(f_display._tk_image, ImageTk.PhotoImage) + assert f_display._tk_image.width() == 128 + assert f_display._tk_image.height() == 128 + assert f_display.tk_image == f_display._tk_image # public property test + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test_get_scale_size(self, columns: int, face_size: int) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` get_scale_size method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display.set_display_dimensions((800, 600)) + + img = np.zeros((face_size, face_size, 3), dtype=np.uint8) + size = f_display._get_scale_size(img) + assert size == (600, 600) + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test__build_faces_image(self, + columns: int, + face_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _build_faces_image method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking internal methods called + """ + header_size = 32 + + f_display = self.get_faces_display_instance(columns, face_size) + f_display._faces_from_frames = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._header_text = T.cast( # type:ignore + MagicMock, + mocker.MagicMock(return_value=np.random.rand(header_size, face_size * columns, 3))) + f_display._draw_rect = T.cast(MagicMock, # type:ignore + mocker.MagicMock(side_effect=lambda x: x)) + + # Test full update + f_display.update_source = True + f_display._build_faces_image() + + f_display._faces_from_frames.assert_called_once() + f_display._header_text.assert_called_once() + assert f_display._draw_rect.call_count == columns * 2 # src + dst + assert f_display._faces_source.shape == (face_size + header_size, face_size * columns, 3) + assert f_display._faces_dest.shape == (face_size, face_size * columns, 3) + + f_display._faces_from_frames.reset_mock() + f_display._header_text.reset_mock() + f_display._draw_rect.reset_mock() + + # Test dst update only + f_display.update_source = False + f_display._build_faces_image() + + f_display._faces_from_frames.assert_called_once() + assert not f_display._header_text.called + assert f_display._draw_rect.call_count == columns # dst only + assert f_display._faces_dest.shape == (face_size, face_size * columns, 3) + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test_faces__from_frames(self, + columns, + face_size, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _from_frames method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + mocker: :class:`pytest_mock.MockerFixture` + Mocker for checking _build_faces_image method called + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display.source = [mocker.MagicMock() for _ in range(3)] + f_display.destination = [np.random.rand(face_size, face_size, 3) for _ in range(3)] + f_display._crop_source_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._crop_destination_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + + # Both src + dst + f_display.update_source = True + f_display._faces_from_frames() + f_display._crop_source_faces.assert_called_once() + f_display._crop_destination_faces.assert_called_once() + + f_display._crop_source_faces.reset_mock() + f_display._crop_destination_faces.reset_mock() + + # Just dst + f_display.update_source = False + f_display._faces_from_frames() + assert not f_display._crop_source_faces.called + f_display._crop_destination_faces.assert_called_once() + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test__crop_source_faces(self, + columns: int, + face_size: int, + monkeypatch: pytest.MonkeyPatch, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _crop_source_faces method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + monkeypatch: :class:`pytest.MonkeyPatch` + For patching the transform_image function + mocker: :class:`pytest_mock.MockerFixture` + Mocker for mocking various internal methods + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display._centering = "face" + f_display.update_source = True + f_display._faces.src = [] + + transform_image_mock = mocker.MagicMock() + monkeypatch.setattr("tools.preview.viewer.transform_image", transform_image_mock) + + f_display.source = [mocker.MagicMock() for _ in range(columns)] + for idx, mock in enumerate(f_display.source): + assert isinstance(mock, MagicMock) + mock.inbound.detected_faces.__getitem__ = lambda self, x, y=mock: y + mock.aligned.matrix = f"test_matrix_{idx}" + mock.inbound.filename = f"test_filename_{idx}.txt" + + f_display._crop_source_faces() + + assert len(f_display._faces.filenames) == columns + assert len(f_display._faces.matrix) == columns + assert len(f_display._faces.src) == columns + assert not f_display.update_source + assert transform_image_mock.call_count == columns + + for idx in range(columns): + assert f_display._faces.filenames[idx] == f"test_filename_{idx}" + assert f_display._faces.matrix[idx] == f"test_matrix_{idx}" + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test__crop_destination_faces(self, + columns: int, + face_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _crop_destination_faces method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in full frames + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display._centering = "face" + f_display._faces.dst = [] # empty object and test populated correctly + + f_display.source = [mocker.MagicMock() for _ in range(columns)] + for item in f_display.source: # type ignore + item.inbound.image = np.random.rand(1280, 720, 3) # type:ignore + + f_display._crop_destination_faces() + assert len(f_display._faces.dst) == columns + assert all(f.shape == (face_size, face_size, 3) for f in f_display._faces.dst) + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test__header_text(self, + columns: int, + face_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _header_text method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in cv2 calls + """ + f_display = self.get_faces_display_instance(columns, face_size) + f_display.source = [None for _ in range(columns)] # type:ignore + f_display._faces.filenames = [f"filename_{idx}.png" for idx in range(columns)] + + cv2_mock = mocker.patch("tools.preview.viewer.cv2") + text_width, text_height = (100, 32) + cv2_mock.getTextSize.return_value = [(text_width, text_height), ] + + header_box = f_display._header_text() + assert cv2_mock.getTextSize.call_count == columns + assert cv2_mock.putText.call_count == columns + assert header_box.shape == (face_size // 8, face_size * columns, 3) + + @pytest.mark.parametrize("columns, face_size", _PARAMS, ids=_IDS) + def test__draw_rect_text(self, + columns: int, + face_size: int, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.FacesDisplay` _draw_rect method + + Parameters + ---------- + columns: int + The number of columns to display in the viewer + face_size: int + The size of each face image to be displayed in the viewer + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in cv2 calls + """ + f_display = self.get_faces_display_instance(columns, face_size) + cv2_mock = mocker.patch("tools.preview.viewer.cv2") + + image = (np.random.rand(face_size, face_size, 3) * 255.0) + 50 + assert image.max() > 255.0 + output = f_display._draw_rect(image) + cv2_mock.rectangle.assert_called_once() + assert output.max() == 255.0 # np.clip + + +class TestImagesCanvas: + """ Test :class:`~tools.preview.viewer.ImagesCanvas` """ + + @pytest.fixture + def parent(self) -> MagicMock: + """ Mock object to act as the parent widget to the ImagesCanvas + + Returns + -------- + :class:`unittest.mock.MagicMock` + The mocked ttk.PanedWindow widget + """ + retval = MagicMock(spec=ttk.PanedWindow) + retval.tk = retval + retval._w = "mock_ttkPanedWindow" + retval.children = {} + retval.call = retval + retval.createcommand = retval + retval.preview_display = MagicMock(spec=FacesDisplay) + return retval + + @pytest.fixture(name="images_canvas_instance") + def images_canvas_fixture(self, parent) -> ImagesCanvas: + """ Fixture for creating a testing :class:`~tools.preview.viewer.ImagesCanvas` instance + + Parameters + ---------- + parent: :class:`unittest.mock.MagicMock` + The mocked ttk.PanedWindow parent + + Returns + ------- + :class:`~tools.preview.viewer.ImagesCanvas` + The class instance for testing + """ + app = MagicMock() + return ImagesCanvas(app, parent) + + def test_init(self, images_canvas_instance: ImagesCanvas, parent: MagicMock) -> None: + """ Test :class:`~tools.preview.viewer.ImagesCanvas` __init__ method + + Parameters + ---------- + images_canvas_instance: :class:`~tools.preview.viewer.ImagesCanvas` + The class instance to test + parent: :class:`unittest.mock.MagicMock` + The mocked parent ttk.PanedWindow + """ + assert images_canvas_instance._display == parent.preview_display + assert isinstance(images_canvas_instance._canvas, tk.Canvas) + assert images_canvas_instance._canvas.master == images_canvas_instance + assert images_canvas_instance._canvas.winfo_ismapped() + + def test_resize(self, + images_canvas_instance: ImagesCanvas, + parent: MagicMock, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.ImagesCanvas` resize method + + Parameters + ---------- + images_canvas_instance: :class:`~tools.preview.viewer.ImagesCanvas` + The class instance to test + parent: :class:`unittest.mock.MagicMock` + The mocked parent ttk.PanedWindow + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in tk calls + """ + event_mock = mocker.MagicMock(spec=tk.Event, width=100, height=200) + images_canvas_instance.reload = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + + images_canvas_instance._resize(event_mock) + + parent.preview_display.set_display_dimensions.assert_called_once_with((100, 200)) + images_canvas_instance.reload.assert_called_once() + + def test_reload(self, + images_canvas_instance: ImagesCanvas, + parent: MagicMock, + mocker: pytest_mock.MockerFixture) -> None: + """ Test :class:`~tools.preview.viewer.ImagesCanvas` reload method + + Parameters + ---------- + images_canvas_instance: :class:`~tools.preview.viewer.ImagesCanvas` + The class instance to test + parent: :class:`unittest.mock.MagicMock` + The mocked parent ttk.PanedWindow + mocker: :class:`pytest_mock.MockerFixture` + Mocker for dummying in tk calls + """ + itemconfig_mock = mocker.patch.object(tk.Canvas, "itemconfig") + + images_canvas_instance.reload() + + parent.preview_display.update_tk_image.assert_called_once() + itemconfig_mock.assert_called_once() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..b357dc13c4 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" Utils imported from Keras as their location changes between Tensorflow Keras and standard +Keras. Also ensures testing consistency """ +import inspect + +import numpy as np + + +def generate_test_data(num_train=1000, num_test=500, input_shape=(10,), + output_shape=(2,), + classification=True, num_classes=2): + """Generates test data to train a model on. classification=True overrides output_shape (i.e. + output_shape is set to (1,)) and the output consists in integers in [0, num_classes-1]. + + Otherwise: float output with shape output_shape. + """ + samples = num_train + num_test + if classification: + var_y = np.random.randint(0, num_classes, size=(samples,)) + var_x = np.zeros((samples,) + input_shape, dtype=np.float32) + for i in range(samples): + var_x[i] = np.random.normal(loc=var_y[i], scale=0.7, size=input_shape) + else: + y_loc = np.random.random((samples,)) + var_x = np.zeros((samples,) + input_shape, dtype=np.float32) + var_y = np.zeros((samples,) + output_shape, dtype=np.float32) + for i in range(samples): + var_x[i] = np.random.normal(loc=y_loc[i], scale=0.7, size=input_shape) + var_y[i] = np.random.normal(loc=y_loc[i], scale=0.7, size=output_shape) + + return (var_x[:num_train], var_y[:num_train]), (var_x[num_train:], var_y[num_train:]) + + +def to_categorical(var_y, num_classes=None, dtype='float32'): + """Converts a class vector (integers) to binary class matrix. + E.g. for use with categorical_crossentropy. + + Parameters + ---------- + var_y: int + Class vector to be converted into a matrix (integers from 0 to num_classes). + num_classes: int + Total number of classes. + dtype: str + The data type expected by the input, as a string (`float32`, `float64`, `int32`...) + + Returns + ------- + tensor + A binary matrix representation of the input. The classes axis is placed last. + + Example + ------- + >>> # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}: + >>> labels + >>> array([0, 2, 1, 2, 0]) + >>> # `to_categorical` converts this into a matrix with as many columns as there are classes. + >>> # The number of rows stays the same. + >>> to_categorical(labels) + >>> array([[ 1., 0., 0.], + >>> [ 0., 0., 1.], + >>> [ 0., 1., 0.], + >>> [ 0., 0., 1.], + >>> [ 1., 0., 0.]], dtype=float32) + """ + var_y = np.array(var_y, dtype='int') + input_shape = var_y.shape + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + var_y = var_y.ravel() + if not num_classes: + num_classes = np.max(var_y) + 1 + var_n = var_y.shape[0] + categorical = np.zeros((var_n, num_classes), dtype=dtype) + categorical[np.arange(var_n), var_y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +def has_arg(func, name, accept_all=False): + """Checks if a callable accepts a given keyword argument. + + For Python 2, checks if there is an argument with the given name. + For Python 3, checks if there is an argument with the given name, and also whether this + argument can be called with a keyword (i.e. if it is not a positional-only argument). + + Parameters + ---------- + func: object + Callable to inspect. + name: str + Check if `func` can be called with `name` as a keyword argument. + accept_all: bool, optional + What to return if there is no parameter called `name` but the function accepts a + `**kwargs` argument. Default: ``False`` + + Returns + ------- + bool + Whether `func` accepts a `name` keyword argument. + """ + signature = inspect.signature(func) + parameter = signature.parameters.get(name) + if parameter is None: + if accept_all: + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return False + return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) diff --git a/tools.py b/tools.py index 188a86e48d..423b58a7d8 100755 --- a/tools.py +++ b/tools.py @@ -1,46 +1,48 @@ #!/usr/bin/env python3 """ The master tools.py script """ +import gettext +import os import sys + +from importlib import import_module + # Importing the various tools -import tools.cli as cli -from lib.cli import FullHelpArgumentParser, GuiArgs +from lib.cli.args import FullHelpArgumentParser + +# LOCALES +_LANG = gettext.translation("tools", localedir="locales", fallback=True) +_ = _LANG.gettext # Python version check -if sys.version_info[0] < 3: - raise Exception("This program requires at least python3.2") -if sys.version_info[0] == 3 and sys.version_info[1] < 2: - raise Exception("This program requires at least python3.2") +if sys.version_info < (3, 11): + raise ValueError("This program requires at least python 3.11") -def bad_args(args): +def bad_args(*args): # pylint:disable=unused-argument """ Print help on bad arguments """ PARSER.print_help() - exit(0) + sys.exit(0) -if __name__ == "__main__": - _tools_warning = "Please backup your data and/or test the tool you want " - _tools_warning += "to use with a smaller data set to make sure you " - _tools_warning += "understand how it works." - print(_tools_warning) +def _get_cli_opts(): + """ Optain the subparsers and cli options for available tools """ + base_path = os.path.realpath(os.path.dirname(sys.argv[0])) + tools_dir = os.path.join(base_path, "tools") + for tool_name in sorted(os.listdir(tools_dir)): + cli_file = os.path.join(tools_dir, tool_name, "cli.py") + if os.path.exists(cli_file): + mod = ".".join(("tools", tool_name, "cli")) + module = import_module(mod) + cliarg_class = getattr(module, f"{tool_name.title()}Args") + help_text = getattr(module, "_HELPTEXT") + yield tool_name, help_text, cliarg_class + +if __name__ == "__main__": PARSER = FullHelpArgumentParser() SUBPARSER = PARSER.add_subparsers() - ALIGN = cli.AlignmentsArgs(SUBPARSER, - "alignments", - "This command lets you perform various tasks " - "pertaining to an alignments file.") - EFFMPEG = cli.EffmpegArgs(SUBPARSER, - "effmpeg", - "This command allows you to easily execute " - "common ffmpeg tasks.") - SORT = cli.SortArgs(SUBPARSER, - "sort", - "This command lets you sort images using various " - "methods.") - GUI = GuiArgs(SUBPARSER, - "gui", - "Launch the Faceswap Tools Graphical User Interface.") + for tool, helptext, cli_args in _get_cli_opts(): + cli_args(SUBPARSER, tool, helptext) PARSER.set_defaults(func=bad_args) ARGUMENTS = PARSER.parse_args() ARGUMENTS.func(ARGUMENTS) diff --git a/tools/alignments.py b/tools/alignments.py deleted file mode 100644 index 2d5b65c86d..0000000000 --- a/tools/alignments.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -""" Tools for manipulating the alignments seralized file """ - -from lib.utils import set_system_verbosity -from .lib_alignments import (AlignmentData, Check, Draw, # noqa pylint: disable=unused-import - Extract, Legacy, Manual, Merge, Reformat, Rename, - RemoveAlignments, Sort, Spatial, UpdateHashes) - - -class Alignments(): - """ Perform tasks relating to alignments file """ - def __init__(self, arguments): - self.args = arguments - set_system_verbosity(self.args.loglevel) - - dest_format = self.get_dest_format() - self.alignments = AlignmentData(self.args.alignments_file, dest_format) - - def get_dest_format(self): - """ Set the destination format for Alignments """ - dest_format = None - if hasattr(self.args, 'alignment_format') and self.args.alignment_format: - dest_format = self.args.alignment_format - return dest_format - - def process(self): - """ Main processing function of the Align tool """ - if self.args.job.startswith("extract"): - job = Extract - elif self.args.job == "update-hashes": - job = UpdateHashes - elif self.args.job.startswith("remove-"): - job = RemoveAlignments - elif self.args.job.startswith("sort-"): - job = Sort - elif self.args.job in("missing-alignments", "missing-frames", - "multi-faces", "leftover-faces", "no-faces"): - job = Check - else: - job = globals()[self.args.job.title()] - job = job(self.alignments, self.args) - job.process() diff --git a/tools/alignments/__init__.py b/tools/alignments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/alignments/alignments.py b/tools/alignments/alignments.py new file mode 100644 index 0000000000..d9d610f290 --- /dev/null +++ b/tools/alignments/alignments.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python3 +""" Tools for manipulating the alignments serialized file """ +import logging +import os +import sys +import typing as T + +from argparse import Namespace +from multiprocessing import Process + +from lib.utils import (get_module_objects, FaceswapError, + handle_deprecated_cliopts, VIDEO_EXTENSIONS) +from .media import AlignmentData +from .jobs import Check, Export, Sort, Spatial # noqa pylint:disable=unused-import +from .jobs_faces import FromFaces, RemoveFaces, Rename # noqa pylint:disable=unused-import +from .jobs_frames import Draw, Extract # noqa pylint:disable=unused-import + + +logger = logging.getLogger(__name__) + + +class Alignments(): + """ The main entry point for Faceswap's Alignments Tool. This tool is part of the Faceswap + Tools suite and should be called from the ``python tools.py alignments`` command. + + The tool allows for manipulation, and working with Faceswap alignments files. + + This parent class handles creating the individual job arguments when running in batch-mode or + triggers the job when not running in batch mode + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) + self._requires_alignments = ["export", "sort", "spatial"] + self._requires_faces = ["extract", "from-faces"] + self._requires_frames = ["draw", + "extract", + "missing-alignments", + "missing-frames", + "no-faces"] + + self._args = handle_deprecated_cliopts(arguments) + self._batch_mode = self._validate_batch_mode() + self._locations = self._get_locations() + + def _validate_batch_mode(self) -> bool: + """ Validate that the selected job supports batch processing + + Returns + ------- + bool + ``True`` if batch mode has been selected otherwise ``False`` + """ + batch_mode: bool = self._args.batch_mode + if not batch_mode: + logger.debug("Running in standard mode") + return batch_mode + valid = self._requires_alignments + self._requires_faces + self._requires_frames + if self._args.job not in valid: + logger.error("Job '%s' does not support batch mode. Please select a job from %s or " + "disable batch mode", self._args.job, valid) + sys.exit(1) + logger.debug("Running in batch mode") + return batch_mode + + def _get_alignments_locations(self) -> dict[str, list[str | None]]: + """ Obtain the full path to alignments files in a parent (batch) location + + These are jobs that only require an alignments file as input, so frames and face locations + are returned as a list of ``None`` values corresponding to the number of alignments files + detected + + Returns + ------- + dict[str, list[Optional[str]]]: + The list of alignments location paths and None lists for frames and faces locations + """ + if not self._args.alignments_file: + logger.error("Please provide an 'alignments_file' location for '%s' job", + self._args.job) + sys.exit(1) + + alignments = [os.path.join(self._args.alignments_file, fname) + for fname in os.listdir(self._args.alignments_file) + if os.path.splitext(fname)[-1].lower() == ".fsa" + and os.path.splitext(fname)[0].endswith("alignments")] + if not alignments: + logger.error("No alignment files found in '%s'", self._args.alignments_file) + sys.exit(1) + + logger.info("Batch mode selected. Processing alignments: %s", alignments) + retval = {"alignments_file": alignments, + "faces_dir": [None for _ in range(len(alignments))], + "frames_dir": [None for _ in range(len(alignments))]} + return retval + + def _get_frames_locations(self) -> dict[str, list[str | None]]: + """ Obtain the full path to frame locations along with corresponding alignments file + locations contained within the parent (batch) location + + Returns + ------- + dict[str, list[Optional[str]]]: + list of frames and alignments location paths. If the job requires an output faces + location then the faces folders are also returned, otherwise the faces will be a list + of ``Nones`` corresponding to the number of jobs to run + """ + if not self._args.frames_dir: + logger.error("Please provide a 'frames_dir' location for '%s' job", self._args.job) + sys.exit(1) + + frames: list[str] = [] + alignments: list[str] = [] + candidates = [os.path.join(self._args.frames_dir, fname) + for fname in os.listdir(self._args.frames_dir) + if os.path.isdir(os.path.join(self._args.frames_dir, fname)) + or os.path.splitext(fname)[-1].lower() in VIDEO_EXTENSIONS] + logger.debug("Frame candidates: %s", candidates) + + for candidate in candidates: + fname = os.path.join(candidate, "alignments.fsa") + if os.path.isdir(candidate) and os.path.exists(fname): + frames.append(candidate) + alignments.append(fname) + continue + fname = f"{os.path.splitext(candidate)[0]}_alignments.fsa" + if os.path.isfile(candidate) and os.path.exists(fname): + frames.append(candidate) + alignments.append(fname) + continue + logger.warning("Can't locate alignments file for '%s'. Skipping.", candidate) + + if not frames: + logger.error("No valid videos or frames folders found in '%s'", self._args.frames_dir) + sys.exit(1) + + if self._args.job not in self._requires_faces: # faces not required for frames input + faces: list[str | None] = [None for _ in range(len(frames))] + else: + if not self._args.faces_dir: + logger.error("Please provide a 'faces_dir' location for '%s' job", self._args.job) + sys.exit(1) + faces = [os.path.join(self._args.faces_dir, os.path.basename(os.path.splitext(frm)[0])) + for frm in frames] + + logger.info("Batch mode selected. Processing frames: %s", + [os.path.basename(frame) for frame in frames]) + + return {"alignments_file": T.cast(list[str | None], alignments), + "frames_dir": T.cast(list[str | None], frames), + "faces_dir": faces} + + def _get_locations(self) -> dict[str, list[str | None]]: + """ Obtain the full path to any frame, face and alignments input locations for the + selected job when running in batch mode. If not running in batch mode, then the original + passed in values are returned in lists + + Returns + ------- + dict[str, list[Optional[str]]] + A dictionary corresponding to the alignments, frames_dir and faces_dir arguments + with a list of full paths for each job + """ + job: str = self._args.job + if not self._batch_mode: # handle with given arguments + retval = {"alignments_file": [self._args.alignments_file], + "faces_dir": [self._args.faces_dir], + "frames_dir": [self._args.frames_dir]} + + elif job in self._requires_alignments: # Jobs only requiring an alignments file location + retval = self._get_alignments_locations() + + elif job in self._requires_frames: # Jobs that require a frames folder + retval = self._get_frames_locations() + + elif job in self._requires_faces and job not in self._requires_frames: + # Jobs that require faces as input + faces = [os.path.join(self._args.faces_dir, folder) + for folder in os.listdir(self._args.faces_dir) + if os.path.isdir(os.path.join(self._args.faces_dir, folder))] + if not faces: + logger.error("No folders found in '%s'", self._args.faces_dir) + sys.exit(1) + + retval = {"faces_dir": faces, + "frames_dir": [None for _ in range(len(faces))], + "alignments_file": [None for _ in range(len(faces))]} + logger.info("Batch mode selected. Processing faces: %s", + [os.path.basename(folder) for folder in faces]) + else: + raise FaceswapError(f"Unhandled job: {self._args.job}. This is a bug. Please report " + "to the developers") + + logger.debug("File locations: %s", retval) + return retval + + @staticmethod + def _run_process(arguments) -> None: + """ The alignements tool process to be run in a spawned process. + + In some instances, batch-mode memory leaks. Launching each job in a separate process + prevents this leak. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments to be used for the given job + """ + logger.debug("Starting process: (arguments: %s)", arguments) + tool = _Alignments(arguments) + tool.process() + logger.debug("Finished process: (arguments: %s)", arguments) + + def process(self): + """ The entry point for the Alignments tool from :mod:`lib.tools.alignments.cli`. + + Launches the selected alignments job. + """ + num_jobs = len(self._locations["frames_dir"]) + for idx, (frames, faces, alignments) in enumerate(zip(self._locations["frames_dir"], + self._locations["faces_dir"], + self._locations["alignments_file"])): + if num_jobs > 1: + logger.info("Processing job %s of %s", idx + 1, num_jobs) + + args = Namespace(**self._args.__dict__) + args.frames_dir = frames + args.faces_dir = faces + args.alignments_file = alignments + + if num_jobs > 1: + proc = Process(target=self._run_process, args=(args, )) + proc.start() + proc.join() + else: + self._run_process(args) + + +class _Alignments(): + """ The main entry point for Faceswap's Alignments Tool. This tool is part of the Faceswap + Tools suite and should be called from the ``python tools.py alignments`` command. + + The tool allows for manipulation, and working with Faceswap alignments files. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: '%s'", self.__class__.__name__, arguments) + self._args = arguments + job = self._args.job + + if job == "from-faces": + self.alignments = None + else: + self.alignments = AlignmentData(self._find_alignments()) + + if (self.alignments is not None and + arguments.frames_dir and + os.path.isfile(arguments.frames_dir)): + self.alignments.update_legacy_has_source(os.path.basename(arguments.frames_dir)) + + logger.debug("Initialized %s", self.__class__.__name__) + + def _find_alignments(self) -> str: + """ If an alignments folder is required and hasn't been provided, scan for a file based on + the video folder. + + Exits if an alignments file cannot be located + + Returns + ------- + str + The full path to an alignments file + """ + fname = self._args.alignments_file + frames = self._args.frames_dir + if fname and os.path.isfile(fname) and os.path.splitext(fname)[-1].lower() == ".fsa": + return fname + if fname: + logger.error("Not a valid alignments file: '%s'", fname) + sys.exit(1) + + if not frames or not os.path.exists(frames): + logger.error("Not a valid frames folder: '%s'. Can't scan for alignments.", frames) + sys.exit(1) + + fname = "alignments.fsa" + if os.path.isdir(frames) and os.path.exists(os.path.join(frames, fname)): + return fname + + if os.path.isdir(frames) or os.path.splitext(frames)[-1] not in VIDEO_EXTENSIONS: + logger.error("Can't find a valid alignments file in location: %s", frames) + sys.exit(1) + + fname = f"{os.path.splitext(frames)[0]}_{fname}" + if not os.path.exists(fname): + logger.error("Can't find a valid alignments file for video: %s", frames) + sys.exit(1) + + return fname + + def process(self) -> None: + """ The entry point for the Alignments tool from :mod:`lib.tools.alignments.cli`. + + Launches the selected alignments job. + """ + if self._args.job in ("missing-alignments", "missing-frames", "multi-faces", "no-faces"): + job: T.Any = Check + else: + job = globals()[self._args.job.title().replace("-", "")] + job = job(self.alignments, self._args) + logger.debug(job) + job.process() + + +__all__ = get_module_objects(__name__) diff --git a/tools/alignments/cli.py b/tools/alignments/cli.py new file mode 100644 index 0000000000..510b8eba53 --- /dev/null +++ b/tools/alignments/cli.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" Command Line Arguments for tools """ +import sys +import gettext +import typing as T + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import DirOrFileFullPaths, DirFullPaths, FileFullPaths, Radio, Slider +from lib.utils import get_module_objects + +# LOCALES +_LANG = gettext.translation("tools.alignments.cli", localedir="locales", fallback=True) +_ = _LANG.gettext + + +_HELPTEXT = _("This command lets you perform various tasks pertaining to an alignments file.") + + +class AlignmentsArgs(FaceSwapArgs): + """ Class to parse the command line arguments for Alignments tool """ + + @staticmethod + def get_info() -> str: + """ Obtain command information. + + Returns + ------- + str + The help text for displaying in argparses help output + """ + return _("Alignments tool\nThis tool allows you to perform numerous actions on or using " + "an alignments file against its corresponding faceset/frame source.") + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Collect the argparse argument options. + + Returns + ------- + dict + The argparse command line options for processing by argparse + """ + frames_dir = _(" Must Pass in a frames folder/source video file (-r).") + faces_dir = _(" Must Pass in a faces folder (-c).") + frames_or_faces_dir = _(" Must Pass in either a frames folder/source video file OR a " + "faces folder (-r or -c).") + frames_and_faces_dir = _(" Must Pass in a frames folder/source video file AND a faces " + "folder (-r and -c).") + output_opts = _(" Use the output option (-o) to process results.") + argument_list = [] + argument_list.append({ + "opts": ("-j", "--job"), + "action": Radio, + "type": str, + "choices": ("draw", "extract", "export", "from-faces", "missing-alignments", + "missing-frames", "multi-faces", "no-faces", "remove-faces", "rename", + "sort", "spatial"), + "group": _("processing"), + "required": True, + "help": _( + "R|Choose which action you want to perform. NB: All actions require an " + "alignments file (-a) to be passed in." + "\nL|'draw': Draw landmarks on frames in the selected folder/video. A " + "subfolder will be created within the frames folder to hold the output.{0}" + "\nL|'export': Export the contents of an alignments file to a json file. Can be " + "used for editing alignment information in external tools and then re-importing " + "by using Faceswap's Extract 'Import' plugins. Note: masks and identity vectors " + "will not be included in the exported file, so will be re-generated when the json " + "file is imported back into Faceswap. All data is exported with the origin (0, 0) " + "at the top left of the canvas." + "\nL|'extract': Re-extract faces from the source frames/video based on " + "alignment data. This is a lot quicker than re-detecting faces. Can pass in " + "the '-een' (--extract-every-n) parameter to only extract every nth frame.{1}" + "\nL|'from-faces': Generate alignment file(s) from a folder of extracted " + "faces. if the folder of faces comes from multiple sources, then multiple " + "alignments files will be created. NB: for faces which have been extracted " + "from folders of source images, rather than a video, a single alignments file " + "will be created as there is no way for the process to know how many folders " + "of images were originally used. You do not need to provide an alignments file " + "path to run this job. {3}" + "\nL|'missing-alignments': Identify frames that do not exist in the alignments " + "file.{2}{0}" + "\nL|'missing-frames': Identify frames in the alignments file that do not " + "appear within the frames folder/video.{2}{0}" + "\nL|'multi-faces': Identify where multiple faces exist within the alignments " + "file.{2}{4}" + "\nL|'no-faces': Identify frames that exist within the alignment file but no " + "faces were detected.{2}{0}" + "\nL|'remove-faces': Remove deleted faces from an alignments file. The " + "original alignments file will be backed up.{3}" + "\nL|'rename' - Rename faces to correspond with their parent frame and " + "position index in the alignments file (i.e. how they are named after running " + "extract).{3}" + "\nL|'sort': Re-index the alignments from left to right. For alignments with " + "multiple faces this will ensure that the left-most face is at index 0." + "\nL|'spatial': Perform spatial and temporal filtering to smooth alignments " + "(EXPERIMENTAL!)").format(frames_dir, frames_and_faces_dir, output_opts, + faces_dir, frames_or_faces_dir)}) + argument_list.append({ + "opts": ("-o", "--output"), + "action": Radio, + "type": str, + "choices": ("console", "file", "move"), + "group": _("processing"), + "default": "console", + "help": _( + "R|How to output discovered items ('faces' and 'frames' only):" + "\nL|'console': Print the list of frames to the screen. (DEFAULT)" + "\nL|'file': Output the list of frames to a text file (stored within the " + "source directory)." + "\nL|'move': Move the discovered items to a sub-folder within the source " + "directory.")}) + argument_list.append({ + "opts": ("-a", "--alignments_file"), + "action": FileFullPaths, + "dest": "alignments_file", + "type": str, + "group": _("data"), + # hacky solution to not require alignments file if creating alignments from faces: + "required": not any(val in sys.argv for val in ["from-faces", + "-r", + "-frames_folder"]), + "filetypes": "alignments", + "help": _( + "Full path to the alignments file to be processed. If you have input a " + "'frames_dir' and don't provide this option, the process will try to find the " + "alignments file at the default location. All jobs require an alignments file " + "with the exception of 'from-faces' when the alignments file will be generated " + "in the specified faces folder.")}) + argument_list.append({ + "opts": ("-c", "-faces_folder"), + "action": DirFullPaths, + "dest": "faces_dir", + "group": ("data"), + "help": ("Directory containing extracted faces.")}) + argument_list.append({ + "opts": ("-r", "-frames_folder"), + "action": DirOrFileFullPaths, + "dest": "frames_dir", + "filetypes": "video", + "group": _("data"), + "help": _("Directory containing source frames that faces were extracted from.")}) + argument_list.append({ + "opts": ("-B", "--batch-mode"), + "action": "store_true", + "dest": "batch_mode", + "default": False, + "group": _("data"), + "help": _( + "R|Run the aligmnents tool on multiple sources. The following jobs support " + "batch mode:" + "\nL|draw, extract, from-faces, missing-alignments, missing-frames, no-faces, " + "sort, spatial." + "\nIf batch mode is selected then the other options should be set as follows:" + "\nL|alignments_file: For 'sort' and 'spatial' this should point to the parent " + "folder containing the alignments files to be processed. For all other jobs " + "this option is ignored, and the alignments files must exist at their default " + "location relative to the original frames folder/video." + "\nL|faces_dir: For 'from-faces' this should be a parent folder, containing " + "sub-folders of extracted faces from which to generate alignments files. For " + "'extract' this should be a parent folder where sub-folders will be created " + "for each extraction to be run. For all other jobs this option is ignored." + "\nL|frames_dir: For 'draw', 'extract', 'missing-alignments', 'missing-frames' " + "and 'no-faces' this should be a parent folder containing video files or sub-" + "folders of images to perform the alignments job on. The alignments file " + "should exist at the default location. For all other jobs this option is " + "ignored.")}) + argument_list.append({ + "opts": ("-N", "--extract-every-n"), + "type": int, + "action": Slider, + "dest": "extract_every_n", + "min_max": (1, 100), + "default": 1, + "rounding": 1, + "group": _("extract"), + "help": _( + "[Extract only] Extract every 'nth' frame. This option will skip frames when " + "extracting faces. For example a value of 1 will extract faces from every frame, " + "a value of 10 will extract faces from every 10th frame.")}) + argument_list.append({ + "opts": ("-z", "--size"), + "type": int, + "action": Slider, + "min_max": (256, 1024), + "rounding": 64, + "default": 512, + "group": _("extract"), + "help": _("[Extract only] The output size of extracted faces.")}) + argument_list.append({ + "opts": ("-m", "--min-size"), + "type": int, + "action": Slider, + "min_max": (0, 200), + "rounding": 1, + "default": 0, + "dest": "min_size", + "group": _("extract"), + "help": _( + "[Extract only] Only extract faces that have been resized by this percent or " + "more to meet the specified extract size (`-sz`, `--size`). Useful for " + "excluding low-res images from a training set. Set to 0 to extract all faces. " + "Eg: For an extract size of 512px, A setting of 50 will only include faces " + "that have been resized from 256px or above. Setting to 100 will only extract " + "faces that have been resized from 512px or above. A setting of 200 will only " + "extract faces that have been downscaled from 1024px or above.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/alignments/jobs.py b/tools/alignments/jobs.py new file mode 100644 index 0000000000..78ccce3ba2 --- /dev/null +++ b/tools/alignments/jobs.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python3 +""" Tools for manipulating the alignments serialized file """ +from __future__ import annotations +import logging +import os +import sys +import typing as T + +from datetime import datetime + +import numpy as np +from scipy import signal +from sklearn import decomposition +from tqdm import tqdm + +from lib.logger import parse_class_init +from lib.serializer import get_serializer +from lib.utils import get_module_objects, FaceswapError + +from .media import Faces, Frames +from .jobs_faces import FaceToFile + +if T.TYPE_CHECKING: + from collections.abc import Generator + from argparse import Namespace + from lib.align.alignments import AlignmentFileDict, PNGHeaderDict + from .media import AlignmentData + +logger = logging.getLogger(__name__) + + +class Check: + """ Frames and faces checking tasks. + + Parameters + --------- + alignments : :class:`tools.alignments.media.AlignmentsData` + The loaded alignments corresponding to the frames to be annotated + arguments : :class:`argparse.Namespace` + The command line arguments that have called this job + """ + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: + logger.debug(parse_class_init(locals())) + self._alignments = alignments + self._job = arguments.job + self._type: T.Literal["faces", "frames"] | None = None + self._is_video = False # Set when getting items + self._output = arguments.output + self._source_dir = self._get_source_dir(arguments) + self._validate() + self._items = self._get_items() + + self.output_message = "" + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_source_dir(self, arguments: Namespace) -> str: + """ Set the correct source folder + + Parameters + ---------- + arguments : :class:`argparse.Namespace` + The command line arguments for the Alignments tool + + Returns + ------- + str + Full path to the source folder + """ + if (hasattr(arguments, "faces_dir") and arguments.faces_dir and + hasattr(arguments, "frames_dir") and arguments.frames_dir): + logger.error("Only select a source frames (-fr) or source faces (-fc) folder") + sys.exit(1) + elif hasattr(arguments, "faces_dir") and arguments.faces_dir: + self._type = "faces" + source_dir = arguments.faces_dir + elif hasattr(arguments, "frames_dir") and arguments.frames_dir: + self._type = "frames" + source_dir = arguments.frames_dir + else: + logger.error("No source folder (-fr or -fc) was provided") + sys.exit(1) + logger.debug("type: '%s', source_dir: '%s'", self._type, source_dir) + return source_dir + + def _get_items(self) -> list[dict[str, str]] | list[tuple[str, PNGHeaderDict]]: + """ Set the correct items to process + + Returns + ------- + list[dict[str, str]] | list[tuple[str, :class:`~lib.align.alignments.PNGHeaderDict`]] + Sorted list of dictionaries for either faces or frames. If faces the dictionaries + have the current filename as key, with the header source data as value. If frames + the dictionaries will contain the keys 'frame_fullname', 'frame_name', 'extension'. + """ + assert self._type is not None + items: Frames | Faces = globals()[self._type.title()](self._source_dir) + self._is_video = items.is_video + return T.cast(list[dict[str, str]] | list[tuple[str, "PNGHeaderDict"]], + items.file_list_sorted) + + def process(self) -> None: + """ Process the frames check against the alignments file """ + assert self._type is not None + logger.info("[CHECK %s]", self._type.upper()) + items_output = self._compile_output() + + if self._type == "faces": + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._items) + check_update = FaceToFile(self._alignments, [val[1] for val in filelist]) + if check_update(): + self._alignments.save() + + self._output_results(items_output) + + def _validate(self) -> None: + """ Check that the selected type is valid for selected task and job """ + if self._job == "missing-frames" and self._output == "move": + logger.warning("Missing_frames was selected with move output, but there will " + "be nothing to move. Defaulting to output: console") + self._output = "console" + if self._type == "faces" and self._job != "multi-faces": + logger.error("The selected folder is not valid. Faces folder (-fc) is only " + "supported for 'multi-faces'") + sys.exit(1) + + def _compile_output(self) -> list[str] | list[tuple[str, int]]: + """ Compile list of frames that meet criteria + + Returns + ------- + list[str] | list[tuple[str, int]] + List of filenames or filenames and face indices for the selected criteria + """ + action = self._job.replace("-", "_") + processor = getattr(self, f"_get_{action}") + logger.debug("Processor: %s", processor) + return [item for item in processor()] # pylint:disable=unnecessary-comprehension + + def _get_no_faces(self) -> Generator[str, None, None]: + """ yield each frame that has no face match in alignments file + + Yields + ------ + str + The frame name of any frames which have no faces + """ + self.output_message = "Frames with no faces" + for frame in tqdm(T.cast(list[dict[str, str]], self._items), + desc=self.output_message, + leave=False): + logger.trace(frame) # type:ignore + frame_name = frame["frame_fullname"] + if not self._alignments.frame_has_faces(frame_name): + logger.debug("Returning: '%s'", frame_name) + yield frame_name + + def _get_multi_faces(self) -> (Generator[str, None, None] | + Generator[tuple[str, int], None, None]): + """ yield each frame or face that has multiple faces matched in alignments file + + Yields + ------ + str | tuple + The frame name of any frames which have multiple faces and potentially the face id + """ + process_type = getattr(self, f"_get_multi_faces_{self._type}") + yield from process_type() + + def _get_multi_faces_frames(self) -> Generator[str, None, None]: + """ Return Frames that contain multiple faces + + Yields + ------ + str + The frame name of any frames which have multiple faces + """ + self.output_message = "Frames with multiple faces" + for item in tqdm(T.cast(list[dict[str, str]], self._items), + desc=self.output_message, + leave=False): + filename = item["frame_fullname"] + if not self._alignments.frame_has_multiple_faces(filename): + continue + logger.trace("Returning: '%s'", filename) # type:ignore + yield filename + + def _get_multi_faces_faces(self) -> Generator[tuple[str, int], None, None]: + """ Return Faces when there are multiple faces in a frame + + Yields + ------ + tuple[str, int] + The frame name and the face id of any frames which have multiple faces + """ + self.output_message = "Multiple faces in frame" + for item in tqdm(T.cast(list[tuple[str, "PNGHeaderDict"]], self._items), + desc=self.output_message, + leave=False): + src = item[1]["source"] + if not self._alignments.frame_has_multiple_faces(src["source_filename"]): + continue + retval = (item[0], src["face_index"]) + logger.trace("Returning: '%s'", retval) # type:ignore + yield retval + + def _get_missing_alignments(self) -> Generator[str, None, None]: + """ yield each frame that does not exist in alignments file + + Yields + ------ + str + The frame name of any frames missing alignments + """ + self.output_message = "Frames missing from alignments file" + exclude_filetypes = set(["yaml", "yml", "p", "json", "txt"]) + for frame in tqdm(T.cast(dict[str, str], self._items), + desc=self.output_message, + leave=False): + frame_name = frame["frame_fullname"] + if (frame["frame_extension"] not in exclude_filetypes + and not self._alignments.frame_exists(frame_name)): + logger.debug("Returning: '%s'", frame_name) + yield frame_name + + def _get_missing_frames(self) -> Generator[str, None, None]: + """ yield each frame in alignments that does not have a matching file + + Yields + ------ + str + The frame name of any frames in alignments with no matching file + """ + self.output_message = "Missing frames that are in alignments file" + frames = set(item["frame_fullname"] for item in T.cast(list[dict[str, str]], self._items)) + for frame in tqdm(self._alignments.data.keys(), desc=self.output_message, leave=False): + if frame not in frames: + logger.debug("Returning: '%s'", frame) + yield frame + + def _output_results(self, items_output: list[str] | list[tuple[str, int]]) -> None: + """ Output the results in the requested format + + Parameters + ---------- + items_output : list[str] + The list of frame names, and potentially face ids, of any items which met the + selection criteria + """ + logger.trace("items_output: %s", items_output) # type:ignore + if self._output == "move" and self._is_video and self._type == "frames": + logger.warning("Move was selected with an input video. This is not possible so " + "falling back to console output") + self._output = "console" + if not items_output: + logger.info("No %s were found meeting the criteria", self._type) + return + if self._output == "move": + self._move_file(items_output) + return + if self._job == "multi-faces" and self._type == "faces": + # Strip the index for printed/file output + final_output = [item[0] for item in items_output] + else: + final_output = T.cast(list[str], items_output) + output_message = "-----------------------------------------------\r\n" + output_message += f" {self.output_message} ({len(final_output)})\r\n" + output_message += "-----------------------------------------------\r\n" + output_message += "\r\n".join(final_output) + if self._output == "console": + for line in output_message.splitlines(): + logger.info(line) + if self._output == "file": + self.output_file(output_message, len(final_output)) + + def _get_output_folder(self) -> str: + """ Return output folder. Needs to be in the root if input is a video and processing + frames + + Returns + ------- + str + Full path to the output folder + """ + if self._is_video and self._type == "frames": + return os.path.dirname(self._source_dir) + return self._source_dir + + def _get_filename_prefix(self) -> str: + """ Video name needs to be prefixed to filename if input is a video and processing frames + + Returns + ------- + str + The common filename prefix to use + """ + if self._is_video and self._type == "frames": + return f"{os.path.basename(self._source_dir)}_" + return "" + + def output_file(self, output_message: str, items_discovered: int) -> None: + """ Save the output to a text file in the frames directory + + Parameters + ---------- + output_message : str + The message to write out to file + items_discovered : int + The number of items which matched the criteria + """ + now = datetime.now().strftime("%Y%m%d_%H%M%S") + dst_dir = self._get_output_folder() + filename = (f"{self._get_filename_prefix()}{self.output_message.replace(' ', '_').lower()}" + f"_{now}.txt") + output_file = os.path.join(dst_dir, filename) + logger.info("Saving %s result(s) to '%s'", items_discovered, output_file) + with open(output_file, "w", encoding="utf8") as f_output: + f_output.write(output_message) + + def _move_file(self, items_output: list[str] | list[tuple[str, int]]) -> None: + """ Move the identified frames to a new sub folder + + Parameters + ---------- + items_output : list[str] | list[tuple[str, int]] + List of items to move + """ + now = datetime.now().strftime("%Y%m%d_%H%M%S") + folder_name = (f"{self._get_filename_prefix()}" + f"{self.output_message.replace(' ', '_').lower()}_{now}") + dst_dir = self._get_output_folder() + output_folder = os.path.join(dst_dir, folder_name) + logger.debug("Creating folder: '%s'", output_folder) + os.makedirs(output_folder) + move = getattr(self, f"_move_{self._type}") + logger.debug("Move function: %s", move) + move(output_folder, items_output) + + def _move_frames(self, output_folder: str, items_output: list[str]) -> None: + """ Move frames into single sub folder + + Parameters + ---------- + output_folder : str + The folder to move the output to + items_output : list + List of items to move + """ + logger.info("Moving %s frame(s) to '%s'", len(items_output), output_folder) + for frame in items_output: + src = os.path.join(self._source_dir, frame) + dst = os.path.join(output_folder, frame) + logger.debug("Moving: '%s' to '%s'", src, dst) + os.rename(src, dst) + + def _move_faces(self, output_folder: str, items_output: list[tuple[str, int]]) -> None: + """ Make additional sub folders for each face that appears Enables easier manual sorting + + Parameters + ---------- + output_folder : str + The folder to move the output to + items_output : list + List of items and face indices to move + """ + logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder) + for frame, idx in items_output: + src = os.path.join(self._source_dir, frame) + dst_folder = os.path.join(output_folder, str(idx)) if idx != -1 else output_folder + if not os.path.isdir(dst_folder): + logger.debug("Creating folder: '%s'", dst_folder) + os.makedirs(dst_folder) + dst = os.path.join(dst_folder, frame) + logger.debug("Moving: '%s' to '%s'", src, dst) + os.rename(src, dst) + + +class Export: + """ Export alignments from a Faceswap .fsa file to a json formatted file. + + Parameters + ---------- + alignments : :class:`tools.lib_alignments.media.AlignmentData` + The alignments data loaded from an alignments file for this rename job + arguments : :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py`. Unused + """ + def __init__(self, + alignments: AlignmentData, + arguments: Namespace) -> None: # pylint:disable=unused-argument + logger.debug(parse_class_init(locals())) + self._alignments = alignments + self._serializer = get_serializer("json") + self._output_file = self._get_output_file() + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_output_file(self) -> str: + """ Obtain the name of an output file. If a file of the request name exists, then append a + digit to the end until a unique filename is found + + Returns + ------- + str + Full path to an output json file + """ + in_file = self._alignments.file + base_filename = f"{os.path.splitext(in_file)[0]}_export" + out_file = f"{base_filename}.json" + idx = 1 + while True: + if not os.path.exists(out_file): + break + logger.debug("Output file exists: '%s'", out_file) + out_file = f"{base_filename}_{idx}.json" + idx += 1 + logger.debug("Setting output file to '%s'", out_file) + return out_file + + @classmethod + def _format_face(cls, face: AlignmentFileDict) -> dict[str, list[int] | list[list[float]]]: + """ Format the relevant keys from an alignment file's face into the correct format for + export/import + + Parameters + ---------- + face : :class:`~lib.align.alignments.AlignmentFileDict` + The alignment dictionary for a face to process + + Returns + ------- + dict[str, list[int] | list[list[float]]] + The face formatted for exporting to a json file + """ + lms = face["landmarks_xy"] + assert isinstance(lms, np.ndarray) + retval = {"detected": [int(round(face["x"], 0)), + int(round(face["y"], 0)), + int(round(face["x"] + face["w"], 0)), + int(round(face["y"] + face["h"], 0))], + "landmarks_2d": lms.tolist()} + return retval + + def process(self) -> None: + """ Parse the imported alignments file and output relevant information to a json file """ + logger.info("[EXPORTING ALIGNMENTS]") # Tidy up cli output + formatted = {key: [self._format_face(face) for face in val["faces"]] + for key, val in self._alignments.data.items()} + logger.info("Saving export alignments to '%s'...", self._output_file) + self._serializer.save(self._output_file, formatted) + + +class Sort: + """ Sort alignments' index by the order they appear in an image in left to right order. + + Parameters + ---------- + alignments : :class:`tools.lib_alignments.media.AlignmentData` + The alignments data loaded from an alignments file for this rename job + arguments : :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py`. Unused + """ + def __init__(self, + alignments: AlignmentData, + arguments: Namespace) -> None: # pylint:disable=unused-argument + logger.debug(parse_class_init(locals())) + self._alignments = alignments + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self) -> None: + """ Execute the sort process """ + logger.info("[SORT INDEXES]") # Tidy up cli output + reindexed = self.reindex_faces() + if reindexed: + self._alignments.save() + logger.warning("If you have a face-set corresponding to the alignment file you " + "processed then you should run the 'Extract' job to regenerate it.") + + def reindex_faces(self) -> int: + """ Re-Index the faces + + Returns + ------- + int + The count of re-indexed faces + """ + reindexed = 0 + for alignment in tqdm(self._alignments.yield_faces(), + desc="Sort alignment indexes", + total=self._alignments.frames_count, + leave=False): + frame, alignments, count, key = alignment + if count <= 1: + logger.trace("0 or 1 face in frame. Not sorting: '%s'", frame) # type:ignore + continue + sorted_alignments = sorted(alignments, key=lambda x: (x["x"])) + if sorted_alignments == alignments: + logger.trace("Alignments already in correct order. Not " # type:ignore + "sorting: '%s'", frame) + continue + logger.trace("Sorting alignments for frame: '%s'", frame) # type:ignore + self._alignments.data[key]["faces"] = sorted_alignments + reindexed += 1 + logger.info("%s Frames had their faces reindexed", reindexed) + return reindexed + + +class Spatial: + """ Apply spatial temporal filtering to landmarks + + Parameters + ---------- + alignments : :class:`tools.lib_alignments.media.AlignmentData` + The alignments data loaded from an alignments file for this rename job + arguments : :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + + Reference + --------- + https://www.kaggle.com/selfishgene/animating-and-smoothing-3d-facial-keypoints/notebook + """ + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: + logger.debug(parse_class_init(locals())) + self.arguments = arguments + self._alignments = alignments + self._mappings: dict[int, str] = {} + self._normalized: dict[str, np.ndarray] = {} + self._shapes_model: decomposition.PCA | None = None + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self) -> None: + """ Perform spatial filtering """ + logger.info("[SPATIO-TEMPORAL FILTERING]") # Tidy up cli output + logger.info("NB: The process only processes the alignments for the first " + "face it finds for any given frame. For best results only run this when " + "there is only a single face in the alignments file and all false positives " + "have been removed") + + self._normalize() + self._shape_model() + landmarks = self._spatially_filter() + landmarks = self._temporally_smooth(landmarks) + self._update_alignments(landmarks) + self._alignments.save() + logger.warning("If you have a face-set corresponding to the alignment file you " + "processed then you should run the 'Extract' job to regenerate it.") + + # Define shape normalization utility functions + @staticmethod + def _normalize_shapes(shapes_im_coords: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ Normalize a 2D or 3D shape + + Parameters + ---------- + shaped_im_coords : :class:`numpy.ndarray` + The facial landmarks + + Returns + ------- + shapes_normalized : :class:`numpy.ndarray` + The normalized shapes + scale_factors : :class:`numpy.ndarray` + The scale factors + mean_coords : :class:`numpy.ndarray` + The mean coordinates + """ + logger.debug("Normalize shapes") + (num_pts, num_dims, _) = shapes_im_coords.shape + + # Calculate mean coordinates and subtract from shapes + mean_coords = shapes_im_coords.mean(axis=0) + shapes_centered = np.zeros(shapes_im_coords.shape) + shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1]) + + # Calculate scale factors and divide shapes + scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0) + shapes_normalized = np.zeros(shapes_centered.shape) + shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1]) + + logger.debug("Normalized shapes: (shapes_normalized: %s, scale_factors: %s, mean_coords: " + "%s", shapes_normalized, scale_factors, mean_coords) + return shapes_normalized, scale_factors, mean_coords + + @staticmethod + def _normalized_to_original(shapes_normalized: np.ndarray, + scale_factors: np.ndarray, + mean_coords: np.ndarray) -> np.ndarray: + """ Transform a normalized shape back to original image coordinates + + Parameters + ---------- + shapes_normalized : :class:`numpy.ndarray` + The normalized shapes + scale_factors : :class:`numpy.ndarray` + The scale factors + mean_coords : :class:`numpy.ndarray` + The mean coordinates + + Returns + ------- + :class:`numpy.ndarray` + The normalized shape transformed back to original coordinates + """ + logger.debug("Normalize to original") + (num_pts, num_dims, _) = shapes_normalized.shape + + # move back to the correct scale + shapes_centered = shapes_normalized * np.tile(scale_factors, [num_pts, num_dims, 1]) + # move back to the correct location + shapes_im_coords = shapes_centered + np.tile(mean_coords, [num_pts, 1, 1]) + + logger.debug("Normalized to original: %s", shapes_im_coords) + return shapes_im_coords + + def _normalize(self) -> None: + """ Compile all original and normalized alignments """ + logger.debug("Normalize") + count = sum(1 for val in self._alignments.data.values() if val["faces"]) + + sample_lm = next((val["faces"][0]["landmarks_xy"] + for val in self._alignments.data.values() if val["faces"]), 68) + assert isinstance(sample_lm, np.ndarray) + lm_count = sample_lm.shape[0] + if lm_count != 68: + raise FaceswapError("Spatial smoothing only supports 68 point facial landmarks") + + landmarks_all = np.zeros((lm_count, 2, int(count))) + + end = 0 + for key in tqdm(sorted(self._alignments.data.keys()), desc="Compiling", leave=False): + val = self._alignments.data[key]["faces"] + if not val: + continue + # We should only be normalizing a single face, so just take + # the first landmarks found + landmarks = np.array(val[0]["landmarks_xy"]).reshape((lm_count, 2, 1)) + start = end + end = start + landmarks.shape[2] + # Store in one big array + landmarks_all[:, :, start:end] = landmarks + # Make sure we keep track of the mapping to the original frame + self._mappings[start] = key + + # Normalize shapes + normalized_shape = self._normalize_shapes(landmarks_all) + self._normalized["landmarks"] = normalized_shape[0] + self._normalized["scale_factors"] = normalized_shape[1] + self._normalized["mean_coords"] = normalized_shape[2] + logger.debug("Normalized: %s", self._normalized) + + def _shape_model(self) -> None: + """ build 2D shape model """ + logger.debug("Shape model") + landmarks_norm = self._normalized["landmarks"] + num_components = 20 + normalized_shapes_tbl = np.reshape(landmarks_norm, [68*2, landmarks_norm.shape[2]]).T + self._shapes_model = decomposition.PCA(n_components=num_components, + whiten=True, + random_state=1).fit(normalized_shapes_tbl) + explained = self._shapes_model.explained_variance_ratio_.sum() + logger.info("Total explained percent by PCA model with %s components is %s%%", + num_components, round(100 * explained, 1)) + logger.debug("Shaped model") + + def _spatially_filter(self) -> np.ndarray: + """ interpret the shapes using our shape model (project and reconstruct) + + Returns + ------- + :class:`numpy.ndarray` + The filtered landmarks in original coordinate space + """ + logger.debug("Spatially Filter") + assert self._shapes_model is not None + landmarks_norm = self._normalized["landmarks"] + # Convert to matrix form + landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T + # Project onto shapes model and reconstruct + landmarks_norm_table_rec = self._shapes_model.inverse_transform( + self._shapes_model.transform(landmarks_norm_table)) + # Convert back to shapes (numKeypoint, num_dims, numFrames) + landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T, + [68, 2, landmarks_norm.shape[2]]) + # Transform back to image co-ordinates + retval = self._normalized_to_original(landmarks_norm_rec, + self._normalized["scale_factors"], + self._normalized["mean_coords"]) + + logger.debug("Spatially Filtered: %s", retval) + return retval + + @staticmethod + def _temporally_smooth(landmarks: np.ndarray) -> np.ndarray: + """ apply temporal filtering on the 2D points + + Parameters + ---------- + landmarks : :class:`numpy.ndarray` + 68 point landmarks to be temporally smoothed + + Returns + ------- + :class: `numpy.ndarray` + The temporally smoothed landmarks + """ + logger.debug("Temporally Smooth") + filter_half_length = 2 + temporal_filter = np.ones((1, 1, 2 * filter_half_length + 1)) + temporal_filter = temporal_filter / temporal_filter.sum() + + start_tileblock = np.tile(landmarks[:, :, 0][:, :, np.newaxis], [1, 1, filter_half_length]) + end_tileblock = np.tile(landmarks[:, :, -1][:, :, np.newaxis], [1, 1, filter_half_length]) + landmarks_padded = np.dstack((start_tileblock, landmarks, end_tileblock)) + + retval = signal.convolve(landmarks_padded, temporal_filter, mode='valid', method='fft') + logger.debug("Temporally Smoothed: %s", retval) + return retval + + def _update_alignments(self, landmarks: np.ndarray) -> None: + """ Update smoothed landmarks back to alignments + + Parameters + ---------- + landmarks : :class:`numpy.ndarray` + The smoothed landmarks + """ + logger.debug("Update alignments") + for idx, frame in tqdm(self._mappings.items(), desc="Updating", leave=False): + logger.trace("Updating: (frame: %s)", frame) # type:ignore + landmarks_update = landmarks[:, :, idx] + landmarks_xy = landmarks_update.reshape(68, 2).tolist() + self._alignments.data[frame]["faces"][0]["landmarks_xy"] = landmarks_xy + logger.trace("Updated: (frame: '%s', landmarks: %s)", # type:ignore + frame, landmarks_xy) + logger.debug("Updated alignments") + + +__all__ = get_module_objects(__name__) diff --git a/tools/alignments/jobs_faces.py b/tools/alignments/jobs_faces.py new file mode 100644 index 0000000000..066558c39a --- /dev/null +++ b/tools/alignments/jobs_faces.py @@ -0,0 +1,488 @@ +#!/usr/bin/env python3 +""" Tools for manipulating the alignments using extracted Faces as a source """ +from __future__ import annotations +import logging +import os +import typing as T + +from argparse import Namespace +from operator import itemgetter + +import numpy as np +from tqdm import tqdm + +from lib.align import DetectedFace +from lib.image import update_existing_metadata # TODO remove +from lib.utils import get_module_objects +from scripts.fsmedia import Alignments + +from .media import Faces + +if T.TYPE_CHECKING: + from .media import AlignmentData + from lib.align.alignments import (AlignmentDict, AlignmentFileDict, + PNGHeaderDict, PNGHeaderAlignmentsDict) + +logger = logging.getLogger(__name__) + + +class FromFaces(): + """ Scan a folder of Faceswap Extracted Faces and re-create the associated alignments file(s) + + Parameters + ---------- + alignments: NoneType + Parameter included for standard job naming convention, but not used for this process. + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + def __init__(self, alignments: None, arguments: Namespace) -> None: + logger.debug("Initializing %s: (alignments: %s, arguments: %s)", + self.__class__.__name__, alignments, arguments) + self._faces_dir = arguments.faces_dir + self._faces = Faces(arguments.faces_dir) + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self) -> None: + """ Run the job to read faces from a folder to create alignments file(s). """ + logger.info("[CREATE ALIGNMENTS FROM FACES]") # Tidy up cli output + + all_versions: dict[str, list[float]] = {} + d_align: dict[str, dict[str, list[tuple[int, AlignmentFileDict, str, dict]]]] = {} + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) + for filename, meta in tqdm(filelist, + desc="Generating Alignments", + total=len(filelist), + leave=False): + + align_fname = self._get_alignments_filename(meta["source"]) + source_name, f_idx, alignment = self._extract_alignment(meta) + full_info = (f_idx, alignment, filename, meta["source"]) + + d_align.setdefault(align_fname, {}).setdefault(source_name, []).append(full_info) + all_versions.setdefault(align_fname, []).append(meta["source"]["alignments_version"]) + + versions = {k: min(v) for k, v in all_versions.items()} + alignments = self._sort_alignments(d_align) + self._save_alignments(alignments, versions) + + @classmethod + def _get_alignments_filename(cls, source_data: dict) -> str: + """ Obtain the name of the alignments file from the source information contained within the + PNG metadata. + + Parameters + ---------- + source_data: dict + The source information contained within a Faceswap extracted PNG + + Returns + ------- + str: + If the face was generated from a video file, the filename will be + `'_alignments.fsa'`. If it was extracted from an image file it will be + `'alignments.fsa'` + """ + is_video = source_data["source_is_video"] + src_name = source_data["source_filename"] + prefix = f"{src_name.rpartition('_')[0]}_" if is_video else "" + retval = f"{prefix}alignments.fsa" + logger.trace("Extracted alignments file filename: '%s'", retval) # type:ignore + return retval + + def _extract_alignment(self, metadata: dict) -> tuple[str, int, AlignmentFileDict]: + """ Extract alignment data from a PNG image's itxt header. + + Formats the landmarks into a numpy array and adds in mask centering information if it is + from an older extract. + + Parameters + ---------- + metadata: dict + An extracted faces PNG Header data + + Returns + ------- + tuple + The alignment's source frame name in position 0. The index of the face within the + alignment file in position 1. The alignment data correctly formatted for writing to an + alignments file in positin 2 + """ + alignment = metadata["alignments"] + alignment["landmarks_xy"] = np.array(alignment["landmarks_xy"], dtype="float32") + + src = metadata["source"] + frame_name = src["source_filename"] + face_index = int(src["face_index"]) + + logger.trace("Extracted alignment for frame: '%s', face index: %s", # type:ignore + frame_name, face_index) + return frame_name, face_index, alignment + + def _sort_alignments(self, + alignments: dict[str, dict[str, list[tuple[int, + AlignmentFileDict, + str, + dict]]]] + ) -> dict[str, dict[str, AlignmentDict]]: + """ Sort the faces into face index order as they appeared in the original alignments file. + + If the face index stored in the png header does not match it's position in the alignments + file (i.e. A face has been removed from a frame) then update the header of the + corresponding png to the correct index as exists in the newly created alignments file. + + Parameters + ---------- + alignments: dict + The unsorted alignments file(s) as generated from the face PNG headers, including the + face index of the face within it's respective frame, the original face filename and + the orignal face header source information + + Returns + ------- + dict + The alignments file dictionaries sorted into the correct face order, ready for saving + """ + logger.info("Sorting and checking faces...") + aln_sorted: dict[str, dict[str, AlignmentDict]] = {} + for fname, frames in alignments.items(): + this_file: dict[str, AlignmentDict] = {} + for frame in tqdm(sorted(frames), desc=f"Sorting {fname}", leave=False): + this_file[frame] = {"video_meta": {}, "faces": []} + for real_idx, (f_id, almt, f_path, f_src) in enumerate(sorted(frames[frame], + key=itemgetter(0))): + if real_idx != f_id: + full_path = os.path.join(self._faces_dir, f_path) + self._update_png_header(full_path, real_idx, almt, f_src) + this_file[frame]["faces"].append(almt) + aln_sorted[fname] = this_file + return aln_sorted + + @classmethod + def _update_png_header(cls, + face_path: str, + new_index: int, + alignment: AlignmentFileDict, + source_info: dict) -> None: + """ Update the PNG header for faces where the stored index does not correspond with the + alignments file. This can occur when frames with multiple faces have had some faces deleted + from the faces folder. + + Updates the original filename and index in the png header. + + Parameters + ---------- + face_path: str + Full path to the saved face image that requires updating + new_index: int + The new index as it appears in the newly generated alignments file + alignment: dict + The alignment information to store in the png header + source_info: dict + The face source information as extracted from the original face png file + """ + face = DetectedFace() + face.from_alignment(alignment) + new_filename = f"{os.path.splitext(source_info['source_filename'])[0]}_{new_index}.png" + + logger.trace("Updating png header for '%s': (face index from %s to %s, " # type:ignore + "original filename from '%s' to '%s'", face_path, source_info["face_index"], + new_index, source_info["original_filename"], new_filename) + + source_info["face_index"] = new_index + source_info["original_filename"] = new_filename + meta = {"alignments": face.to_png_meta(), "source": source_info} + update_existing_metadata(face_path, meta) + + def _save_alignments(self, + all_alignments: dict[str, dict[str, AlignmentDict]], + versions: dict[str, float]) -> None: + """ Save the newely generated alignments file(s). + + If an alignments file already exists in the source faces folder, back it up rather than + overwriting + + Parameters + ---------- + all_alignments: dict + The alignment(s) dictionaries found in the faces folder. Alignment filename as key, + corresponding alignments as value. + versions: dict + The minimum version number that exists in a face set for each alignments file to be + generated + """ + for fname, alignments in all_alignments.items(): + version = versions[fname] + alignments_path = os.path.join(self._faces_dir, fname) + dummy_args = Namespace(alignments_path=alignments_path) + aln = Alignments(dummy_args, is_extract=True) + aln.update_from_dict(alignments) + aln._io._version = version # pylint:disable=protected-access + aln._io.update_legacy() # pylint:disable=protected-access + aln.backup() + aln.save() + + +class Rename(): + """ Rename faces in a folder to match their filename as stored in an alignments file. + + Parameters + ---------- + alignments: :class:`tools.lib_alignments.media.AlignmentData` + The alignments data loaded from an alignments file for this rename job + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + faces: :class:`tools.lib_alignments.media.Faces`, Optional + An optional faces object, if the rename task is being called by another job. + Default: ``None`` + """ + def __init__(self, + alignments: AlignmentData, + arguments: Namespace | None, + faces: Faces | None = None) -> None: + logger.debug("Initializing %s: (arguments: %s, faces: %s)", + self.__class__.__name__, arguments, faces) + self._alignments = alignments + + kwargs = {} + if alignments.version < 2.1: + # Update headers of faces generated with hash based alignments + kwargs["alignments"] = alignments + if faces: + self._faces = faces + else: + assert arguments is not None + self._faces = Faces(arguments.faces_dir, **kwargs) # type:ignore # needs TypedDict :/ + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self) -> None: + """ Process the face renaming """ + logger.info("[RENAME FACES]") # Tidy up cli output + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) + rename_mappings = sorted([(face[0], face[1]["source"]["original_filename"]) + for face in filelist + if face[0] != face[1]["source"]["original_filename"]], + key=lambda x: x[1]) + rename_count = self._rename_faces(rename_mappings) + logger.info("%s faces renamed", rename_count) + + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) + copyback = FaceToFile(self._alignments, [val[1] for val in filelist]) + if copyback(): + self._alignments.save() + + def _rename_faces(self, filename_mappings: list[tuple[str, str]]) -> int: + """ Rename faces back to their original name as exists in the alignments file. + + If the source and destination filename are the same then skip that file. + + Parameters + ---------- + filename_mappings: list + List of tuples of (`source filename`, `destination filename`) ordered by destination + filename + + Returns + ------- + int + The number of faces that have been renamed + """ + if not filename_mappings: + return 0 + + rename_count = 0 + conflicts = [] + for src, dst in tqdm(filename_mappings, desc="Renaming Faces", leave=False): + old = os.path.join(self._faces.folder, src) + new = os.path.join(self._faces.folder, dst) + + if os.path.exists(new): + # Interim add .tmp extension to files that will cause a rename conflict, to + # process afterwards + logger.debug("interim renaming file to avoid conflict: (src: '%s', dst: '%s')", + src, dst) + new = new + ".tmp" + conflicts.append(new) + + logger.verbose("Renaming '%s' to '%s'", old, new) # type:ignore + os.rename(old, new) + rename_count += 1 + if conflicts: + for old in tqdm(conflicts, desc="Renaming Faces", leave=False): + new = old[:-4] # Remove .tmp extension + if os.path.exists(new): + # This should only be running on faces. If there is still a conflict + # then the user has done something stupid, so we will delete the file and + # replace. They can always re-extract :/ + os.remove(new) + logger.verbose("Renaming '%s' to '%s'", old, new) # type:ignore + os.rename(old, new) + return rename_count + + +class RemoveFaces(): + """ Remove items from alignments file. + + Parameters + --------- + alignments: :class:`tools.alignments.media.AlignmentsData` + The loaded alignments containing faces to be removed + arguments: :class:`argparse.Namespace` + The command line arguments that have called this job + """ + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) + self._alignments = alignments + + self._items = Faces(arguments.faces_dir, alignments=alignments) + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self) -> None: + """ Run the job to remove faces from an alignments file that do not exist within a faces + folder. """ + logger.info("[REMOVE FACES FROM ALIGNMENTS]") # Tidy up cli output + + if not self._items.items: + logger.error("No matching faces found in your faces folder. This would remove all " + "faces from your alignments file. Process aborted.") + return + + items = T.cast(dict[str, list[int]], self._items.items) + pre_face_count = self._alignments.faces_count + self._alignments.filter_faces(items, filter_out=False) + del_count = pre_face_count - self._alignments.faces_count + if del_count == 0: + logger.info("No changes made to alignments file. Exiting") + return + + logger.info("%s alignment(s) were removed from alignments file", del_count) + + self._update_png_headers() + self._alignments.save() + + rename = Rename(self._alignments, None, self._items) + rename.process() + + def _update_png_headers(self) -> None: + """ Update the EXIF iTXt field of any face PNGs that have had their face index changed. + + Notes + ----- + This could be quicker if parellizing in threads, however, Windows (at least) does not seem + to like this and has a tendency to throw permission errors, so this remains single threaded + for now. + """ + items = T.cast(dict[str, list[int]], self._items.items) + srcs = [(x[0], x[1]["source"]) + for x in T.cast(list[tuple[str, "PNGHeaderDict"]], self._items.file_list_sorted)] + to_update = [ # Items whose face index has changed + x for x in srcs + if x[1]["face_index"] != items[x[1]["source_filename"]].index(x[1]["face_index"])] + + for item in tqdm(to_update, desc="Updating PNG Headers", leave=False): + filename, file_info = item + frame = file_info["source_filename"] + face_index = file_info["face_index"] + new_index = items[frame].index(face_index) + + fullpath = os.path.join(self._items.folder, filename) + logger.debug("Updating png header for '%s': face index from %s to %s", + fullpath, face_index, new_index) + + # Update file_list_sorted for rename task + orig_filename = f"{os.path.splitext(frame)[0]}_{new_index}.png" + file_info["face_index"] = new_index + file_info["original_filename"] = orig_filename + + face = DetectedFace() + face.from_alignment(self._alignments.get_faces_in_frame(frame)[new_index]) + meta = {"alignments": face.to_png_meta(), + "source": {"alignments_version": file_info["alignments_version"], + "original_filename": orig_filename, + "face_index": new_index, + "source_filename": frame, + "source_is_video": file_info["source_is_video"], + "source_frame_dims": file_info.get("source_frame_dims")}} + update_existing_metadata(fullpath, meta) + + logger.info("%s Extracted face(s) had their header information updated", len(to_update)) + + +class FaceToFile(): + """ Updates any optional/missing keys in the alignments file with any data that has been + populated in a PNGHeader. Includes masks and identity fields. + + Parameters + --------- + alignments: :class:`tools.alignments.media.AlignmentsData` + The loaded alignments containing faces to be removed + face_data: list + List of :class:`PNGHeaderDict` objects + """ + def __init__(self, alignments: AlignmentData, face_data: list[PNGHeaderDict]) -> None: + logger.debug("Initializing %s: alignments: %s, face_data: %s", + self.__class__.__name__, alignments, len(face_data)) + self._alignments = alignments + self._face_alignments = face_data + self._updatable_keys: list[T.Literal["identity", "mask"]] = ["identity", "mask"] + self._counts: dict[str, int] = {} + logger.debug("Initialized %s", self.__class__.__name__) + + def _check_and_update(self, + alignment: PNGHeaderAlignmentsDict, + face: AlignmentFileDict) -> None: + """ Check whether the key requires updating and update it. + + alignment: dict + The alignment dictionary from the PNG Header + face: dict + The alignment dictionary for the face from the alignments file + """ + for key in self._updatable_keys: + if key == "mask": + exist_masks = face["mask"] + for mask_name, mask_data in alignment["mask"].items(): + if mask_name in exist_masks: + continue + exist_masks[mask_name] = mask_data + count_key = f"mask_{mask_name}" + self._counts[count_key] = self._counts.get(count_key, 0) + 1 + continue + + if not face.get(key, {}) and alignment.get(key): + face[key] = alignment[key] + self._counts[key] = self._counts.get(key, 0) + 1 + + def __call__(self) -> bool: + """ Parse through the face data updating any entries in the alignments file. + + Returns + ------- + bool + ``True`` if any alignment information was updated otherwise ``False`` + """ + for meta in tqdm(self._face_alignments, + desc="Updating Alignments File from PNG Header", + leave=False): + src = meta["source"] + alignment = meta["alignments"] + if not any(alignment.get(key, {}) for key in self._updatable_keys): + continue + + faces = self._alignments.get_faces_in_frame(src["source_filename"]) + if len(faces) < src["face_index"] + 1: # list index out of range + logger.debug("Skipped face '%s'. Index does not exist in alignments file", + src["original_filename"]) + continue + + face = faces[src["face_index"]] + self._check_and_update(alignment, face) + + retval = False + if self._counts: + retval = True + logger.info("Updated alignments file from PNG Data: %s", self._counts) + return retval + + +__all__ = get_module_objects(__name__) diff --git a/tools/alignments/jobs_frames.py b/tools/alignments/jobs_frames.py new file mode 100644 index 0000000000..fcedc13065 --- /dev/null +++ b/tools/alignments/jobs_frames.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +""" Tools for manipulating the alignments using Frames as a source """ +from __future__ import annotations +import logging +import os +import sys +import typing as T + +from datetime import datetime + +import cv2 +import numpy as np +from tqdm import tqdm + +from lib.align import DetectedFace, EXTRACT_RATIOS, LANDMARK_PARTS, LandmarkType +from lib.align.alignments import _VERSION, PNGHeaderDict +from lib.image import encode_image, generate_thumbnail, ImagesSaver +from lib.utils import get_module_objects +from plugins.extract import ExtractMedia, Extractor +from .media import ExtractedFaces, Frames + +if T.TYPE_CHECKING: + from argparse import Namespace + from .media import AlignmentData + +logger = logging.getLogger(__name__) + + +class Draw(): + """ Draws annotations onto original frames and saves into a sub-folder next to the original + frames. + + Parameters + --------- + alignments: :class:`tools.alignments.media.AlignmentsData` + The loaded alignments corresponding to the frames to be annotated + arguments: :class:`argparse.Namespace` + The command line arguments that have called this job + """ + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) + self._alignments = alignments + self._frames = Frames(arguments.frames_dir) + self._output_folder = self._set_output() + logger.debug("Initialized %s", self.__class__.__name__) + + def _set_output(self) -> str: + """ Set the output folder path. + + If annotating a folder of frames, output will be placed in a sub folder within the frames + folder. If annotating a video, output will be a folder next to the original video. + + Returns + ------- + str + Full path to the output folder + + """ + now = datetime.now().strftime("%Y%m%d_%H%M%S") + folder_name = f"drawn_landmarks_{now}" + if self._frames.is_video: + dest_folder = os.path.dirname(self._frames.folder) + else: + dest_folder = self._frames.folder + output_folder = os.path.join(dest_folder, folder_name) + logger.debug("Creating folder: '%s'", output_folder) + os.makedirs(output_folder) + return output_folder + + def process(self) -> None: + """ Runs the process to draw face annotations onto original source frames. """ + logger.info("[DRAW LANDMARKS]") # Tidy up cli output + frames_drawn = 0 + for frame in tqdm(self._frames.file_list_sorted, desc="Drawing landmarks", leave=False): + frame_name = frame["frame_fullname"] + + if not self._alignments.frame_exists(frame_name): + logger.verbose("Skipping '%s' - Alignments not found", frame_name) # type:ignore + continue + + self._annotate_image(frame_name) + frames_drawn += 1 + logger.info("%s Frame(s) output", frames_drawn) + + def _annotate_image(self, frame_name: str) -> None: + """ Annotate the frame with each face that appears in the alignments file. + + Parameters + ---------- + frame_name: str + The full path to the original frame + """ + logger.trace("Annotating frame: '%s'", frame_name) # type:ignore + image = self._frames.load_image(frame_name) + + for idx, alignment in enumerate(self._alignments.get_faces_in_frame(frame_name)): + face = DetectedFace() + face.from_alignment(alignment, image=image) + # Bounding Box + assert face.left is not None + assert face.top is not None + cv2.rectangle(image, (face.left, face.top), (face.right, face.bottom), (255, 0, 0), 1) + self._annotate_landmarks(image, np.rint(face.landmarks_xy).astype("int32")) + self._annotate_extract_boxes(image, face, idx) + self._annotate_pose(image, face) # Pose (head is still loaded) + + self._frames.save_image(self._output_folder, frame_name, image) + + def _annotate_landmarks(self, image: np.ndarray, landmarks: np.ndarray) -> None: + """ Annotate the extract boxes onto the frame. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The frame that extract boxes are to be annotated on to + landmarks: :class:`numpy.ndarray` + The facial landmarks that are to be annotated onto the frame + """ + # Mesh + for start, end, fill in LANDMARK_PARTS[LandmarkType.from_shape(landmarks.shape)].values(): + cv2.polylines(image, [landmarks[start:end]], fill, (255, 255, 0), 1) + # Landmarks + for (pos_x, pos_y) in landmarks: + cv2.circle(image, (pos_x, pos_y), 1, (0, 255, 255), -1) + + @classmethod + def _annotate_extract_boxes(cls, image: np.ndarray, face: DetectedFace, index: int) -> None: + """ Annotate the mesh and landmarks boxes onto the frame. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The frame that mesh and landmarks are to be annotated on to + face: :class:`lib.align.DetectedFace` + The aligned face + index: int + The face index for the given face + """ + for area in T.get_args(T.Literal["face", "head"]): + face.load_aligned(image, centering=area, force=True) + color = (0, 255, 0) if area == "face" else (0, 0, 255) + top_left = face.aligned.original_roi[0] + top_left = (top_left[0], top_left[1] - 10) + cv2.putText(image, str(index), top_left, cv2.FONT_HERSHEY_DUPLEX, 1.0, color, 1) + cv2.polylines(image, [face.aligned.original_roi], True, color, 1) + + @classmethod + def _annotate_pose(cls, image: np.ndarray, face: DetectedFace) -> None: + """ Annotate the pose onto the frame. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The frame that pose is to be annotated on to + face: :class:`lib.align.DetectedFace` + The aligned face loaded for head centering + """ + center = np.array((face.aligned.size / 2, + face.aligned.size / 2)).astype("int32").reshape(1, 2) + center = np.rint(face.aligned.transform_points(center, invert=True)).astype("int32") + points = face.aligned.pose.xyz_2d * face.aligned.size + points = np.rint(face.aligned.transform_points(points, invert=True)).astype("int32") + cv2.line(image, tuple(center), tuple(points[1]), (0, 255, 0), 2) + cv2.line(image, tuple(center), tuple(points[0]), (255, 0, 0), 2) + cv2.line(image, tuple(center), tuple(points[2]), (0, 0, 255), 2) + + +class Extract(): + """ Re-extract faces from source frames based on Alignment data + + Parameters + ---------- + alignments: :class:`tools.lib_alignments.media.AlignmentData` + The alignments data loaded from an alignments file for this rename job + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) + self._arguments = arguments + self._alignments = alignments + self._is_legacy = self._alignments.version == 1.0 # pylint:disable=protected-access + self._mask_pipeline: Extractor | None = None + self._faces_dir = arguments.faces_dir + self._min_size = self._get_min_size(arguments.size, arguments.min_size) + + self._frames = Frames(arguments.frames_dir, self._get_count()) + self._extracted_faces = ExtractedFaces(self._frames, + self._alignments, + size=arguments.size) + self._saver: ImagesSaver | None = None + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _get_min_size(cls, extract_size: int, min_size: int) -> int: + """ Obtain the minimum size that a face has been resized from to be included as a valid + extract. + + Parameters + ---------- + extract_size: int + The requested size of the extracted images + min_size: int + The percentage amount that has been supplied for valid faces (as a percentage of + extract size) + + Returns + ------- + int + The minimum size, in pixels, that a face is resized from to be considered valid + """ + retval = 0 if min_size == 0 else max(4, int(extract_size * (min_size / 100.))) + logger.debug("Extract size: %s, min percentage size: %s, min_size: %s", + extract_size, min_size, retval) + return retval + + def _get_count(self) -> int | None: + """ If the alignments file has been run through the manual tool, then it will hold video + meta information, meaning that the count of frames in the alignment file can be relied + on to be accurate. + + Returns + ------- + int or ``None`` + For video input which contain video meta-data in the alignments file then the count of + frames is returned. In all other cases ``None`` is returned + """ + meta = self._alignments.video_meta_data + has_meta = all(val is not None for val in meta.values()) + if has_meta: + retval: int | None = len(T.cast(dict[str, list[int] | list[float]], meta["pts_time"])) + else: + retval = None + logger.debug("Frame count from alignments file: (has_meta: %s, %s", has_meta, retval) + return retval + + def process(self) -> None: + """ Run the re-extraction from Alignments file process""" + logger.info("[EXTRACT FACES]") # Tidy up cli output + self._check_folder() + if self._is_legacy: + self._legacy_check() + self._saver = ImagesSaver(self._faces_dir, as_bytes=True) + + if self._min_size > 0: + logger.info("Only selecting faces that have been resized from a minimum resolution " + "of %spx", self._min_size) + + self._export_faces() + + def _check_folder(self) -> None: + """ Check that the faces folder doesn't pre-exist and create. """ + err = None + if not self._faces_dir: + err = "ERROR: Output faces folder not provided." + elif not os.path.isdir(self._faces_dir): + logger.debug("Creating folder: '%s'", self._faces_dir) + os.makedirs(self._faces_dir) + elif os.listdir(self._faces_dir): + err = f"ERROR: Output faces folder should be empty: '{self._faces_dir}'" + if err: + logger.error(err) + sys.exit(0) + logger.verbose("Creating output folder at '%s'", self._faces_dir) # type:ignore + + def _legacy_check(self) -> None: + """ Check whether the alignments file was created with the legacy extraction method. + + If so, force user to re-extract all faces if any options have been specified, otherwise + raise the appropriate warnings and set the legacy options. + """ + if self._min_size > 0 or self._arguments.extract_every_n != 1: + logger.warning("This alignments file was generated with the legacy extraction method.") + logger.warning("You should run this extraction job, but with 'min_size' set to 0 and " + "'extract-every-n' set to 1 to update the alignments file.") + logger.warning("You can then re-run this extraction job with your chosen options.") + sys.exit(0) + + maskers = ["components", "extended"] + nn_masks = [mask for mask in list(self._alignments.mask_summary) if mask not in maskers] + logtype = logger.warning if nn_masks else logger.info + logtype("This alignments file was created with the legacy extraction method and will be " + "updated.") + logtype("Faces will be extracted using the new method and landmarks based masks will be " + "regenerated.") + if nn_masks: + logtype("However, the NN based masks '%s' will be cropped to the legacy extraction " + "method, so you may want to run the mask tool to regenerate these " + "masks.", "', '".join(nn_masks)) + self._mask_pipeline = Extractor(None, None, maskers, multiprocess=True) + self._mask_pipeline.launch() + # Update alignments versioning + self._alignments._io._version = _VERSION # pylint:disable=protected-access + + def _export_faces(self) -> None: + """ Export the faces to the output folder. """ + extracted_faces = 0 + skip_list = self._set_skip_list() + count = self._frames.count if skip_list is None else self._frames.count - len(skip_list) + + for filename, image in tqdm(self._frames.stream(skip_list=skip_list), + total=count, desc="Saving extracted faces", + leave=False): + frame_name = os.path.basename(filename) + if not self._alignments.frame_exists(frame_name): + logger.verbose("Skipping '%s' - Alignments not found", frame_name) # type:ignore + continue + extracted_faces += self._output_faces(frame_name, image) + if self._is_legacy and extracted_faces != 0 and self._min_size == 0: + self._alignments.save() + logger.info("%s face(s) extracted", extracted_faces) + + def _set_skip_list(self) -> list[int] | None: + """ Set the indices for frames that should be skipped based on the `extract_every_n` + command line option. + + Returns + ------- + list or ``None`` + A list of indices to be skipped if extract_every_n is not `1` otherwise + returns ``None`` + """ + skip_num = self._arguments.extract_every_n + if skip_num == 1: + logger.debug("Not skipping any frames") + return None + skip_list = [] + for idx, item in enumerate(T.cast(list[dict[str, str]], self._frames.file_list_sorted)): + if idx % skip_num != 0: + logger.trace("Adding image '%s' to skip list due to " # type:ignore + "extract_every_n = %s", item["frame_fullname"], skip_num) + skip_list.append(idx) + logger.debug("Adding skip list: %s", skip_list) + return skip_list + + def _output_faces(self, filename: str, image: np.ndarray) -> int: + """ For each frame save out the faces + + Parameters + ---------- + filename: str + The filename (without the full path) of the current frame + image: :class:`numpy.ndarray` + The full frame that faces are to be extracted from + + Returns + ------- + int + The total number of faces that have been extracted + """ + logger.trace("Outputting frame: %s", filename) # type:ignore + face_count = 0 + frame_name = os.path.splitext(filename)[0] + faces = self._select_valid_faces(filename, image) + assert self._saver is not None + if not faces: + return face_count + if self._is_legacy: + faces = self._process_legacy(filename, image, faces) + + for idx, face in enumerate(faces): + output = f"{frame_name}_{idx}.png" + meta: PNGHeaderDict = { + "alignments": face.to_png_meta(), + "source": {"alignments_version": self._alignments.version, + "original_filename": output, + "face_index": idx, + "source_filename": filename, + "source_is_video": self._frames.is_video, + "source_frame_dims": T.cast(tuple[int, int], image.shape[:2])}} + assert face.aligned.face is not None + self._saver.save(output, encode_image(face.aligned.face, ".png", metadata=meta)) + if self._min_size == 0 and self._is_legacy: + face.thumbnail = generate_thumbnail(face.aligned.face, size=96, quality=60) + self._alignments.data[filename]["faces"][idx] = face.to_alignment() + face_count += 1 + self._saver.close() + return face_count + + def _select_valid_faces(self, frame: str, image: np.ndarray) -> list[DetectedFace]: + """ Return the aligned faces from a frame that meet the selection criteria, + + Parameters + ---------- + frame: str + The filename (without the full path) of the current frame + image: :class:`numpy.ndarray` + The full frame that faces are to be extracted from + + Returns + ------- + list: + List of valid :class:`lib,align.DetectedFace` objects + """ + faces = self._extracted_faces.get_faces_in_frame(frame, image=image) + if self._min_size == 0: + valid_faces = faces + else: + sizes = self._extracted_faces.get_roi_size_for_frame(frame) + valid_faces = [faces[idx] for idx, size in enumerate(sizes) + if size >= self._min_size] + logger.trace("frame: '%s', total_faces: %s, valid_faces: %s", # type:ignore + frame, len(faces), len(valid_faces)) + return valid_faces + + def _process_legacy(self, + filename: str, + image: np.ndarray, + detected_faces: list[DetectedFace]) -> list[DetectedFace]: + """ Process legacy face extractions to new extraction method. + + Updates stored masks to new extract size + + Parameters + ---------- + filename: str + The current frame filename + image: :class:`numpy.ndarray` + The current image the contains the faces + detected_faces: list + list of :class:`lib.align.DetectedFace` objects for the current frame + + Returns + ------- + list + The updated list of :class:`lib.align.DetectedFace` objects for the current frame + """ + # Update landmarks based masks for face centering + assert self._mask_pipeline is not None + mask_item = ExtractMedia(filename, image, detected_faces=detected_faces) + self._mask_pipeline.input_queue.put(mask_item) + faces = next(self._mask_pipeline.detected_faces()).detected_faces + + # Pad and shift Neural Network based masks to face centering + for face in faces: + self._pad_legacy_masks(face) + return faces + + @classmethod + def _pad_legacy_masks(cls, detected_face: DetectedFace) -> None: + """ Recenter legacy Neural Network based masks from legacy centering to face centering + and pad accordingly. + + Update the masks back into the detected face objects. + + Parameters + ---------- + detected_face: :class:`lib.align.DetectedFace` + The detected face to update the masks for + """ + offset = detected_face.aligned.pose.offset["face"] + for name, mask in detected_face.mask.items(): # Re-center mask and pad to face size + if name in ("components", "extended"): + continue + old_mask = mask.mask.astype("float32") / 255.0 + size = old_mask.shape[0] + new_size = int(size + (size * EXTRACT_RATIOS["face"]) / 2) + + shift = np.rint(offset * (size - (size * EXTRACT_RATIOS["face"]))).astype("int32") + pos = np.array([(new_size // 2 - size // 2) - shift[1], + (new_size // 2) + (size // 2) - shift[1], + (new_size // 2 - size // 2) - shift[0], + (new_size // 2) + (size // 2) - shift[0]]) + bounds = np.array([max(0, pos[0]), min(new_size, pos[1]), + max(0, pos[2]), min(new_size, pos[3])]) + + slice_in = [slice(0 - (pos[0] - bounds[0]), size - (pos[1] - bounds[1])), + slice(0 - (pos[2] - bounds[2]), size - (pos[3] - bounds[3]))] + slice_out = [slice(bounds[0], bounds[1]), slice(bounds[2], bounds[3])] + + new_mask = np.zeros((new_size, new_size, 1), dtype="float32") + new_mask[slice_out[0], slice_out[1], :] = old_mask[slice_in[0], slice_in[1], :] + + mask.replace_mask(new_mask) + # Get the affine matrix from recently generated components mask + # pylint:disable=protected-access + mask._affine_matrix = detected_face.mask["components"].affine_matrix + + +__all__ = get_module_objects(__name__) diff --git a/tools/alignments/media.py b/tools/alignments/media.py new file mode 100644 index 0000000000..b92d233ca0 --- /dev/null +++ b/tools/alignments/media.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +""" Media items (Alignments, Faces, Frames) + for alignments tool """ +from __future__ import annotations +import logging +from operator import itemgetter +import os +import sys +import typing as T + +import cv2 +from tqdm import tqdm + +# TODO imageio single frame seek seems slow. Look into this +# import imageio + +from lib.align import Alignments, DetectedFace, update_legacy_png_header +from lib.image import (count_frames, generate_thumbnail, ImagesLoader, + png_write_meta, read_image, read_image_meta_batch) +from lib.utils import get_module_objects, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS, FaceswapError + +if T.TYPE_CHECKING: + from collections.abc import Generator + import numpy as np + from lib.align.alignments import AlignmentFileDict, PNGHeaderDict + +logger = logging.getLogger(__name__) + + +class AlignmentData(Alignments): + """ Class to hold the alignment data + + Parameters + ---------- + alignments_file: str + Full path to an alignments file + """ + def __init__(self, alignments_file: str) -> None: + logger.debug("Initializing %s: (alignments file: '%s')", + self.__class__.__name__, alignments_file) + logger.info("[ALIGNMENT DATA]") # Tidy up cli output + folder, filename = self.check_file_exists(alignments_file) + super().__init__(folder, filename=filename) + logger.verbose("%s items loaded", self.frames_count) # type: ignore + logger.debug("Initialized %s", self.__class__.__name__) + + @staticmethod + def check_file_exists(alignments_file: str) -> tuple[str, str]: + """ Check if the alignments file exists, and returns a tuple of the folder and filename. + + Parameters + ---------- + alignments_file: str + Full path to an alignments file + + Returns + ------- + folder: str + The full path to the folder containing the alignments file + filename: str + The filename of the alignments file + """ + folder, filename = os.path.split(alignments_file) + if not os.path.isfile(alignments_file): + logger.error("ERROR: alignments file not found at: '%s'", alignments_file) + sys.exit(0) + if folder: + logger.verbose("Alignments file exists at '%s'", alignments_file) # type: ignore + return folder, filename + + def save(self) -> None: + """ Backup copy of old alignments and save new alignments """ + self.backup() + super().save() + + +class MediaLoader(): + """ Class to load images. + + Parameters + ---------- + folder: str + The folder of images or video file to load images from + count: int or ``None``, optional + If the total frame count is known it can be passed in here which will skip + analyzing a video file. If the count is not passed in, it will be calculated. + Default: ``None`` + """ + def __init__(self, folder: str, count: int | None = None): + logger.debug("Initializing %s: (folder: '%s')", self.__class__.__name__, folder) + logger.info("[%s DATA]", self.__class__.__name__.upper()) + self._count = count + self.folder = folder + self._vid_reader = self.check_input_folder() + self.file_list_sorted = self.sorted_items() + self.items = self.load_items() + logger.verbose("%s items loaded", self.count) # type: ignore + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def is_video(self) -> bool: + """ bool: Return whether source is a video or not """ + return self._vid_reader is not None + + @property + def count(self) -> int: + """ int: Number of faces or frames """ + if self._count is not None: + return self._count + if self.is_video: + self._count = int(count_frames(self.folder)) + else: + self._count = len(self.file_list_sorted) + return self._count + + def check_input_folder(self) -> cv2.VideoCapture | None: + """ Ensure that the frames or faces folder exists and is valid. + If frames folder contains a video file return imageio reader object + + Returns + ------- + :class:`cv2.VideoCapture` + Object for reading a video stream + """ + err = None + loadtype = self.__class__.__name__ + if not self.folder: + err = f"ERROR: A {loadtype} folder must be specified" + elif not os.path.exists(self.folder): + err = f"ERROR: The {loadtype} location {self.folder} could not be found" + if err: + logger.error(err) + sys.exit(0) + + if (loadtype == "Frames" and + os.path.isfile(self.folder) and + os.path.splitext(self.folder)[1].lower() in VIDEO_EXTENSIONS): + logger.verbose("Video exists at: '%s'", self.folder) # type: ignore + retval = cv2.VideoCapture(self.folder) # pylint:disable=no-member + # TODO ImageIO single frame seek seems slow. Look into this + # retval = imageio.get_reader(self.folder, "ffmpeg") + else: + logger.verbose("Folder exists at '%s'", self.folder) # type: ignore + retval = None + return retval + + @staticmethod + def valid_extension(filename) -> bool: + """ bool: Check whether passed in file has a valid extension """ + extension = os.path.splitext(filename)[1] + retval = extension.lower() in IMAGE_EXTENSIONS + logger.trace("Filename has valid extension: '%s': %s", filename, retval) # type: ignore + return retval + + def sorted_items(self) -> list[dict[str, str]] | list[tuple[str, PNGHeaderDict]]: + """ Override for specific folder processing """ + raise NotImplementedError() + + def process_folder(self) -> (Generator[dict[str, str], None, None] | + Generator[tuple[str, PNGHeaderDict], None, None]): + """ Override for specific folder processing """ + raise NotImplementedError() + + def load_items(self) -> dict[str, list[int]] | dict[str, tuple[str, str]]: + """ Override for specific item loading """ + raise NotImplementedError() + + def load_image(self, filename: str) -> np.ndarray: + """ Load an image + + Parameters + ---------- + filename: str + The filename of the image to load + + Returns + ------- + :class:`numpy.ndarray` + The loaded image + """ + if self.is_video: + image = self.load_video_frame(filename) + else: + src = os.path.join(self.folder, filename) + logger.trace("Loading image: '%s'", src) # type: ignore + image = read_image(src, raise_error=True) + return image + + def load_video_frame(self, filename: str) -> np.ndarray: + """ Load a requested frame from video + + Parameters + ---------- + filename: str + The frame name to load + + Returns + ------- + :class:`numpy.ndarray` + The loaded image + """ + assert self._vid_reader is not None + frame = os.path.splitext(filename)[0] + logger.trace("Loading video frame: '%s'", frame) # type: ignore + frame_no = int(frame[frame.rfind("_") + 1:]) - 1 + self._vid_reader.set(cv2.CAP_PROP_POS_FRAMES, frame_no) # pylint:disable=no-member + + _, image = self._vid_reader.read() + # TODO imageio single frame seek seems slow. Look into this + # self._vid_reader.set_image_index(frame_no) + # image = self._vid_reader.get_next_data()[:, :, ::-1] + return image + + def stream(self, skip_list: list[int] | None = None + ) -> Generator[tuple[str, np.ndarray], None, None]: + """ Load the images in :attr:`folder` in the order they are received from + :class:`lib.image.ImagesLoader` in a background thread. + + Parameters + ---------- + skip_list: list, optional + A list of frame indices that should not be loaded. Pass ``None`` if all images should + be loaded. Default: ``None`` + + Yields + ------ + str + The filename of the image that is being returned + numpy.ndarray + The image that has been loaded from disk + """ + loader = ImagesLoader(self.folder, queue_size=32, count=self._count) + if skip_list is not None: + loader.add_skip_list(skip_list) + for filename, image in loader.load(): + yield filename, image + + @staticmethod + def save_image(output_folder: str, + filename: str, + image: np.ndarray, + metadata: PNGHeaderDict | None = None) -> None: + """ Save an image """ + output_file = os.path.join(output_folder, filename) + output_file = os.path.splitext(output_file)[0] + ".png" + logger.trace("Saving image: '%s'", output_file) # type: ignore + if metadata: + encoded = cv2.imencode(".png", image)[1] + encoded_image = png_write_meta(encoded.tobytes(), metadata) + with open(output_file, "wb") as out_file: + out_file.write(encoded_image) + else: + cv2.imwrite(output_file, image) # pylint:disable=no-member + + +class Faces(MediaLoader): + """ Object to load Extracted Faces from a folder. + + Parameters + ---------- + folder: str + The folder to load faces from + alignments: :class:`lib.align.Alignments`, optional + The alignments object that contains the faces. This can be used for 2 purposes: + - To update legacy hash based faces for None: + self._alignments = alignments + super().__init__(folder) + + def _handle_legacy(self, fullpath: str, log: bool = False) -> PNGHeaderDict: + """Handle facesets that are legacy (i.e. do not contain alignment information in the + header data) + + Parameters + ---------- + fullpath : str + The full path to the extracted face image + log : bool, optional + Whether to log a message that legacy updating is occurring + + Returns + ------- + :class:`~lib.align.alignments.PNGHeaderDict` + The Alignments information from the face in PNG Header dict format + + Raises + ------ + FaceswapError + If legacy faces can't be updated because the alignments file does not exist or some of + the faces do not appear in the provided alignments file + """ + if self._alignments is None: # Can't update legacy + raise FaceswapError(f"The folder '{self.folder}' contains images that do not include " + "Faceswap metadata.\nAll images in the provided folder should " + "contain faces generated from Faceswap's extraction process.\n" + "Please double check the source and try again.") + if log: + logger.warning("Legacy faces discovered. These faces will be updated") + + data = update_legacy_png_header(fullpath, self._alignments) + if not data: + raise FaceswapError( + f"Some of the faces being passed in from '{self.folder}' could not be " + f"matched to the alignments file '{self._alignments.file}'\nPlease double " + "check your sources and try again.") + return data + + def _handle_duplicate(self, + fullpath: str, + header_dict: PNGHeaderDict, + seen: dict[str, list[int]]) -> bool: + """ Check whether the given face has already been seen for the source frame and face index + from an existing face. Can happen when filenames have changed due to sorting etc. and users + have done multiple extractions/copies and placed all of the faces in the same folder + + Parameters + ---------- + fullpath : str + The full path to the face image that is being checked + header_dict : class:`~lib.align.alignments.PNGHeaderDict` + The PNG header dictionary for the given face + seen : dict[str, list[int]] + Dictionary of original source filename and face indices that have already been seen and + will be updated with the face processing now + + Returns + ------- + bool + ``True`` if the face was a duplicate and has been removed, otherwise ``False`` + """ + src_filename = header_dict["source"]["source_filename"] + face_index = header_dict["source"]["face_index"] + + if src_filename in seen and face_index in seen[src_filename]: + dupe_dir = os.path.join(self.folder, "_duplicates") + os.makedirs(dupe_dir, exist_ok=True) + filename = os.path.basename(fullpath) + logger.trace("Moving duplicate: %s", filename) # type:ignore + os.rename(fullpath, os.path.join(dupe_dir, filename)) + return True + + seen.setdefault(src_filename, []).append(face_index) + return False + + def process_folder(self) -> Generator[tuple[str, PNGHeaderDict], None, None]: + """ Iterate through the faces folder pulling out various information for each face. + + Yields + ------ + dict + A dictionary for each face found containing the keys returned from + :class:`lib.image.read_image_meta_batch` + """ + logger.info("Loading file list from %s", self.folder) + filter_count = 0 + dupe_count = 0 + seen: dict[str, list[int]] = {} + + if self._alignments is not None and self._alignments.version < 2.1: # Legacy updating + filelist = [os.path.join(self.folder, face) + for face in os.listdir(self.folder) + if self.valid_extension(face)] + else: + filelist = [os.path.join(self.folder, face) + for face in os.listdir(self.folder) + if os.path.splitext(face)[-1] == ".png"] + + log_once = False + for fullpath, metadata in tqdm(read_image_meta_batch(filelist), + total=len(filelist), + desc="Reading Face Data"): + + if "itxt" not in metadata or "source" not in metadata["itxt"]: + sub_dict = self._handle_legacy(fullpath, not log_once) + log_once = True + else: + sub_dict = T.cast("PNGHeaderDict", metadata["itxt"]) + + if self._handle_duplicate(fullpath, sub_dict, seen): + dupe_count += 1 + continue + + if (self._alignments is not None and # filter existing + not self._alignments.frame_exists(sub_dict["source"]["source_filename"])): + filter_count += 1 + continue + + retval = (os.path.basename(fullpath), sub_dict) + yield retval + + if self._alignments is not None: + logger.debug("Faces filtered out that did not exist in alignments file: %s", + filter_count) + + if dupe_count > 0: + logger.warning("%s Duplicate face images were found. These files have been moved to " + "'%s' from where they can be safely deleted", + dupe_count, os.path.join(self.folder, "_duplicates")) + + def load_items(self) -> dict[str, list[int]]: + """ Load the face names into dictionary. + + Returns + ------- + dict + The source filename as key with list of face indices for the frame as value + """ + faces: dict[str, list[int]] = {} + for face in T.cast(list[tuple[str, "PNGHeaderDict"]], self.file_list_sorted): + src = face[1]["source"] + faces.setdefault(src["source_filename"], []).append(src["face_index"]) + logger.trace(faces) # type: ignore + return faces + + def sorted_items(self) -> list[tuple[str, PNGHeaderDict]]: + """ Return the items sorted by the saved file name. + + Returns + -------- + list + List of `dict` objects for each face found, sorted by the face's current filename + """ + items = sorted(self.process_folder(), key=itemgetter(0)) + logger.trace(items) # type: ignore + return items + + +class Frames(MediaLoader): + """ Object to hold the frames that are to be checked against """ + + def process_folder(self) -> Generator[dict[str, str], None, None]: + """ Iterate through the frames folder pulling the base filename + + Yields + ------ + dict + The full framename, the filename and the file extension of the frame + """ + iterator = self.process_video if self.is_video else self.process_frames + yield from iterator() + + def process_frames(self) -> Generator[dict[str, str], None, None]: + """ Process exported Frames + + Yields + ------ + dict + The full framename, the filename and the file extension of the frame + """ + logger.info("Loading file list from %s", self.folder) + for frame in os.listdir(self.folder): + if not self.valid_extension(frame): + continue + filename = os.path.splitext(frame)[0] + file_extension = os.path.splitext(frame)[1] + + retval = {"frame_fullname": frame, + "frame_name": filename, + "frame_extension": file_extension} + logger.trace(retval) # type: ignore + yield retval + + def process_video(self) -> Generator[dict[str, str], None, None]: + """Dummy in frames for video + + Yields + ------ + dict + The full framename, the filename and the file extension of the frame + """ + logger.info("Loading video frames from %s", self.folder) + vidname, ext = os.path.splitext(os.path.basename(self.folder)) + for i in range(self.count): + idx = i + 1 + # Keep filename format for outputted face + filename = f"{vidname}_{idx:06d}" + retval = {"frame_fullname": f"{filename}{ext}", + "frame_name": filename, + "frame_extension": ext} + logger.trace(retval) # type: ignore + yield retval + + def load_items(self) -> dict[str, tuple[str, str]]: + """ Load the frame info into dictionary + + Returns + ------- + dict + Fullname as key, tuple of frame name and extension as value + """ + frames: dict[str, tuple[str, str]] = {} + for frame in T.cast(list[dict[str, str]], self.file_list_sorted): + frames[frame["frame_fullname"]] = (frame["frame_name"], + frame["frame_extension"]) + logger.trace(frames) # type: ignore + return frames + + def sorted_items(self) -> list[dict[str, str]]: + """ Return the items sorted by filename + + Returns + ------- + list + The sorted list of frame information + """ + items = sorted(self.process_folder(), key=lambda x: (x["frame_name"])) + logger.trace(items) # type: ignore + return items + + +class ExtractedFaces(): + """ Holds the extracted faces and matrix for alignments + + Parameters + ---------- + frames: :class:`Frames` + The frames object to extract faces from + alignments: :class:`AlignmentData` + The alignment data corresponding to the frames + size: int, optional + The extract face size. Default: 512 + """ + def __init__(self, frames: Frames, alignments: AlignmentData, size: int = 512) -> None: + logger.trace("Initializing %s: size: %s", # type: ignore + self.__class__.__name__, size) + self.size = size + self.padding = int(size * 0.1875) + self.alignments = alignments + self.frames = frames + self.current_frame: str | None = None + self.faces: list[DetectedFace] = [] + logger.trace("Initialized %s", self.__class__.__name__) # type: ignore + + def get_faces(self, frame: str, image: np.ndarray | None = None) -> None: + """ Obtain faces and transformed landmarks for each face in a given frame with its + alignments + + Parameters + ---------- + frame: str + The frame name to obtain faces for + image: :class:`numpy.ndarray`, optional + The image to extract the face from, if we already have it, otherwise ``None`` to + load the image. Default: ``None`` + """ + logger.trace("Getting faces for frame: '%s'", frame) # type: ignore + self.current_frame = None + alignments = self.alignments.get_faces_in_frame(frame) + logger.trace("Alignments for frame: (frame: '%s', alignments: %s)", # type: ignore + frame, alignments) + if not alignments: + self.faces = [] + return + image = self.frames.load_image(frame) if image is None else image + self.faces = [self.extract_one_face(alignment, image) for alignment in alignments] + self.current_frame = frame + + def extract_one_face(self, + alignment: AlignmentFileDict, + image: np.ndarray) -> DetectedFace: + """ Extract one face from image + + Parameters + ---------- + alignment: dict + The alignment for a single face + image: :class:`numpy.ndarray` + The image to extract the face from + + Returns + ------- + :class:`~lib.align.DetectedFace` + The detected face object for the given alignment with the aligned face loaded + """ + logger.trace("Extracting one face: (frame: '%s', alignment: %s)", # type: ignore + self.current_frame, alignment) + face = DetectedFace() + face.from_alignment(alignment, image=image) + face.load_aligned(image, size=self.size, centering="head") + face.thumbnail = generate_thumbnail(face.aligned.face, size=80, quality=60) + return face + + def get_faces_in_frame(self, + frame: str, + update: bool = False, + image: np.ndarray | None = None) -> list[DetectedFace]: + """ Return the faces for the selected frame + + Parameters + ---------- + frame: str + The frame name to get the faces for + update: bool, optional + ``True`` if the faces should be refreshed regardless of current frame. ``False`` to not + force a refresh. Default ``False`` + image: :class:`numpy.ndarray`, optional + Image to load faces from if it exists, otherwise ``None`` to load the image. + Default: ``None`` + + Returns + ------- + list + List of :class:`~lib.align.DetectedFace` objects for the frame, with the aligned face + loaded + """ + logger.trace("frame: '%s', update: %s", frame, update) # type: ignore + if self.current_frame != frame or update: + self.get_faces(frame, image=image) + return self.faces + + def get_roi_size_for_frame(self, frame: str) -> list[int]: + """ Return the size of the original extract box for the selected frame. + + Parameters + ---------- + frame: str + The frame to obtain the original sized bounding boxes for + + Returns + ------- + list + List of original pixel sizes of faces held within the frame + """ + logger.trace("frame: '%s'", frame) # type: ignore + if self.current_frame != frame: + self.get_faces(frame) + sizes = [] + for face in self.faces: + roi = face.aligned.original_roi.squeeze() + top_left, top_right = roi[0], roi[3] + len_x = top_right[0] - top_left[0] + len_y = top_right[1] - top_left[1] + if top_left[1] == top_right[1]: + length = len_y + else: + length = int(((len_x ** 2) + (len_y ** 2)) ** 0.5) + sizes.append(length) + logger.trace("sizes: '%s'", sizes) # type: ignore + return sizes + + +__all__ = get_module_objects(__name__) diff --git a/tools/cli.py b/tools/cli.py deleted file mode 100644 index d243561424..0000000000 --- a/tools/cli.py +++ /dev/null @@ -1,497 +0,0 @@ -#!/usr/bin/env python3 -""" Command Line Arguments for tools """ -from lib.cli import FaceSwapArgs -from lib.cli import (ContextFullPaths, DirFullPaths, - FileFullPaths, SaveFileFullPaths, Slider) -from lib.utils import _image_extensions - - -class AlignmentsArgs(FaceSwapArgs): - """ Class to parse the command line arguments for Aligments tool """ - - def get_argument_list(self): - frames_dir = "\n\tMust Pass in a frames folder/source video file (-fr)." - faces_dir = "\n\tMust Pass in a faces folder (-fc)." - frames_or_faces_dir = ("\n\tMust Pass in either a frames folder/source video file OR a" - "\n\tfaces folder (-fr or -fc).") - frames_and_faces_dir = ("\n\tMust Pass in a frames folder/source video file AND a faces " - "\n\tfolder (-fr and -fc).") - output_opts = "\n\tUse the output option (-o) to process results." - align_eyes = "\n\tCan optionally use the align-eyes switch (-ae)." - argument_list = list() - argument_list.append({ - - "opts": ("-j", "--job"), - "type": str, - "choices": ("draw", "extract", "extract-large", "manual", "merge", - "missing-alignments", "missing-frames", "legacy", "leftover-faces", - "multi-faces", "no-faces", "reformat", "remove-faces", "remove-frames", - "rename", "sort-x", "sort-y", "spatial", "update-hashes"), - "required": True, - "help": "R|Choose which action you want to perform.\n" - "NB: All actions require an alignments file (-a) to be passed in." - "\n'draw': Draw landmarks on frames in the selected folder/video. A subfolder" - "\n\twill be created within the frames folder to hold the output." + - frames_dir + align_eyes + - "\n'extract': Re-extract faces from the source frames/video based on " - "\n\talignment data. This is a lot quicker than re-detecting faces." + - frames_and_faces_dir + align_eyes + - "\n'extract-large' - Extract all faces that have not been upscaled. Useful" - "\n\tfor excluding low-res images from a training set." + - frames_and_faces_dir + align_eyes + - "\n'manual': Manually view and edit landmarks." + frames_dir + align_eyes + - "\n'merge': Merge multiple alignment files into one. Specify the main" - "\n\talignments file with the -a flag and the file to be merged with the" - "\n\t-a2 flag." - "\n'missing-alignments': Identify frames that do not exist in the alignments" - "\n\tfile." + output_opts + frames_dir + - "\n'missing-frames': Identify frames in the alignments file that do no " - "\n\tappear within the frames folder/video." + output_opts + frames_dir + - "\n'legacy': This updates legacy alignments to the latest format by rotating" - "\n\tthe landmarks and bounding boxes and adding face_hashes." + - frames_and_faces_dir + - "\n'leftover-faces': Identify faces in the faces folder that do not exist in" - "\n\tthe alignments file." + output_opts + faces_dir + - "\n'multi-faces': Identify where multiple faces exist within the alignments" - "\n\tfile." + output_opts + frames_or_faces_dir + - "\n'no-faces': Identify frames that exist within the alignment file but no" - "\n\tfaces were detected." + output_opts + frames_dir + - "\n'reformat': Save a copy of alignments file in a different format. Specify" - "\n\ta format with the -fmt option." - "\n\tAlignments can be converted from DeepFaceLab by specifing:" - "\n\t -a dfl" - "\n\t -fc " - "\n'remove-faces': Remove deleted faces from an alignments file. The original" - "\n\talignments file will be backed up. A different file format for the" - "\n\talignments file can optionally be specified (-fmt)." + faces_dir + - "\n'remove-frames': Remove deleted frames from an alignments file. The" - "\n\toriginal alignments file will be backed up. A different file format for" - "\n\tthe alignments file can optionally be specified (-fmt)." + frames_dir + - "\n'rename' - Rename faces to correspond with their parent frame and position" - "\n\tindex in the alignments file (i.e. how they are named after running" - "\n\textract)." + faces_dir + - "\n'sort-x': Re-index the alignments from left to right. For alignments with" - "\n\tmultiple faces this will ensure that the left-most face is at index 0" - "\n\tOptionally pass in a faces folder (-fc) to also rename extracted faces." - "\n'sort-y': Re-index the alignments from top to bottom. For alignments with" - "\n\tmultiple faces this will ensure that the top-most face is at index 0" - "\n\tOptionally pass in a faces folder (-fc) to also rename extracted faces." - "\n'spatial': Perform spatial and temporal filtering to smooth alignments" - "\n\t(EXPERIMENTAL!)" - "\n'update-hashes': Recalculate the face hashes. Only use this if you have " - "\n\taltered the extracted faces (e.g. colour adjust). The files MUST be " - "\n\tnamed '_face index' (i.e. how they are named after running" - "\n\textract)." + faces_dir}) - argument_list.append({"opts": ("-a", "--alignments_file"), - "action": FileFullPaths, - "dest": "alignments_file", - "required": True, - "filetypes": "alignments", - "help": "Full path to the alignments " - "file to be processed."}) - argument_list.append({"opts": ("-a2", "--alignments_file2"), - "action": FileFullPaths, - "dest": "alignments_file2", - "required": False, - "filetypes": "alignments", - "help": "Full path to the alignments file to " - "be merged into the main alignments " - "file (merge only)"}) - argument_list.append({"opts": ("-fc", "-faces_folder"), - "action": DirFullPaths, - "dest": "faces_dir", - "help": "Directory containing extracted faces."}) - argument_list.append({"opts": ("-fr", "-frames_folder"), - "action": DirFullPaths, - "dest": "frames_dir", - "help": "Directory containing source frames " - "that faces were extracted from."}) - argument_list.append({"opts": ("-fmt", "--alignment_format"), - "type": str, - "choices": ("json", "pickle", "yaml"), - "help": "The file format to save the alignment " - "data in. Defaults to same as source."}) - argument_list.append({ - "opts": ("-o", "--output"), - "type": str, - "choices": ("console", "file", "move"), - "default": "console", - "help": "R|How to output discovered items ('faces' and" - "\n'frames' only):" - "\n'console': Print the list of frames to the screen. (DEFAULT)" - "\n'file': Output the list of frames to a text file (stored within the source" - "\n\tdirectory)." - "\n'move': Move the discovered items to a sub-folder within the source" - "\n\tdirectory."}) - argument_list.append({"opts": ("-sz", "--size"), - "type": int, - "action": Slider, - "min_max": (128, 512), - "default": 256, - "rounding": 64, - "help": "The output size of extracted faces. (extract only)"}) - argument_list.append({"opts": ("-ae", "--align-eyes"), - "action": "store_true", - "dest": "align_eyes", - "default": False, - "help": "Perform extra alignment to ensure " - "left/right eyes are at the same " - "height. (Draw, Extract and manual " - "only)"}) - argument_list.append({"opts": ("-dm", "--disable-monitor"), - "action": "store_true", - "dest": "disable_monitor", - "default": False, - "help": "Enable this option if manual " - "alignments window is closing " - "instantly. (Manual only)"}) - return argument_list - - -class EffmpegArgs(FaceSwapArgs): - """ Class to parse the command line arguments for EFFMPEG tool """ - - @staticmethod - def __parse_transpose(value): - index = 0 - opts = ["(0, 90CounterClockwise&VerticalFlip)", - "(1, 90Clockwise)", - "(2, 90CounterClockwise)", - "(3, 90Clockwise&VerticalFlip)"] - if len(value) == 1: - index = int(value) - else: - for i in range(5): - if value in opts[i]: - index = i - break - return opts[index] - - def get_argument_list(self): - argument_list = list() - argument_list.append({"opts": ('-a', '--action'), - "dest": "action", - "choices": ("extract", "gen-vid", "get-fps", - "get-info", "mux-audio", "rescale", - "rotate", "slice"), - "default": "extract", - "help": "Choose which action you want ffmpeg " - "ffmpeg to do.\n" - "'slice' cuts a portion of the video " - "into a separate video file.\n" - "'get-fps' returns the chosen video's " - "fps."}) - - argument_list.append({"opts": ('-i', '--input'), - "action": ContextFullPaths, - "dest": "input", - "default": "input", - "help": "Input file.", - "required": True, - "action_option": "-a", - "filetypes": "video"}) - - argument_list.append({"opts": ('-o', '--output'), - "action": ContextFullPaths, - "dest": "output", - "default": "", - "help": "Output file. If no output is " - "specified then: if the output is " - "meant to be a video then a video " - "called 'out.mkv' will be created in " - "the input directory; if the output is " - "meant to be a directory then a " - "directory called 'out' will be " - "created inside the input " - "directory.\n" - "Note: the chosen output file " - "extension will determine the file " - "encoding.", - "action_option": "-a", - "filetypes": "video"}) - - argument_list.append({"opts": ('-r', '--reference-video'), - "action": FileFullPaths, - "dest": "ref_vid", - "default": None, - "help": "Path to reference video if 'input' " - "was not a video.", - "filetypes": "video"}) - - argument_list.append({"opts": ('-fps', '--fps'), - "type": str, - "dest": "fps", - "default": "-1.0", - "help": "Provide video fps. Can be an integer, " - "float or fraction. Negative values " - "will make the program try to get the " - "fps from the input or reference " - "videos."}) - - argument_list.append({"opts": ("-ef", "--extract-filetype"), - "choices": _image_extensions, - "dest": "extract_ext", - "default": ".png", - "help": "Image format that extracted images " - "should be saved as. '.bmp' will offer " - "the fastest extraction speed, but " - "will take the most storage space. " - "'.png' will be slower but will take " - "less storage."}) - - argument_list.append({"opts": ('-s', '--start'), - "type": str, - "dest": "start", - "default": "00:00:00", - "help": "Enter the start time from which an " - "action is to be applied.\n" - "Default: 00:00:00, in HH:MM:SS " - "format. You can also enter the time " - "with or without the colons, e.g. " - "00:0000 or 026010."}) - - argument_list.append({"opts": ('-e', '--end'), - "type": str, - "dest": "end", - "default": "00:00:00", - "help": "Enter the end time to which an action " - "is to be applied. If both an end time " - "and duration are set, then the end " - "time will be used and the duration " - "will be ignored.\n" - "Default: 00:00:00, in HH:MM:SS."}) - - argument_list.append({"opts": ('-d', '--duration'), - "type": str, - "dest": "duration", - "default": "00:00:00", - "help": "Enter the duration of the chosen " - "action, for example if you enter " - "00:00:10 for slice, then the first 10 " - "seconds after and including the start " - "time will be cut out into a new " - "video.\n" - "Default: 00:00:00, in HH:MM:SS " - "format. You can also enter the time " - "with or without the colons, e.g. " - "00:0000 or 026010."}) - - argument_list.append({"opts": ('-m', '--mux-audio'), - "action": "store_true", - "dest": "mux_audio", - "default": False, - "help": "Mux the audio from the reference " - "video into the input video. This " - "option is only used for the 'gen-vid' " - "action. 'mux-audio' action has this " - "turned on implicitly."}) - - argument_list.append( - {"opts": ('-tr', '--transpose'), - "choices": ("(0, 90CounterClockwise&VerticalFlip)", - "(1, 90Clockwise)", - "(2, 90CounterClockwise)", - "(3, 90Clockwise&VerticalFlip)"), - "type": lambda v: self.__parse_transpose(v), - "dest": "transpose", - "default": None, - "help": "Transpose the video. If transpose is " - "set, then degrees will be ignored. For " - "cli you can enter either the number " - "or the long command name, " - "e.g. to use (1, 90Clockwise) " - "-tr 1 or -tr 90Clockwise"}) - - argument_list.append({"opts": ('-de', '--degrees'), - "type": str, - "dest": "degrees", - "default": None, - "help": "Rotate the video clockwise by the " - "given number of degrees."}) - - argument_list.append({"opts": ('-sc', '--scale'), - "type": str, - "dest": "scale", - "default": "1920x1080", - "help": "Set the new resolution scale if the " - "chosen action is 'rescale'."}) - - argument_list.append({"opts": ('-pr', '--preview'), - "action": "store_true", - "dest": "preview", - "default": False, - "help": "Uses ffplay to preview the effects of " - "actions that have a video output. " - "Currently preview does not work when " - "muxing audio."}) - - argument_list.append({"opts": ('-q', '--quiet'), - "action": "store_true", - "dest": "quiet", - "default": False, - "help": "Reduces output verbosity so that only " - "serious errors are printed. If both " - "quiet and verbose are set, verbose " - "will override quiet."}) - - argument_list.append({"opts": ('-v', '--verbose'), - "action": "store_true", - "dest": "verbose", - "default": False, - "help": "Increases output verbosity. If both " - "quiet and verbose are set, verbose " - "will override quiet."}) - - return argument_list - - -class SortArgs(FaceSwapArgs): - """ Class to parse the command line arguments for sort tool """ - - @staticmethod - def get_argument_list(): - """ Put the arguments in a list so that they are accessible from both - argparse and gui """ - argument_list = list() - argument_list.append({"opts": ('-i', '--input'), - "action": DirFullPaths, - "dest": "input_dir", - "default": "input_dir", - "help": "Input directory of aligned faces.", - "required": True}) - - argument_list.append({"opts": ('-o', '--output'), - "action": DirFullPaths, - "dest": "output_dir", - "default": "_output_dir", - "help": "Output directory for sorted aligned " - "faces."}) - - argument_list.append({"opts": ('-fp', '--final-process'), - "type": str, - "choices": ("folders", "rename"), - "dest": 'final_process', - "default": "rename", - "help": "R|\n'folders': files are sorted using " - "the -s/--sort-by\n\tmethod, then they " - "are organized into\n\tfolders using " - "the -g/--group-by grouping\n\tmethod." - "\n'rename': files are sorted using " - "the -s/--sort-by\n\tthen they are " - "renamed.\nDefault: rename"}) - - argument_list.append({"opts": ('-k', '--keep'), - "action": 'store_true', - "dest": 'keep_original', - "default": False, - "help": "Keeps the original files in the input " - "directory. Be careful when using this " - "with rename grouping and no specified " - "output directory as this would keep " - "the original and renamed files in the " - "same directory."}) - - argument_list.append({"opts": ('-s', '--sort-by'), - "type": str, - "choices": ("blur", "face", "face-cnn", - "face-cnn-dissim", "face-dissim", - "face-yaw", "hist", - "hist-dissim"), - "dest": 'sort_method', - "default": "hist", - "help": "Sort by method. " - "Choose how images are sorted. " - "Default: hist"}) - - argument_list.append({"opts": ('-g', '--group-by'), - "type": str, - "choices": ("blur", "face", "face-cnn", - "face-yaw", "hist"), - "dest": 'group_method', - "default": "hist", - "help": "Group by method. " - "When -fp/--final-processing by " - "folders choose the how the images are " - "grouped after sorting. " - "Default: hist"}) - - argument_list.append({"opts": ('-t', '--ref_threshold'), - "action": Slider, - "min_max": (-1.0, 10.0), - "rounding": 2, - "type": float, - "dest": 'min_threshold', - "default": -1.0, - "help": "Float value. " - "Minimum threshold to use for grouping " - "comparison with 'face' and 'hist' " - "methods. The lower the value the more " - "discriminating the grouping is. " - "Leaving -1.0 will make the program " - "set the default value automatically. " - "For face 0.6 should be enough, with " - "0.5 being very discriminating. " - "For face-cnn 7.2 should be enough, " - "with 4 being very discriminating. " - "For hist 0.3 should be enough, with " - "0.2 being very discriminating. " - "Be careful setting a value that's too " - "low in a directory with many images, " - "as this could result in a lot of " - "directories being created. " - "Defaults: face 0.6, face-cnn 7.2, " - "hist 0.3"}) - - argument_list.append({"opts": ('-b', '--bins'), - "action": Slider, - "min_max": (1, 100), - "rounding": 1, - "type": int, - "dest": 'num_bins', - "default": 5, - "help": "Integer value. " - "Number of folders that will be used " - "to group by blur and face-yaw. " - "For blur folder 0 will be the least " - "blurry, while the last folder will be " - "the blurriest. " - "For face-yaw the number of bins is by " - "how much 180 degrees is divided. So " - "if you use 18, then each folder will " - "be a 10 degree increment. Folder 0 " - "will contain faces looking the most " - "to the left whereas the last folder " - "will contain the faces looking the " - "most to the right. " - "If the number of images doesn't " - "divide evenly into the number of " - "bins, the remaining images get put in " - "the last bin." - "Default value: 5"}) - - argument_list.append({"opts": ('-l', '--log-changes'), - "action": 'store_true', - "dest": 'log_changes', - "default": False, - "help": "Logs file renaming changes if " - "grouping by renaming, or it logs the " - "file copying/movement if grouping by " - "folders. If no log file is specified " - "with '--log-file', then a " - "'sort_log.json' file will be created " - "in the input directory."}) - - argument_list.append({"opts": ('-lf', '--log-file'), - "action": SaveFileFullPaths, - "filetypes": "alignments", - "dest": 'log_file_path', - "default": 'sort_log.json', - "help": "Specify a log file to use for saving " - "the renaming or grouping information. " - "If specified extension isn't 'json' " - "or 'yaml', then json will be used as " - "the serializer, with the supplied " - "filename. " - "Default: sort_log.json"}) - - return argument_list diff --git a/tools/effmpeg/__init__.py b/tools/effmpeg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/effmpeg/cli.py b/tools/effmpeg/cli.py new file mode 100644 index 0000000000..8238855606 --- /dev/null +++ b/tools/effmpeg/cli.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +""" Command Line Arguments for tools """ +import gettext + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import ContextFullPaths, FileFullPaths, Radio +from lib.utils import get_module_objects, IMAGE_EXTENSIONS + + +# LOCALES +_LANG = gettext.translation("tools.effmpeg.cli", localedir="locales", fallback=True) +_ = _LANG.gettext + +_HELPTEXT = _("This command allows you to easily execute common ffmpeg tasks.") + + +def __parse_transpose(value: str) -> str: + """ Parse transpose option + + Parameters + ---------- + value: str + The value to parse + + Returns + ------- + str + The option item for the given value + """ + index = 0 + opts = ["(0, 90CounterClockwise&VerticalFlip)", + "(1, 90Clockwise)", + "(2, 90CounterClockwise)", + "(3, 90Clockwise&VerticalFlip)"] + if len(value) == 1: + index = int(value) + else: + for i in range(5): + if value in opts[i]: + index = i + break + return opts[index] + + +class EffmpegArgs(FaceSwapArgs): + """ Class to parse the command line arguments for EFFMPEG tool """ + + @staticmethod + def get_info(): + """ Return command information """ + return _("A wrapper for ffmpeg for performing image <> video converting.") + + @staticmethod + def get_argument_list(): + argument_list = [] + argument_list.append({ + "opts": ('-a', '--action'), + "action": Radio, + "dest": "action", + "choices": ("extract", "gen-vid", "get-fps", "get-info", "mux-audio", "rescale", + "rotate", "slice"), + "default": "extract", + "help": _("R|Choose which action you want ffmpeg ffmpeg to do." + "\nL|'extract': turns videos into images " + "\nL|'gen-vid': turns images into videos " + "\nL|'get-fps' returns the chosen video's fps." + "\nL|'get-info' returns information about a video." + "\nL|'mux-audio' add audio from one video to another." + "\nL|'rescale' resize video." + "\nL|'rotate' rotate video." + "\nL|'slice' cuts a portion of the video into a separate video file.")}) + argument_list.append({ + "opts": ('-i', '--input'), + "action": ContextFullPaths, + "dest": "input", + "default": "input", + "help": _("Input file."), + "group": _("data"), + "required": True, + "action_option": "-a", + "filetypes": "video"}) + argument_list.append({ + "opts": ('-o', '--output'), + "action": ContextFullPaths, + "group": _("data"), + "default": "", + "dest": "output", + "help": _("Output file. If no output is specified then: if the output is meant to be " + "a video then a video called 'out.mkv' will be created in the input " + "directory; if the output is meant to be a directory then a directory " + "called 'out' will be created inside the input directory. Note: the chosen " + "output file extension will determine the file encoding."), + "action_option": "-a", + "filetypes": "video"}) + argument_list.append({ + "opts": ('-r', '--reference-video'), + "action": FileFullPaths, + "dest": "ref_vid", + "group": _("data"), + "default": None, + "help": _("Path to reference video if 'input' was not a video."), + "filetypes": "video"}) + argument_list.append({ + "opts": ('-R', '--fps'), + "type": str, + "dest": "fps", + "group": _("output"), + "default": "-1.0", + "help": _("Provide video fps. Can be an integer, float or fraction. Negative values " + "will will make the program try to get the fps from the input or reference " + "videos.")}) + argument_list.append({ + "opts": ("-E", "--extract-filetype"), + "action": Radio, + "choices": IMAGE_EXTENSIONS, + "dest": "extract_ext", + "group": _("output"), + "default": ".png", + "help": _("Image format that extracted images should be saved as. '.bmp' will offer " + "the fastest extraction speed, but will take the most storage space. '.png' " + "will be slower but will take less storage.")}) + argument_list.append({ + "opts": ('-s', '--start'), + "type": str, + "dest": "start", + "group": _("clip"), + "default": "00:00:00", + "help": _("Enter the start time from which an action is to be applied. Default: " + "00:00:00, in HH:MM:SS format. You can also enter the time with or without " + "the colons, e.g. 00:0000 or 026010.")}) + argument_list.append({ + "opts": ('-e', '--end'), + "type": str, + "dest": "end", + "group": _("clip"), + "default": "00:00:00", + "help": _("Enter the end time to which an action is to be applied. If both an end " + "time and duration are set, then the end time will be used and the duration " + "will be ignored. Default: 00:00:00, in HH:MM:SS.")}) + argument_list.append({ + "opts": ('-d', '--duration'), + "type": str, + "dest": "duration", + "group": _("clip"), + "default": "00:00:00", + "help": _("Enter the duration of the chosen action, for example if you enter 00:00:10 " + "for slice, then the first 10 seconds after and including the start time " + "will be cut out into a new video. Default: 00:00:00, in HH:MM:SS format. " + "You can also enter the time with or without the colons, e.g. 00:0000 or " + "026010.")}) + argument_list.append({ + "opts": ('-m', '--mux-audio'), + "action": "store_true", + "dest": "mux_audio", + "group": _("output"), + "default": False, + "help": _("Mux the audio from the reference video into the input video. This option " + "is only used for the 'gen-vid' action. 'mux-audio' action has this turned " + "on implicitly.")}) + argument_list.append({ + "opts": ('-T', '--transpose'), + "choices": ("(0, 90CounterClockwise&VerticalFlip)", + "(1, 90Clockwise)", + "(2, 90CounterClockwise)", + "(3, 90Clockwise&VerticalFlip)"), + "type": lambda v: __parse_transpose(v), # pylint:disable=unnecessary-lambda + "dest": "transpose", + "group": _("rotate"), + "default": None, + "help": _("Transpose the video. If transpose is set, then degrees will be ignored. " + "For cli you can enter either the number or the long command name, e.g. to " + "use (1, 90Clockwise) -tr 1 or -tr 90Clockwise")}) + argument_list.append({ + "opts": ('-D', '--degrees'), + "type": str, + "dest": "degrees", + "default": None, + "group": _("rotate"), + "help": _("Rotate the video clockwise by the given number of degrees.")}) + argument_list.append({ + "opts": ('-S', '--scale'), + "type": str, + "dest": "scale", + "group": _("output"), + "default": "1920x1080", + "help": _("Set the new resolution scale if the chosen action is 'rescale'.")}) + argument_list.append({ + "opts": ('-q', '--quiet'), + "action": "store_true", + "dest": "quiet", + "group": _("settings"), + "default": False, + "help": _("Reduces output verbosity so that only serious errors are printed. If both " + "quiet and verbose are set, verbose will override quiet.")}) + argument_list.append({ + "opts": ('-v', '--verbose'), + "action": "store_true", + "dest": "verbose", + "group": _("settings"), + "default": False, + "help": _("Increases output verbosity. If both quiet and verbose are set, verbose " + "will override quiet.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/effmpeg.py b/tools/effmpeg/effmpeg.py similarity index 69% rename from tools/effmpeg.py rename to tools/effmpeg/effmpeg.py index 5930a4add8..28cce637c0 100644 --- a/tools/effmpeg.py +++ b/tools/effmpeg/effmpeg.py @@ -5,28 +5,22 @@ @author: Lev Velykoivanenko (velykoivanenko.lev@gmail.com) """ -# TODO: integrate preview into gui window -# TODO: add preview support when muxing audio -# -> figure out if ffmpeg | ffplay would work on windows and mac import logging import os -import sys import subprocess +import sys import datetime +from collections import OrderedDict -from ffmpy import FFprobe, FFmpeg, FFRuntimeError +import imageio +import imageio_ffmpeg as im_ffm +from ffmpy import FFmpeg, FFRuntimeError # faceswap imports -from lib.cli import FullHelpArgumentParser -from lib.utils import _image_extensions, _video_extensions -from . import cli +from lib.utils import (get_module_objects, handle_deprecated_cliopts, IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS) -if sys.version_info[0] < 3: - raise Exception("This program requires at least python3.2") -if sys.version_info[0] == 3 and sys.version_info[1] < 2: - raise Exception("This program requires at least python3.2") - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) class DataItem(): @@ -34,10 +28,10 @@ class DataItem(): A simple class used for storing the media data items and directories that Effmpeg uses for 'input', 'output' and 'ref_vid'. """ - vid_ext = _video_extensions + vid_ext = VIDEO_EXTENSIONS # future option in effmpeg to use audio file for muxing - audio_ext = ['.aiff', '.flac', '.mp3', '.wav'] - img_ext = _image_extensions + audio_ext = [".aiff", ".flac", ".mp3", ".wav"] + img_ext = IMAGE_EXTENSIONS def __init__(self, path=None, name=None, item_type=None, ext=None, fps=None): @@ -75,16 +69,14 @@ def set_type_ext(self, path=None): if self.path is not None: item_ext = os.path.splitext(self.path)[1].lower() if item_ext in DataItem.vid_ext: - item_type = 'vid' + item_type = "vid" elif item_ext in DataItem.audio_ext: - item_type = 'audio' + item_type = "audio" else: - item_type = 'dir' + item_type = "dir" self.type = item_type self.ext = item_ext logger.debug("path: '%s', type: '%s', ext: '%s'", self.path, self.type, self.ext) - else: - return def set_dirname(self, path=None): """ Set the folder name """ @@ -132,8 +124,6 @@ class Effmpeg(): _actions_req_fps = ["extract", "gen_vid"] _actions_req_ref_video = ["mux_audio"] - _actions_can_preview = ["gen_vid", "mux_audio", "rescale", "rotate", - "slice"] _actions_can_use_ref_video = ["gen_vid"] _actions_have_dir_output = ["extract"] _actions_have_vid_output = ["gen_vid", "mux_audio", "rescale", "rotate", @@ -144,22 +134,22 @@ class Effmpeg(): "rotate", "slice"] # Class variable that stores the target executable (ffmpeg or ffplay) - _executable = 'ffmpeg' + _executable = im_ffm.get_ffmpeg_exe() # Class variable that stores the common ffmpeg arguments based on verbosity __common_ffmpeg_args_dict = {"normal": "-hide_banner ", "quiet": "-loglevel panic -hide_banner ", - "verbose": ''} + "verbose": ""} # _common_ffmpeg_args is the class variable that will get used by various # actions and it will be set by the process_arguments() method based on # passed verbosity - _common_ffmpeg_args = '' + _common_ffmpeg_args = "" def __init__(self, arguments): logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.args = arguments - self.exe = "ffmpeg" + self.args = handle_deprecated_cliopts(arguments) + self.exe = im_ffm.get_ffmpeg_exe() self.input = DataItem() self.output = DataItem() self.ref_vid = DataItem() @@ -169,17 +159,8 @@ def __init__(self, arguments): self.print_ = False logger.debug("Initialized %s", self.__class__.__name__) - def process(self): - """ EFFMPEG Process """ - logger.debug("Running Effmpeg") - # Format action to match the method name - self.args.action = self.args.action.replace('-', '_') - logger.debug("action: '%s", self.args.action) - - # Instantiate input DataItem object - self.input = DataItem(path=self.args.input) - - # Instantiate output DataItem object + def _set_output(self) -> None: + """ Set :attr:`output` based on input arguments """ if self.args.action in self._actions_have_dir_output: self.output = DataItem(path=self.__get_default_output()) elif self.args.action in self._actions_have_vid_output: @@ -189,70 +170,101 @@ def process(self): else: self.output = DataItem(path=self.__get_default_output()) - if self.args.ref_vid is None \ - or self.args.ref_vid == '': + def _set_ref_video(self) -> None: + """ Set :attr:`ref_vid` based on input arguments """ + if self.args.ref_vid is None or self.args.ref_vid == "": self.args.ref_vid = None - # Instantiate ref_vid DataItem object self.ref_vid = DataItem(path=self.args.ref_vid) - # Check that correct input and output arguments were provided + def _check_inputs(self) -> None: + """ Validate provided arguments are valid + + Raises + ------ + ValueError + If provided arguments are not valid + """ + if self.args.action in self._actions_have_dir_input and not self.input.is_type("dir"): - raise ValueError("The chosen action requires a directory as its " - "input, but you entered: " - "{}".format(self.input.path)) + raise ValueError("The chosen action requires a directory as its input, but you " + f"entered: {self.input.path}") if self.args.action in self._actions_have_vid_input and not self.input.is_type("vid"): - raise ValueError("The chosen action requires a video as its " - "input, but you entered: " - "{}".format(self.input.path)) + raise ValueError("The chosen action requires a video as its input, but you entered: " + f"{self.input.path}") if self.args.action in self._actions_have_dir_output and not self.output.is_type("dir"): - raise ValueError("The chosen action requires a directory as its " - "output, but you entered: " - "{}".format(self.output.path)) + raise ValueError("The chosen action requires a directory as its output, but you " + f"entered: {self.output.path}") if self.args.action in self._actions_have_vid_output and not self.output.is_type("vid"): - raise ValueError("The chosen action requires a video as its " - "output, but you entered: " - "{}".format(self.output.path)) + raise ValueError("The chosen action requires a video as its output, but you entered: " + f"{self.output.path}") # Check that ref_vid is a video when it needs to be if self.args.action in self._actions_req_ref_video: if self.ref_vid.is_type("none"): - raise ValueError("The file chosen as the reference video is " - "not a video, either leave the field blank " - "or type 'None': " - "{}".format(self.ref_vid.path)) + raise ValueError("The file chosen as the reference video is not a video, either " + f"leave the field blank or type 'None': {self.ref_vid.path}") elif self.args.action in self._actions_can_use_ref_video: if self.ref_vid.is_type("none"): logger.warning("Warning: no reference video was supplied, even though " "one may be used with the chosen action. If this is " "intentional then ignore this warning.") - # Process start and duration arguments + def _set_times(self) -> None: + """Set start, end and duration attributes """ self.start = self.parse_time(self.args.start) self.end = self.parse_time(self.args.end) if not self.__check_equals_time(self.args.end, "00:00:00"): self.duration = self.__get_duration(self.start, self.end) else: self.duration = self.parse_time(str(self.args.duration)) + + def _set_fps(self) -> None: + """ Set :attr:`arguments.fps` based on input arguments""" # If fps was left blank in gui, set it to default -1.0 value - if self.args.fps == '': + if self.args.fps == "": self.args.fps = str(-1.0) # Try to set fps automatically if needed and not supplied by user if self.args.action in self._actions_req_fps \ and self.__convert_fps(self.args.fps) <= 0: - if self.__check_have_fps(['r', 'i']): + if self.__check_have_fps(["r", "i"]): _error_str = "No fps, input or reference video was supplied, " _error_str += "hence it's not possible to " - _error_str += "'{}'.".format(self.args.action) + _error_str += f"'{self.args.action}'." raise ValueError(_error_str) - elif self.output.fps is not None and self.__check_have_fps(['r', 'i']): + if self.output.fps is not None and self.__check_have_fps(["r", "i"]): self.args.fps = self.output.fps - elif self.ref_vid.fps is not None and self.__check_have_fps(['i']): + elif self.ref_vid.fps is not None and self.__check_have_fps(["i"]): self.args.fps = self.ref_vid.fps - elif self.input.fps is not None and self.__check_have_fps(['r']): + elif self.input.fps is not None and self.__check_have_fps(["r"]): self.args.fps = self.input.fps + def process(self): + """ EFFMPEG Process """ + logger.debug("Running Effmpeg") + # Format action to match the method name + self.args.action = self.args.action.replace("-", "_") + logger.debug("action: '%s'", self.args.action) + + # Instantiate input DataItem object + self.input = DataItem(path=self.args.input) + + # Instantiate output DataItem object + self._set_output() + + # Instantiate ref_vid DataItem object + self._set_ref_video() + + # Check that correct input and output arguments were provided + self._check_inputs() + + # Process start and duration arguments + self._set_times() + + # Set fps + self._set_fps() + # Processing transpose if self.args.transpose is None or \ self.args.transpose.lower() == "none": @@ -263,7 +275,7 @@ def process(self): # Processing degrees if self.args.degrees is None \ or self.args.degrees.lower() == "none" \ - or self.args.degrees == '': + or self.args.degrees == "": self.args.degrees = None elif self.args.transpose is None: try: @@ -271,12 +283,7 @@ def process(self): except ValueError: logger.error("You have entered an invalid value for degrees: %s", self.args.degrees) - exit(1) - - # Set executable based on whether previewing or not - if self.args.preview and self.args.action in self._actions_can_preview: - self.exe = 'ffplay' - self.output = DataItem() + sys.exit(1) # Set verbosity of output self.__set_verbosity(self.args.quiet, self.args.verbose) @@ -302,20 +309,19 @@ def effmpeg_process(self): "transpose": self.args.transpose, "scale": self.args.scale, "print_": self.print_, - "preview": self.args.preview, "exe": self.exe} action = getattr(self, self.args.action) action(**kwargs) @staticmethod - def extract(input_=None, output=None, fps=None, extract_ext=None, start=None, duration=None, - **kwargs): + def extract(input_=None, output=None, fps=None, # pylint:disable=unused-argument + extract_ext=None, start=None, duration=None, **kwargs): """ Extract video to image frames """ logger.debug("input_: %s, output: %s, fps: %s, extract_ext: '%s', start: %s, duration: %s", input_, output, fps, extract_ext, start, duration) _input_opts = Effmpeg._common_ffmpeg_args[:] if start is not None and duration is not None: - _input_opts += '-ss {} -t {}'.format(start, duration) + _input_opts += f"-ss {start} -t {duration}" _input = {input_.path: _input_opts} _output_opts = '-y -vf fps="' + str(fps) + '" -q:v 1' _output_path = output.path + "/" + input_.name + "_%05d" + extract_ext @@ -325,24 +331,21 @@ def extract(input_=None, output=None, fps=None, extract_ext=None, start=None, du Effmpeg.__run_ffmpeg(inputs=_input, outputs=_output) @staticmethod - def gen_vid(input_=None, output=None, fps=None, mux_audio=False, - ref_vid=None, preview=False, exe=None, **kwargs): + def gen_vid(input_=None, output=None, fps=None, # pylint:disable=unused-argument + mux_audio=False, ref_vid=None, exe=None, **kwargs): """ Generate Video """ - logger.debug("input: %s, output: %s, fps: %s, mux_audio: %s, ref_vid: '%s', preview: %s, " - "exe: '%s'", input, output, fps, mux_audio, ref_vid, preview, exe) + logger.debug("input: %s, output: %s, fps: %s, mux_audio: %s, ref_vid: '%s'exe: '%s'", + input, output, fps, mux_audio, ref_vid, exe) filename = Effmpeg.__get_extracted_filename(input_.path) _input_opts = Effmpeg._common_ffmpeg_args[:] _input_path = os.path.join(input_.path, filename) - _output_opts = '-vf fps="' + str(fps) + '" ' - if not preview: - _output_opts = '-y ' + _output_opts + ' -c:v libx264' + _fps_arg = "-r " + str(fps) + " " + _input_opts += _fps_arg + "-f image2 " + _output_opts = "-y " + _fps_arg + " -c:v libx264" if mux_audio: - _ref_vid_opts = '-c copy -map 0:0 -map 1:1' - if preview: - raise ValueError("Preview for gen-vid with audio muxing is " - "not supported.") - _output_opts = _ref_vid_opts + ' ' + _output_opts - _inputs = {_input_path: _input_opts, ref_vid.path: None} + _ref_vid_opts = "-c copy -map 0:0 -map 1:1" + _output_opts = _ref_vid_opts + " " + _output_opts + _inputs = OrderedDict([(_input_path, _input_opts), (ref_vid.path, None)]) else: _inputs = {_input_path: _input_opts} _outputs = {output.path: _output_opts} @@ -352,60 +355,55 @@ def gen_vid(input_=None, output=None, fps=None, mux_audio=False, @staticmethod def get_fps(input_=None, print_=False, **kwargs): """ Get Frames per Second """ - _input_opts = '-v error -select_streams v -of ' - _input_opts += 'default=noprint_wrappers=1:nokey=1 ' - _input_opts += '-show_entries stream=r_frame_rate' - if isinstance(input_, str): - _inputs = {input_: _input_opts} - else: - _inputs = {input_.path: _input_opts} - ffp = FFprobe(inputs=_inputs) - _fps = ffp.run(stdout=subprocess.PIPE)[0].decode("utf-8") - _fps = _fps.strip() + logger.debug("input_: %s, print_: %s, kwargs: %s", input_, print_, kwargs) + input_ = input_ if isinstance(input_, str) else input_.path + logger.debug("input: %s", input_) + reader = imageio.get_reader(input_, "ffmpeg") + _fps = reader.get_meta_data()["fps"] + logger.debug(_fps) + reader.close() if print_: logger.info("Video fps: %s", _fps) - logger.debug(_fps) return _fps @staticmethod def get_info(input_=None, print_=False, **kwargs): """ Get video Info """ - _input_opts = Effmpeg._common_ffmpeg_args[:] - _inputs = {input_.path: _input_opts} - ffp = FFprobe(inputs=_inputs) - out = ffp.run(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT)[0].decode('utf-8') - if print_: - logger.info(out) + logger.debug("input_: %s, print_: %s, kwargs: %s", input_, print_, kwargs) + input_ = input_ if isinstance(input_, str) else input_.path + logger.debug("input: %s", input_) + reader = imageio.get_reader(input_, "ffmpeg") + out = reader.get_meta_data() logger.debug(out) + reader.close() + if print_: + logger.info("======== Video Info ========",) + logger.info("path: %s", input_) + for key, val in out.items(): + logger.info("%s: %s", key, val) return out @staticmethod - def rescale(input_=None, output=None, scale=None, preview=False, exe=None, - **kwargs): + def rescale(input_=None, output=None, scale=None, # pylint:disable=unused-argument + exe=None, **kwargs): """ Rescale Video """ _input_opts = Effmpeg._common_ffmpeg_args[:] - _output_opts = '-vf scale="' + str(scale) + '"' - if not preview: - _output_opts = '-y ' + _output_opts + _output_opts = '-y -vf scale="' + str(scale) + '"' _inputs = {input_.path: _input_opts} _outputs = {output.path: _output_opts} Effmpeg.__run_ffmpeg(exe=exe, inputs=_inputs, outputs=_outputs) @staticmethod - def rotate(input_=None, output=None, degrees=None, transpose=None, - preview=None, exe=None, **kwargs): + def rotate(input_=None, output=None, degrees=None, # pylint:disable=unused-argument + transpose=None, exe=None, **kwargs): """ Rotate Video """ if transpose is None and degrees is None: - raise ValueError("You have not supplied a valid transpose or " - "degrees value:\ntranspose: {}\ndegrees: " - "{}".format(transpose, degrees)) + raise ValueError("You have not supplied a valid transpose or degrees value:\n" + f"transpose: {transpose}\ndegrees: {degrees}") _input_opts = Effmpeg._common_ffmpeg_args[:] - _output_opts = '-vf ' - if not preview: - _output_opts = '-y -c:a copy ' + _output_opts - _bilinear = '' + _output_opts = "-y -c:a copy -vf " + _bilinear = "" if transpose is not None: _output_opts += 'transpose="' + str(transpose) + '"' elif int(degrees) != 0: @@ -419,29 +417,23 @@ def rotate(input_=None, output=None, degrees=None, transpose=None, Effmpeg.__run_ffmpeg(exe=exe, inputs=_inputs, outputs=_outputs) @staticmethod - def mux_audio(input_=None, output=None, ref_vid=None, preview=None, + def mux_audio(input_=None, output=None, ref_vid=None, # pylint:disable=unused-argument exe=None, **kwargs): """ Mux Audio """ _input_opts = Effmpeg._common_ffmpeg_args[:] _ref_vid_opts = None - _output_opts = '-y -c copy -map 0:0 -map 1:1 -shortest' - if preview: - raise ValueError("Preview with audio muxing is not supported.") - # if not preview: - # _output_opts = '-y ' + _output_opts - _inputs = {input_.path: _input_opts, ref_vid.path: _ref_vid_opts} + _output_opts = "-y -c copy -map 0:0 -map 1:1 -shortest" + _inputs = OrderedDict([(input_.path, _input_opts), (ref_vid.path, _ref_vid_opts)]) _outputs = {output.path: _output_opts} Effmpeg.__run_ffmpeg(exe=exe, inputs=_inputs, outputs=_outputs) @staticmethod - def slice(input_=None, output=None, start=None, duration=None, - preview=None, exe=None, **kwargs): + def slice(input_=None, output=None, start=None, # pylint:disable=unused-argument + duration=None, exe=None, **kwargs): """ Slice Video """ _input_opts = Effmpeg._common_ffmpeg_args[:] _input_opts += "-ss " + start _output_opts = "-t " + duration + " " - if not preview: - _output_opts = '-y ' + _output_opts + "-vcodec copy -acodec copy" _inputs = {input_.path: _input_opts} _output = {output.path: _output_opts} Effmpeg.__run_ffmpeg(exe=exe, inputs=_inputs, outputs=_output) @@ -459,35 +451,36 @@ def __set_verbosity(cls, quiet, verbose): def __get_default_output(self): """ Set output to the same directory as input if the user didn't specify it. """ + retval = "" if self.args.output == "": if self.args.action in self._actions_have_dir_output: - retval = os.path.join(self.input.dirname, 'out') + retval = os.path.join(self.input.dirname, "out") elif self.args.action in self._actions_have_vid_output: if self.input.is_type("media"): # Using the same extension as input leads to very poor # output quality, hence the default is mkv for now retval = os.path.join(self.input.dirname, "out.mkv") # + self.input.ext) else: # case if input was a directory - retval = os.path.join(self.input.dirname, 'out.mkv') + retval = os.path.join(self.input.dirname, "out.mkv") else: retval = self.args.output logger.debug(retval) return retval def __check_have_fps(self, items): - items_to_check = list() + items_to_check = [] for i in items: - if i == 'r': - items_to_check.append('ref_vid') - elif i == 'i': - items_to_check.append('input') - elif i == 'o': - items_to_check.append('output') + if i == "r": + items_to_check.append("ref_vid") + elif i == "i": + items_to_check.append("input") + elif i == "o": + items_to_check.append("output") return all(getattr(self, i).fps is None for i in items_to_check) @staticmethod - def __run_ffmpeg(exe="ffmpeg", inputs=None, outputs=None): + def __run_ffmpeg(exe=im_ffm.get_ffmpeg_exe(), inputs=None, outputs=None): """ Run ffmpeg """ logger.debug("Running ffmpeg: (exe: '%s', inputs: %s, outputs: %s", exe, inputs, outputs) ffm = FFmpeg(executable=exe, inputs=inputs, outputs=outputs) @@ -498,8 +491,7 @@ def __run_ffmpeg(exe="ffmpeg", inputs=None, outputs=None): if ffe.exit_code == 255: pass else: - raise ValueError("An unexpected FFRuntimeError occurred: " - "{}".format(ffe)) + raise ValueError(f"An unexpected FFRuntimeError occurred: {ffe}") from ffe except KeyboardInterrupt: pass # Do nothing if voluntary interruption logger.debug("ffmpeg finished") @@ -507,8 +499,8 @@ def __run_ffmpeg(exe="ffmpeg", inputs=None, outputs=None): @staticmethod def __convert_fps(fps): """ Convert to Frames per Second """ - if '/' in fps: - _fps = fps.split('/') + if "/" in fps: + _fps = fps.split("/") retval = float(_fps[0]) / float(_fps[1]) else: retval = float(fps) @@ -518,15 +510,13 @@ def __convert_fps(fps): @staticmethod def __get_duration(start_time, end_time): """ Get the duration """ - start = [int(i) for i in start_time.split(':')] - end = [int(i) for i in end_time.split(':')] + start = [int(i) for i in start_time.split(":")] + end = [int(i) for i in end_time.split(":")] start = datetime.timedelta(hours=start[0], minutes=start[1], seconds=start[2]) end = datetime.timedelta(hours=end[0], minutes=end[1], seconds=end[2]) delta = end - start secs = delta.total_seconds() - retval = '{:02}:{:02}:{:02}'.format(int(secs // 3600), - int(secs % 3600 // 60), - int(secs % 60)) + retval = f"{int(secs // 3600):02}:{int(secs % 3600 // 60):02}:{int(secs % 60):02}" logger.debug(retval) return retval @@ -534,17 +524,16 @@ def __get_duration(start_time, end_time): def __get_extracted_filename(path): """ Get the extracted filename """ logger.debug("path: '%s'", path) - filename = '' + filename = "" for file in os.listdir(path): if any(i in file for i in DataItem.img_ext): filename = file break logger.debug("sample filename: '%s'", filename) - filename = filename.split('.') - img_ext = filename[-1] - zero_pad = Effmpeg.__get_zero_pad(filename[-2]) - name = filename[-2][:-zero_pad] - retval = "{}%{}d.{}".format(name, zero_pad, img_ext) + filename, img_ext = os.path.splitext(filename) + zero_pad = Effmpeg.__get_zero_pad(filename) + name = filename[:-zero_pad] + retval = f"{name}%{zero_pad}d{img_ext}" logger.debug("filename: %s, img_ext: '%s', zero_pad: %s, name: '%s'", filename, img_ext, zero_pad, name) logger.debug(retval) @@ -554,26 +543,19 @@ def __get_extracted_filename(path): def __get_zero_pad(filename): """ Return the starting position of zero padding from a filename """ chkstring = filename[::-1] + logger.trace("filename: %s, chkstring: %s", filename, chkstring) pos = 0 - for pos in range(len(chkstring)): - if not chkstring[pos].isdigit(): + for char in chkstring: + if not char.isdigit(): break logger.debug("filename: '%s', pos: %s", filename, pos) return pos - @staticmethod - def __check_is_valid_time(value): - """ Check valid time """ - val = value.replace(':', '') - retval = val.isdigit() - logger.debug("value: '%s', retval: %s", value, retval) - return retval - @staticmethod def __check_equals_time(value, time): """ Check equals time """ - val = value.replace(':', '') - tme = time.replace(':', '') + val = value.replace(":", "") + tme = time.replace(":", "") retval = val.zfill(6) == tme.zfill(6) logger.debug("value: '%s', time: %s, retval: %s", value, time, retval) return retval @@ -581,28 +563,13 @@ def __check_equals_time(value, time): @staticmethod def parse_time(txt): """ Parse Time """ - clean_txt = txt.replace(':', '') + clean_txt = txt.replace(":", "") hours = clean_txt[0:2] minutes = clean_txt[2:4] seconds = clean_txt[4:6] - retval = hours + ':' + minutes + ':' + seconds + retval = hours + ":" + minutes + ":" + seconds logger.debug("txt: '%s', retval: %s", txt, retval) return retval -def bad_args(args): # pylint: disable=unused-argument - """ Print help on bad arguments """ - PARSER.print_help() - exit(0) - - -if __name__ == "__main__": - print('"Easy"-ffmpeg wrapper.\n') - - PARSER = FullHelpArgumentParser() - SUBPARSER = PARSER.add_subparsers() - EFFMPEG = cli.EffmpegArgs( - SUBPARSER, "effmpeg", "Wrapper for various common ffmpeg commands.") - PARSER.set_defaults(func=bad_args) - ARGUMENTS = PARSER.parse_args() - ARGUMENTS.func(ARGUMENTS) +__all__ = get_module_objects(__name__) diff --git a/tools/lib_alignments/__init__.py b/tools/lib_alignments/__init__.py deleted file mode 100644 index 95a74e602b..0000000000 --- a/tools/lib_alignments/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from tools.lib_alignments.media import AlignmentData, ExtractedFaces, Faces, Frames -from tools.lib_alignments.annotate import Annotate -from tools.lib_alignments.jobs import Check, Draw, Extract, Legacy, Merge, Reformat, RemoveAlignments, Rename, Sort, Spatial, UpdateHashes -from tools.lib_alignments.jobs_manual import Manual diff --git a/tools/lib_alignments/annotate.py b/tools/lib_alignments/annotate.py deleted file mode 100644 index 274c86015f..0000000000 --- a/tools/lib_alignments/annotate.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -""" Tools for annotating an input image """ - -import logging - -import cv2 -import numpy as np - -from lib.align_eyes import FACIAL_LANDMARKS_IDXS - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Annotate(): - """ Annotate an input image """ - - def __init__(self, image, alignments, original_roi=None): - logger.debug("Initializing %s: (alignments: %s, original_roi: %s)", - self.__class__.__name__, alignments, original_roi) - self.image = image - self.alignments = alignments - self.roi = original_roi - self.colors = {1: (255, 0, 0), - 2: (0, 255, 0), - 3: (0, 0, 255), - 4: (255, 255, 0), - 5: (255, 0, 255), - 6: (0, 255, 255)} - logger.debug("Initialized %s", self.__class__.__name__) - - def draw_black_image(self): - """ Change image to black at correct dimensions """ - logger.trace("Drawing black image") - height, width = self.image.shape[:2] - self.image = np.zeros((height, width, 3), np.uint8) - - def draw_bounding_box(self, color_id=1, thickness=1): - """ Draw the bounding box around faces """ - color = self.colors[color_id] - for alignment in self.alignments: - top_left = (alignment["x"], alignment["y"]) - bottom_right = (alignment["x"] + alignment["w"], alignment["y"] + alignment["h"]) - logger.trace("Drawing bounding box: (top_left: %s, bottom_right: %s, color: %s, " - "thickness: %s)", top_left, bottom_right, color, thickness) - cv2.rectangle(self.image, # pylint: disable=no-member - top_left, - bottom_right, - color, - thickness) - - def draw_extract_box(self, color_id=2, thickness=1): - """ Draw the extracted face box """ - if not self.roi: - return - color = self.colors[color_id] - for idx, roi in enumerate(self.roi): - logger.trace("Drawing Extract Box: (idx: %s, roi: %s)", idx, roi) - top_left = [point for point in roi.squeeze()[0]] - top_left = (top_left[0], top_left[1] - 10) - cv2.putText(self.image, # pylint: disable=no-member - str(idx), - top_left, - cv2.FONT_HERSHEY_DUPLEX, # pylint: disable=no-member - 1.0, - color, - thickness) - cv2.polylines(self.image, [roi], True, color, thickness) # pylint: disable=no-member - - def draw_landmarks(self, color_id=3, radius=1): - """ Draw the facial landmarks """ - color = self.colors[color_id] - for alignment in self.alignments: - landmarks = alignment["landmarksXY"] - logger.trace("Drawing Landmarks: (landmarks: %s, color: %s, radius: %s)", - landmarks, color, radius) - for (pos_x, pos_y) in landmarks: - cv2.circle(self.image, # pylint: disable=no-member - (pos_x, pos_y), - radius, - color, - -1) - - def draw_landmarks_mesh(self, color_id=4, thickness=1): - """ Draw the facial landmarks """ - color = self.colors[color_id] - for alignment in self.alignments: - landmarks = alignment["landmarksXY"] - logger.trace("Drawing Landmarks Mesh: (landmarks: %s, color: %s, thickness: %s)", - landmarks, color, thickness) - for key, val in FACIAL_LANDMARKS_IDXS.items(): - points = np.array([landmarks[val[0]:val[1]]], np.int32) - fill_poly = bool(key in ("right_eye", "left_eye", "mouth")) - cv2.polylines(self.image, # pylint: disable=no-member - points, - fill_poly, - color, - thickness) - - def draw_grey_out_faces(self, live_face): - """ Grey out all faces except target """ - if not self.roi: - return - alpha = 0.6 - overlay = self.image.copy() - for idx, roi in enumerate(self.roi): - if idx != int(live_face): - logger.trace("Greying out face: (idx: %s, roi: %s)", idx, roi) - cv2.fillPoly(overlay, roi, (0, 0, 0)) # pylint: disable=no-member - cv2.addWeighted(overlay, # pylint: disable=no-member - alpha, - self.image, - 1 - alpha, - 0, - self.image) diff --git a/tools/lib_alignments/jobs.py b/tools/lib_alignments/jobs.py deleted file mode 100644 index ac81170637..0000000000 --- a/tools/lib_alignments/jobs.py +++ /dev/null @@ -1,1008 +0,0 @@ -#!/usr/bin/env python3 -""" Tools for manipulating the alignments serialized file """ - -import logging -import os -import pickle -import struct -from datetime import datetime - -import numpy as np -from scipy import signal -from sklearn import decomposition -from tqdm import tqdm - -from . import AlignmentData, Annotate, ExtractedFaces, Faces, Frames - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Check(): - """ Frames and faces checking tasks """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.alignments = alignments - self.job = arguments.job - self.type = None - self.output = arguments.output - self.source_dir = self.get_source_dir(arguments) - self.validate() - self.items = self.get_items() - - self.output_message = "" - logger.debug("Initialized %s", self.__class__.__name__) - - def get_source_dir(self, arguments): - """ Set the correct source folder """ - if (hasattr(arguments, "faces_dir") and arguments.faces_dir and - hasattr(arguments, "frames_dir") and arguments.frames_dir): - logger.error("Only select a source frames (-fr) or source faces (-fc) folder") - exit(0) - elif hasattr(arguments, "faces_dir") and arguments.faces_dir: - self.type = "faces" - source_dir = arguments.faces_dir - elif hasattr(arguments, "frames_dir") and arguments.frames_dir: - self.type = "frames" - source_dir = arguments.frames_dir - else: - logger.error("No source folder (-fr or -fc) was provided") - exit(0) - logger.debug("type: '%s', source_dir: '%s'", self.type, source_dir) - return source_dir - - def get_items(self): - """ Set the correct items to process """ - items = globals()[self.type.title()] - return items(self.source_dir).file_list_sorted - - def process(self): - """ Process the frames check against the alignments file """ - logger.info("[CHECK %s]", self.type.upper()) - items_output = self.compile_output() - self.output_results(items_output) - - def validate(self): - """ Check that the selected type is valid for - selected task and job """ - if self.job == "missing-frames" and self.output == "move": - logger.warning("Missing_frames was selected with move output, but there will " - "be nothing to move. Defaulting to output: console") - self.output = "console" - if self.type == "faces" and self.job not in ("multi-faces", "leftover-faces"): - logger.warning("The selected folder is not valid. Faces folder (-fc) is only " - "supported for 'multi-faces' and 'leftover-faces'") - exit(0) - - def compile_output(self): - """ Compile list of frames that meet criteria """ - action = self.job.replace("-", "_") - processor = getattr(self, "get_{}".format(action)) - logger.debug("Processor: %s", processor) - return [item for item in processor()] - - def get_no_faces(self): - """ yield each frame that has no face match in alignments file """ - self.output_message = "Frames with no faces" - for frame in tqdm(self.items, desc=self.output_message): - logger.trace(frame) - frame_name = frame["frame_fullname"] - if not self.alignments.frame_has_faces(frame_name): - logger.debug("Returning: '%s'", frame_name) - yield frame_name - - def get_multi_faces(self): - """ yield each frame that has multiple faces - matched in alignments file """ - if self.type == "faces": - self.output_message = "Multiple faces in frame" - frame_key = "face_hash" - retval_key = "face_fullname" - elif self.type == "frames": - self.output_message = "Frames with multiple faces" - frame_key = "frame_fullname" - retval_key = "frame_fullname" - logger.debug("frame_key: '%s', retval_key: '%s'", frame_key, retval_key) - - for item in tqdm(self.items, desc=self.output_message): - frame = item[frame_key] - if self.type == "faces": - frame_idx = [(frame, idx) - for frame, idx in self.alignments.hashes_to_frame[frame].items()] - retval = item[retval_key] - for frame, idx in frame_idx: - if self.alignments.frame_has_multiple_faces(frame): - if self.type == "faces": - # Add correct alignments index for moving faces - retval = (retval, idx) - logger.trace("Returning: '%s'", retval) - yield retval - - def get_missing_alignments(self): - """ yield each frame that does not exist in alignments file """ - self.output_message = "Frames missing from alignments file" - exclude_filetypes = ["yaml", "yml", "p", "json", "txt"] - for frame in tqdm(self.items, desc=self.output_message): - frame_name = frame["frame_fullname"] - if (frame["frame_extension"] not in exclude_filetypes - and not self.alignments.frame_exists(frame_name)): - logger.debug("Returning: '%s'", frame_name) - yield frame_name - - def get_missing_frames(self): - """ yield each frame in alignments that does - not have a matching file """ - self.output_message = "Missing frames that are in alignments file" - frames = [item["frame_fullname"] for item in self.items] - for frame in tqdm(self.alignments.data.keys(), desc=self.output_message): - if frame not in frames: - logger.debug("Returning: '%s'", frame) - yield frame - - def get_leftover_faces(self): - """yield each face that isn't in the alignments file.""" - self.output_message = "Faces missing from the alignments file" - for face in tqdm(self.items, desc=self.output_message): - f_hash = face["face_hash"] - if not self.alignments.hashes_to_frame.get(f_hash, None): - logger.debug("Returning: '%s'", face["face_fullname"]) - yield face["face_fullname"] - - def output_results(self, items_output): - """ Output the results in the requested format """ - if not items_output: - logger.info("No %s were found meeting the criteria", self.type) - return - if self.output == "move": - self.move_file(items_output) - return - if self.job == "multi-faces": - # Strip the index for printed/file output - items_output = [item[0] for item in items_output] - output_message = "-----------------------------------------------\r\n" - output_message += " {} ({})\r\n".format(self.output_message, - len(items_output)) - output_message += "-----------------------------------------------\r\n" - output_message += "\r\n".join([frame for frame in items_output]) - if self.output == "console": - for line in output_message.splitlines(): - logger.info(line) - if self.output == "file": - self.output_file(output_message, len(items_output)) - - def output_file(self, output_message, items_discovered): - """ Save the output to a text file in the frames directory """ - now = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = "{}_{}.txt".format(self.output_message.replace(" ", "_").lower(), now) - output_file = os.path.join(self.source_dir, filename) - logger.info("Saving %s result(s) to '%s'", items_discovered, output_file) - with open(output_file, "w") as f_output: - f_output.write(output_message) - - def move_file(self, items_output): - """ Move the identified frames to a new subfolder """ - now = datetime.now().strftime("%Y%m%d_%H%M%S") - folder_name = "{}_{}".format(self.output_message.replace(" ", "_").lower(), now) - output_folder = os.path.join(self.source_dir, folder_name) - logger.debug("Creating folder: '%s'", output_folder) - os.makedirs(output_folder) - move = getattr(self, "move_{}".format(self.type)) - logger.debug("Move function: %s", move) - move(output_folder, items_output) - - def move_frames(self, output_folder, items_output): - """ Move frames into single subfolder """ - logger.info("Moving %s frame(s) to '%s'", len(items_output), output_folder) - for frame in items_output: - src = os.path.join(self.source_dir, frame) - dst = os.path.join(output_folder, frame) - logger.debug("Moving: '%s' to '%s'", src, dst) - os.rename(src, dst) - - def move_faces(self, output_folder, items_output): - """ Make additional subfolders for each face that appears - Enables easier manual sorting """ - logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder) - for frame, idx in items_output: - src = os.path.join(self.source_dir, frame) - dst_folder = os.path.join(output_folder, str(idx)) - if not os.path.isdir(dst_folder): - logger.debug("Creating folder: '%s'", dst_folder) - os.makedirs(dst_folder) - dst = os.path.join(dst_folder, frame) - logger.debug("Moving: '%s' to '%s'", src, dst) - os.rename(src, dst) - - -class Draw(): - """ Draw Alignments on passed in images """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.arguments = arguments - self.alignments = alignments - self.frames = Frames(arguments.frames_dir) - self.output_folder = self.set_output() - self.extracted_faces = None - logger.debug("Initialized %s", self.__class__.__name__) - - def set_output(self): - """ Set the output folder path """ - now = datetime.now().strftime("%Y%m%d_%H%M%S") - folder_name = "drawn_landmarks_{}".format(now) - if self.frames.vid_cap: - dest_folder = os.path.split(self.frames.folder)[0] - else: - dest_folder = self.frames.folder - output_folder = os.path.join(dest_folder, folder_name) - logger.debug("Creating folder: '%s'", output_folder) - os.makedirs(output_folder) - return output_folder - - def process(self): - """ Run the draw alignments process """ - legacy = Legacy(self.alignments, None, frames=self.frames, child_process=True) - legacy.process() - - logger.info("[DRAW LANDMARKS]") # Tidy up cli output - self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256, - align_eyes=self.arguments.align_eyes) - frames_drawn = 0 - for frame in tqdm(self.frames.file_list_sorted, desc="Drawing landmarks"): - frame_name = frame["frame_fullname"] - - if not self.alignments.frame_exists(frame_name): - logger.verbose("Skipping '%s' - Alignments not found", frame_name) - continue - - self.annotate_image(frame_name) - frames_drawn += 1 - logger.info("%s Frame(s) output", frames_drawn) - - def annotate_image(self, frame): - """ Draw the alignments """ - logger.trace("Annotating frame: '%s'", frame) - alignments = self.alignments.get_faces_in_frame(frame) - image = self.frames.load_image(frame) - self.extracted_faces.get_faces_in_frame(frame) - original_roi = [face.original_roi - for face in self.extracted_faces.faces] - annotate = Annotate(image, alignments, original_roi) - annotate.draw_bounding_box(1, 1) - annotate.draw_extract_box(2, 1) - annotate.draw_landmarks(3, 1) - annotate.draw_landmarks_mesh(4, 1) - - image = annotate.image - self.frames.save_image(self.output_folder, frame, image) - - -class Extract(): - """ Re-extract faces from source frames based on - Alignment data """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.alignments = alignments - self.arguments = arguments - self.type = arguments.job.replace("extract-", "") - self.faces_dir = arguments.faces_dir - self.frames = Frames(arguments.frames_dir) - self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=arguments.size, - align_eyes=arguments.align_eyes) - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Run extraction """ - logger.info("[EXTRACT FACES]") # Tidy up cli output - self.check_folder() - self.export_faces() - - def check_folder(self): - """ Check that the faces folder doesn't pre-exist - and create """ - err = None - if not self.faces_dir: - err = "ERROR: Output faces folder not provided." - elif not os.path.isdir(self.faces_dir): - logger.debug("Creating folder: '%s'", self.faces_dir) - os.makedirs(self.faces_dir) - elif os.listdir(self.faces_dir): - err = "ERROR: Output faces folder should be empty: '{}'".format(self.faces_dir) - if err: - logger.error(err) - exit(0) - logger.verbose("Creating output folder at '%s'", self.faces_dir) - - def export_faces(self): - """ Export the faces """ - extracted_faces = 0 - - for frame in tqdm(self.frames.file_list_sorted, desc="Saving extracted faces"): - frame_name = frame["frame_fullname"] - - if not self.alignments.frame_exists(frame_name): - logger.verbose("Skipping '%s' - Alignments not found", frame_name) - continue - extracted_faces += self.output_faces(frame) - - if extracted_faces != 0 and self.type != "large": - self.alignments.save() - logger.info("%s face(s) extracted", extracted_faces) - - def output_faces(self, frame): - """ Output the frame's faces to file """ - logger.trace("Outputting frame: %s", frame) - face_count = 0 - frame_fullname = frame["frame_fullname"] - frame_name = frame["frame_name"] - extension = os.path.splitext(frame_fullname)[1] - faces = self.select_valid_faces(frame_fullname) - - for idx, face in enumerate(faces): - output = "{}_{}{}".format(frame_name, str(idx), extension) - if self.type == "large": - self.frames.save_image(self.faces_dir, output, face.aligned_face) - else: - output = os.path.join(self.faces_dir, output) - f_hash = self.extracted_faces.save_face_with_hash(output, - extension, - face.aligned_face) - self.alignments.data[frame_fullname][idx]["hash"] = f_hash - face_count += 1 - return face_count - - def select_valid_faces(self, frame): - """ Return valid faces for extraction """ - faces = self.extracted_faces.get_faces_in_frame(frame) - if self.type != "large": - valid_faces = faces - else: - sizes = self.extracted_faces.get_roi_size_for_frame(frame) - valid_faces = [faces[idx] for idx, size in enumerate(sizes) - if size >= self.extracted_faces.size] - logger.trace("frame: '%s', total_faces: %s, valid_faces: %s", - frame, len(faces), len(valid_faces)) - return valid_faces - - -class Legacy(): - """ Update legacy alignments: - - Rotate landmarks and bounding boxes on legacy alignments - and remove the 'r' parameter - - Add face hashes to alignments file - """ - - def __init__(self, alignments, arguments, frames=None, faces=None, child_process=False): - logger.debug("Initializing %s: (arguments: %s, child_process: %s)", - self.__class__.__name__, arguments, child_process) - self.alignments = alignments - if child_process: - self.frames = frames - self.faces = faces - else: - self.frames = Frames(arguments.frames_dir) - self.faces = Faces(arguments.faces_dir) - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Run the rotate alignments process """ - rotated = self.alignments.get_legacy_rotation() - hashes = self.alignments.get_legacy_no_hashes() - if (not self.frames or not rotated) and (not self.faces or not hashes): - return - logger.info("[UPDATE LEGACY LANDMARKS]") # Tidy up cli output - if rotated and self.frames: - logger.info("Legacy rotated frames found. Converting...") - self.rotate_landmarks(rotated) - self.alignments.save() - if hashes and self.faces: - logger.info("Legacy alignments found. Adding Face Hashes...") - self.add_hashes(hashes) - self.alignments.save() - - def rotate_landmarks(self, rotated): - """ Rotate the landmarks """ - for rotate_item in tqdm(rotated, desc="Rotating Landmarks"): - frame = self.frames.get(rotate_item, None) - if frame is None: - continue - self.alignments.rotate_existing_landmarks(rotate_item, frame) - - def add_hashes(self, hashes): - """ Add Face Hashes to the alignments file """ - all_faces = dict() - logger.info("Getting original filenames, indexes and hashes...") - for face in self.faces.file_list_sorted: - filename = face["face_name"] - extension = face["face_extension"] - if "_" not in face["face_name"]: - logger.warning("Unable to determine index of file. Skipping: '%s'", filename) - continue - index = filename[filename.rfind("_") + 1:] - if not index.isdigit(): - logger.warning("Unable to determine index of file. Skipping: '%s'", filename) - continue - orig_frame = filename[:filename.rfind("_")] + extension - all_faces.setdefault(orig_frame, dict())[int(index)] = face["face_hash"] - - logger.info("Updating hashes to alignments...") - for frame in hashes: - if frame not in all_faces.keys(): - logger.warning("Skipping missing frame: '%s'", frame) - continue - self.alignments.add_face_hashes(frame, all_faces[frame]) - - -class Merge(): - """ Merge two alignments files into one """ - def __init__(self, alignments, arguments): - self.alignments = alignments - self.alignments2 = AlignmentData(arguments.alignments_file2, "json") - - def process(self): - """Process the alignments file merge """ - logger.info("[MERGE ALIGNMENTS]") # Tidy up cli output - skip_count = 0 - merge_count = 0 - for _, src_alignments, _, frame in tqdm(self.alignments2.yield_faces(), - desc="Merging Alignments", - total=self.alignments2.frames_count): - for idx, alignment in enumerate(src_alignments): - if not alignment.get("hash", None): - logger.warning("Alignment '%s':%s has no Hash! Skipping", frame, idx) - skip_count += 1 - continue - if self.check_exists(frame, alignment, idx): - skip_count += 1 - continue - self.merge_alignment(frame, alignment, idx) - merge_count += 1 - logger.info("Alignments Merged: %s", merge_count) - logger.info("Alignments Skipped: %s", skip_count) - if merge_count != 0: - self.set_destination_filename() - self.alignments.save() - - def check_exists(self, frame, alignment, idx): - """ Check whether this face already exists """ - existing_frame = self.alignments.hashes_to_frame.get(alignment["hash"], None) - if not existing_frame: - return False - if frame in existing_frame.keys(): - logger.verbose("Face '%s': %s already exists in destination at position %s. " - "Skipping", frame, idx, existing_frame[frame]) - elif frame not in existing_frame.keys(): - logger.verbose("Face '%s': %s exists in destination as: %s. " - "Skipping", frame, idx, existing_frame) - return True - - def merge_alignment(self, frame, alignment, idx): - """ Merge the source alignment into the destination """ - logger.debug("Merging alignment: (frame: %s, src_idx: %s, hash: %s)", - frame, idx, alignment["hash"]) - self.alignments.data.setdefault(frame, list()).append(alignment) - - def set_destination_filename(self): - """ Set the destination filename """ - orig, ext = os.path.splitext(self.alignments.file) - filename = "{}_merged{}".format(orig, ext) - logger.debug("Output set to: '%s'", filename) - self.alignments.file = filename - - -class Reformat(): - """ Reformat Alignment file """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.alignments = alignments - if self.alignments.file == "dfl.json": - logger.debug("Loading DFL faces") - self.faces = Faces(arguments.faces_dir) - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Run reformat """ - logger.info("[REFORMAT ALIGNMENTS]") # Tidy up cli output - if self.alignments.file == "dfl.json": - self.alignments.data = self.load_dfl() - self.alignments.file = self.alignments.get_location(self.faces.folder, "alignments") - self.alignments.save() - - def load_dfl(self): - """ Load alignments from DeepFaceLab and format for Faceswap """ - alignments = dict() - for face in tqdm(self.faces.file_list_sorted, desc="Converting DFL Faces"): - if face["face_extension"] != ".png": - logger.verbose("'%s' is not a png. Skipping", face["face_fullname"]) - continue - f_hash = face["face_hash"] - fullpath = os.path.join(self.faces.folder, face["face_fullname"]) - dfl = self.get_dfl_alignment(fullpath) - - if not dfl: - continue - - self.convert_dfl_alignment(dfl, f_hash, alignments) - return alignments - - @staticmethod - def get_dfl_alignment(filename): - """ Process the alignment of one face """ - with open(filename, "rb") as dfl: - header = dfl.read(8) - if header != b"\x89PNG\r\n\x1a\n": - logger.error("No Valid PNG header: %s", filename) - return None - while True: - chunk_start = dfl.tell() - chunk_hdr = dfl.read(8) - if not chunk_hdr: - break - chunk_length, chunk_name = struct.unpack("!I4s", chunk_hdr) - dfl.seek(chunk_start, os.SEEK_SET) - if chunk_name == b"fcWp": - chunk = dfl.read(chunk_length + 12) - retval = pickle.loads(chunk[8:-4]) - logger.trace("Loaded DFL Alignment: (filename: '%s', alignment: %s", - filename, retval) - return retval - dfl.seek(chunk_length+12, os.SEEK_CUR) - logger.error("Couldn't find DFL alignments: %s", filename) - - @staticmethod - def convert_dfl_alignment(dfl_alignments, f_hash, alignments): - """ Add DFL Alignments to alignments in Faceswap format """ - sourcefile = dfl_alignments["source_filename"] - left, top, right, bottom = dfl_alignments["source_rect"] - alignment = {"x": left, - "w": right - left, - "y": top, - "h": bottom - top, - "hash": f_hash, - "landmarksXY": dfl_alignments["source_landmarks"]} - logger.trace("Adding alignment: (frame: '%s', alignment: %s", sourcefile, alignment) - alignments.setdefault(sourcefile, list()).append(alignment) - - -class RemoveAlignments(): - """ Remove items from alignments file """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.alignments = alignments - self.type = arguments.job.replace("remove-", "") - self.items = self.get_items(arguments) - self.removed = set() - logger.debug("Initialized %s", self.__class__.__name__) - - def get_items(self, arguments): - """ Set the correct items to process """ - retval = None - if self.type == "frames": - retval = list(Frames(arguments.frames_dir).items.keys()) - elif self.type == "faces": - retval = Faces(arguments.faces_dir) - return retval - - def process(self): - """ run removal """ - if self.type == "faces": - legacy = Legacy(self.alignments, None, faces=self.items, child_process=True) - legacy.process() - - logger.info("[REMOVE ALIGNMENTS DATA]") # Tidy up cli output - del_count = 0 - task = getattr(self, "remove_{}".format(self.type)) - - if self.type == "frames": - logger.debug("Removing Frames") - for frame in tqdm(list(item[3] for item in self.alignments.yield_faces()), - desc="Removing Frames", - total=self.alignments.frames_count): - del_count += task(frame) - else: - logger.debug("Removing Faces") - del_count = task() - - if del_count == 0: - logger.info("No changes made to alignments file. Exiting") - return - - logger.info("%s alignment(s) were removed from alignments file", del_count) - self.alignments.save() - - if self.type == "faces": - rename = Rename(self.alignments, None, self.items) - rename.process() - - def remove_frames(self, frame): - """ Process to remove frames from an alignments file """ - if frame in self.items: - logger.trace("Not deleting frame: '%s'", frame) - return 0 - logger.debug("Deleting frame: '%s'", frame) - del self.alignments.data[frame] - return 1 - - def remove_faces(self): - """ Process to remove faces from an alignments file """ - face_hashes = list(self.items.items.keys()) - if not face_hashes: - logger.error("No face hashes. This would remove all faces from your alignments file.") - return 0 - pre_face_count = self.alignments.faces_count - self.alignments.filter_hashes(face_hashes, filter_out=False) - post_face_count = self.alignments.faces_count - return pre_face_count - post_face_count - - def remove_alignment(self, item): - """ Remove the alignment from the alignments file """ - del_count = 0 - frame_name, alignments, number_alignments = item[:3] - for idx in self.alignments.yield_original_index_reverse(alignments, number_alignments): - face_indexes = self.items.items.get(frame_name, [-1]) - if idx not in face_indexes: - del alignments[idx] - self.removed.add(frame_name) - logger.verbose("Removed alignment data for image: '%s' index: %s", - frame_name, str(idx)) - del_count += 1 - else: - logger.trace("Not removing alignment data for image: '%s' index: %s", - frame_name, str(idx)) - return del_count - - -class Rename(): - """ Rename faces to match their source frame and position index """ - def __init__(self, alignments, arguments, faces=None): - logger.debug("Initializing %s: (arguments: %s, faces: %s)", - self.__class__.__name__, arguments, faces) - self.alignments = alignments - self.faces = faces if faces else Faces(arguments.faces_dir) - self.seen_multihash = set() - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Process the face renaming """ - logger.info("[RENAME FACES]") # Tidy up cli output - rename_count = 0 - for frame, _, _, frame_fullname in tqdm(self.alignments.yield_faces(), - desc="Renaming Faces", - total=self.alignments.frames_count): - rename_count += self.rename_faces(frame, frame_fullname) - logger.info("%s faces renamed", rename_count) - - def rename_faces(self, frame, frame_fullname): - """ Rename faces - Done in 2 iterations as two files cannot share the same name """ - logger.trace("Renaming faces for frame: '%s'", frame_fullname) - temp_ext = ".temp_move" - frame_faces = list() - frame_faces = [(f_hash, idx) - for f_hash, details in self.alignments.hashes_to_frame.items() - for frame_name, idx in details.items() - if frame_name == frame_fullname] - rename_count = 0 - rename_files = list() - for f_hash, idx in frame_faces: - faces = self.faces.items[f_hash] - if len(faces) == 1: - face_name, face_ext = faces[0] - else: - face_name, face_ext = self.check_multi_hashes(faces, frame, idx) - old = face_name + face_ext - new = "{}_{}{}".format(frame, idx, face_ext) - if old == new: - logger.trace("Face does not require renaming: '%s'", old) - continue - rename_files.append((old, new)) - for action in ("temp", "final"): - for files in rename_files: - old, new = files - old_file = old if action == "temp" else old + temp_ext - new_file = old + temp_ext if action == "temp" else new - src = os.path.join(self.faces.folder, old_file) - dst = os.path.join(self.faces.folder, new_file) - logger.trace("Renaming: '%s' to '%s'", old_file, new_file) - os.rename(src, dst) - if action == "final": - rename_count += 1 - logger.verbose("Renamed '%s' to '%s'", old, new) - return rename_count - - def check_multi_hashes(self, faces, frame, idx): - """ Check filenames for where multiple faces have the - same hash (e.g. for freeze frames) """ - frame_idx = "{}_{}".format(frame, idx) - for face_name, extension in faces: - if (face_name, extension) in self.seen_multihash: - # Don't return a filename that has already been processed - continue - if face_name == frame_idx: - # If a matching filename already exists return that - self.seen_multihash.add((face_name, extension)) - return face_name, extension - if face_name.startswith(frame): - # If a matching framename already exists return that - self.seen_multihash.add((face_name, extension)) - return face_name, extension - # If no matches, just pop the first filename - face_name, extension = faces[0] - self.seen_multihash.add((face_name, extension)) - return face_name, extension - - -class Sort(): - """ Sort alignments' index by the order they appear in - an image """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.alignments = alignments - self.axis = arguments.job.replace("sort-", "") - self.faces = self.get_faces(arguments) - logger.debug("Initialized %s", self.__class__.__name__) - - def get_faces(self, arguments): - """ If faces argument is specified, load faces_dir - otherwise return None """ - if not hasattr(arguments, "faces_dir") or not arguments.faces_dir: - return None - faces = Faces(arguments.faces_dir) - legacy = Legacy(self.alignments, None, faces=faces, child_process=True) - legacy.process() - return faces - - def process(self): - """ Execute the sort process """ - logger.info("[SORT INDEXES]") # Tidy up cli output - self.check_legacy() - reindexed = self.reindex_faces() - if reindexed: - self.alignments.save() - if self.faces: - rename = Rename(self.alignments, None, self.faces) - rename.process() - - def check_legacy(self): - """ Legacy rotated alignments will not have the correct x, y - positions. Faces without hashes won't process. - Check for these and generate a warning and exit """ - rotated = self.alignments.get_legacy_rotation() - hashes = self.alignments.get_legacy_no_hashes() - if rotated or hashes: - logger.error("Legacy alignments found. Sort cannot continue. You should run legacy " - "tool to update the file prior to running sort: 'python tools.py " - "alignments -j legacy -a -fr -fc " - "'") - exit(0) - - def reindex_faces(self): - """ Re-Index the faces """ - reindexed = 0 - for alignment in tqdm(self.alignments.yield_faces(), - desc="Sort alignment indexes", total=self.alignments.frames_count): - frame, alignments, count, key = alignment - if count <= 1: - logger.trace("0 or 1 face in frame. Not sorting: '%s'", frame) - continue - sorted_alignments = sorted([item for item in alignments], key=lambda x: (x[self.axis])) - if sorted_alignments == alignments: - logger.trace("Alignments already in correct order. Not sorting: '%s'", frame) - continue - logger.trace("Sorting alignments for frame: '%s'", frame) - self.alignments.data[key] = sorted_alignments - reindexed += 1 - logger.info("%s Frames had their faces reindexed", reindexed) - return reindexed - - -class Spatial(): - """ Apply spatial temporal filtering to landmarks - Adapted from: - https://www.kaggle.com/selfishgene/animating-and-smoothing-3d-facial-keypoints/notebook """ - - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.arguments = arguments - self.alignments = alignments - self.mappings = dict() - self.normalized = dict() - self.shapes_model = None - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Perform spatial filtering """ - logger.info("[SPATIO-TEMPORAL FILTERING]") # Tidy up cli output - logger.info("NB: The process only processes the alignments for the first " - "face it finds for any given frame. For best results only run this when " - "there is only a single face in the alignments file and all false positives " - "have been removed") - - self.normalize() - self.shape_model() - landmarks = self.spatially_filter() - landmarks = self.temporally_smooth(landmarks) - self.update_alignments(landmarks) - self.alignments.save() - - logger.info("Done! To re-extract faces run: python tools.py " - "alignments -j extract -a %s -fr -fc " - "", self.arguments.alignments_file) - - # Define shape normalization utility functions - @staticmethod - def normalize_shapes(shapes_im_coords): - """ Normalize a 2D or 3D shape """ - logger.debug("Normalize shapes") - (num_pts, num_dims, _) = shapes_im_coords.shape - - # Calculate mean coordinates and subtract from shapes - mean_coords = shapes_im_coords.mean(axis=0) - shapes_centered = np.zeros(shapes_im_coords.shape) - shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1]) - - # Calculate scale factors and divide shapes - scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0) - shapes_normalized = np.zeros(shapes_centered.shape) - shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1]) - - logger.debug("Normalized shapes: (shapes_normalized: %s, scale_factors: %s, mean_coords: " - "%s", shapes_normalized, scale_factors, mean_coords) - return shapes_normalized, scale_factors, mean_coords - - @staticmethod - def normalized_to_original(shapes_normalized, scale_factors, mean_coords): - """ Transform a normalized shape back to original image coordinates """ - logger.debug("Normalize to original") - (num_pts, num_dims, _) = shapes_normalized.shape - - # move back to the correct scale - shapes_centered = shapes_normalized * np.tile(scale_factors, [num_pts, num_dims, 1]) - # move back to the correct location - shapes_im_coords = shapes_centered + np.tile(mean_coords, [num_pts, 1, 1]) - - logger.debug("Normalized to original: %s", shapes_im_coords) - return shapes_im_coords - - def normalize(self): - """ Compile all original and normalized alignments """ - logger.debug("Normalize") - count = sum(1 for val in self.alignments.data.values() if val) - landmarks_all = np.zeros((68, 2, int(count))) - - end = 0 - for key in tqdm(sorted(self.alignments.data.keys()), desc="Compiling"): - val = self.alignments.data[key] - if not val: - continue - # We should only be normalizing a single face, so just take - # the first landmarks found - landmarks = np.array(val[0]["landmarksXY"]).reshape(68, 2, 1) - start = end - end = start + landmarks.shape[2] - # Store in one big array - landmarks_all[:, :, start:end] = landmarks - # Make sure we keep track of the mapping to the original frame - self.mappings[start] = key - - # Normalize shapes - normalized_shape = self.normalize_shapes(landmarks_all) - self.normalized["landmarks"] = normalized_shape[0] - self.normalized["scale_factors"] = normalized_shape[1] - self.normalized["mean_coords"] = normalized_shape[2] - logger.debug("Normalized: %s", self.normalized) - - def shape_model(self): - """ build 2D shape model """ - logger.debug("Shape model") - landmarks_norm = self.normalized["landmarks"] - num_components = 20 - normalized_shapes_tbl = np.reshape(landmarks_norm, [68*2, landmarks_norm.shape[2]]).T - self.shapes_model = decomposition.PCA(n_components=num_components, - whiten=True, - random_state=1).fit(normalized_shapes_tbl) - explained = self.shapes_model.explained_variance_ratio_.sum() - logger.info("Total explained percent by PCA model with %s components is %s%%", - num_components, round(100 * explained, 1)) - logger.debug("Shaped model") - - def spatially_filter(self): - """ interpret the shapes using our shape model - (project and reconstruct) """ - logger.debug("Spatially Filter") - landmarks_norm = self.normalized["landmarks"] - # Convert to matrix form - landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T - # Project onto shapes model and reconstruct - landmarks_norm_table_rec = self.shapes_model.inverse_transform( - self.shapes_model.transform(landmarks_norm_table)) - # Convert back to shapes (numKeypoint, num_dims, numFrames) - landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T, - [68, 2, landmarks_norm.shape[2]]) - # Transform back to image coords - retval = self.normalized_to_original(landmarks_norm_rec, - self.normalized["scale_factors"], - self.normalized["mean_coords"]) - - logger.debug("Spatially Filtered: %s", retval) - return retval - - @staticmethod - def temporally_smooth(landmarks): - """ apply temporal filtering on the 2D points """ - logger.debug("Temporally Smooth") - filter_half_length = 2 - temporal_filter = np.ones((1, 1, 2 * filter_half_length + 1)) - temporal_filter = temporal_filter / temporal_filter.sum() - - start_tileblock = np.tile(landmarks[:, :, 0][:, :, np.newaxis], [1, 1, filter_half_length]) - end_tileblock = np.tile(landmarks[:, :, -1][:, :, np.newaxis], [1, 1, filter_half_length]) - landmarks_padded = np.dstack((start_tileblock, landmarks, end_tileblock)) - - retval = signal.convolve(landmarks_padded, temporal_filter, mode='valid', method='fft') - logger.debug("Temporally Smoothed: %s", retval) - return retval - - def update_alignments(self, landmarks): - """ Update smoothed landmarks back to alignments """ - logger.debug("Update alignments") - for idx, frame in tqdm(self.mappings.items(), desc="Updating"): - logger.trace("Updating: (frame: %s)", frame) - landmarks_update = landmarks[:, :, idx].astype(int) - landmarks_xy = landmarks_update.reshape(68, 2).tolist() - self.alignments.data[frame][0]["landmarksXY"] = landmarks_xy - logger.trace("Updated: (frame: '%s', landmarks: %s)", frame, landmarks_xy) - logger.debug("Updated alignments") - - -class UpdateHashes(): - """ Update hashes in an alignments file """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) - self.alignments = alignments - self.faces = Faces(arguments.faces_dir).file_list_sorted - self.face_hashes = dict() - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Update Face Hashes to the alignments file """ - logger.info("[UPDATE FACE HASHES]") # Tidy up cli output - self.get_hashes() - updated = self.update_hashes() - if updated == 0: - logger.info("No hashes were updated. Exiting") - return - self.alignments.save() - logger.info("%s frame(s) had their face hashes updated.", updated) - - def get_hashes(self): - """ Read the face hashes from the faces """ - logger.info("Getting original filenames, indexes and hashes...") - for face in self.faces: - filename = face["face_name"] - extension = face["face_extension"] - if "_" not in face["face_name"]: - logger.warning("Unable to determine index of file. Skipping: '%s'", filename) - continue - index = filename[filename.rfind("_") + 1:] - if not index.isdigit(): - logger.warning("Unable to determine index of file. Skipping: '%s'", filename) - continue - orig_frame = filename[:filename.rfind("_")] + extension - self.face_hashes.setdefault(orig_frame, dict())[int(index)] = face["face_hash"] - - def update_hashes(self): - """ Update hashes to alignments """ - logger.info("Updating hashes to alignments...") - updated = 0 - for frame, hashes in self.face_hashes.items(): - if not self.alignments.frame_exists(frame): - logger.warning("Frame not found in alignments file. Skipping: '%s'", frame) - continue - if not self.alignments.frame_has_faces(frame): - logger.warning("Frame does not have faces. Skipping: '%s'", frame) - continue - existing = [face.get("hash", None) - for face in self.alignments.get_faces_in_frame(frame)] - if any(hsh not in existing for hsh in list(hashes.values())): - self.alignments.add_face_hashes(frame, hashes) - updated += 1 - return updated diff --git a/tools/lib_alignments/jobs_manual.py b/tools/lib_alignments/jobs_manual.py deleted file mode 100644 index 3a5ba010c5..0000000000 --- a/tools/lib_alignments/jobs_manual.py +++ /dev/null @@ -1,993 +0,0 @@ -#!/usr/bin/env python3 -""" Manual processing of alignments """ - -import logging -import platform -import sys -import cv2 -import numpy as np - -from lib.multithreading import SpawnProcess -from lib.queue_manager import queue_manager, QueueEmpty -from plugins.plugin_loader import PluginLoader -from . import Annotate, ExtractedFaces, Frames, Legacy - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Interface(): - """ Key controls and interfacing options for OpenCV """ - def __init__(self, alignments, frames): - logger.debug("Initializing %s: (alignments: %s, frames: %s)", - self.__class__.__name__, alignments, frames) - self.alignments = alignments - self.frames = frames - self.controls = self.set_controls() - self.state = self.set_state() - self.skip_mode = {1: "Standard", - 2: "No Faces", - 3: "Multi-Faces", - 4: "Has Faces"} - logger.debug("Initialized %s", self.__class__.__name__) - - def set_controls(self): - """ Set keyboard controls, destination and help text """ - controls = {"z": {"action": self.iterate_frame, - "args": ("navigation", - 1), - "help": "Previous Frame"}, - "x": {"action": self.iterate_frame, - "args": ("navigation", 1), - "help": "Next Frame"}, - "[": {"action": self.iterate_frame, - "args": ("navigation", - 100), - "help": "100 Frames Back"}, - "]": {"action": self.iterate_frame, - "args": ("navigation", 100), - "help": "100 Frames Forward"}, - "{": {"action": self.iterate_frame, - "args": ("navigation", "first"), - "help": "Go to First Frame"}, - "}": {"action": self.iterate_frame, - "args": ("navigation", "last"), - "help": "Go to Last Frame"}, - 27: {"action": "quit", - "key_text": "ESC", - "args": ("navigation", None), - "help": "Exit", - "key_type": ord}, - "/": {"action": self.iterate_state, - "args": ("navigation", "frame-size"), - "help": "Cycle Frame Zoom"}, - "s": {"action": self.iterate_state, - "args": ("navigation", "skip-mode"), - "help": ("Skip Mode (All, No Faces, Multi Faces, Has Faces)")}, - " ": {"action": self.save_alignments, - "key_text": "SPACE", - "args": ("edit", None), - "help": "Save Alignments"}, - "r": {"action": self.reload_alignments, - "args": ("edit", None), - "help": "Reload Alignments (Discard all changes)"}, - "d": {"action": self.delete_alignment, - "args": ("edit", None), - "help": "Delete Selected Alignment"}, - "m": {"action": self.toggle_state, - "args": ("edit", "active"), - "help": "Change Mode (View, Edit)"}, - range(10): {"action": self.set_state_value, - "key_text": "0 to 9", - "args": ["edit", "selected"], - "help": "Select/Deselect Face at this Index", - "key_type": range}, - "c": {"action": self.copy_alignments, - "args": ("edit", -1), - "help": "Copy Previous Frame's Alignments"}, - "v": {"action": self.copy_alignments, - "args": ("edit", 1), - "help": "Copy Next Frame's Alignments"}, - "y": {"action": self.toggle_state, - "args": ("image", "display"), - "help": "Toggle Image"}, - "u": {"action": self.iterate_state, - "args": ("bounding_box", "color"), - "help": "Cycle Bounding Box Color"}, - "i": {"action": self.iterate_state, - "args": ("extract_box", "color"), - "help": "Cycle Extract Box Color"}, - "o": {"action": self.iterate_state, - "args": ("landmarks", "color"), - "help": "Cycle Landmarks Color"}, - "p": {"action": self.iterate_state, - "args": ("landmarks_mesh", "color"), - "help": "Cycle Landmarks Mesh Color"}, - "h": {"action": self.iterate_state, - "args": ("bounding_box", "size"), - "help": "Cycle Bounding Box thickness"}, - "j": {"action": self.iterate_state, - "args": ("extract_box", "size"), - "help": "Cycle Extract Box thickness"}, - "k": {"action": self.iterate_state, - "args": ("landmarks", "size"), - "help": "Cycle Landmarks - point size"}, - "l": {"action": self.iterate_state, - "args": ("landmarks_mesh", "size"), - "help": "Cycle Landmarks Mesh - thickness"}} - - logger.debug("Controls: %s", controls) - return controls - - @staticmethod - def set_state(): - """ Set the initial display state """ - state = {"bounding_box": dict(), - "extract_box": dict(), - "landmarks": dict(), - "landmarks_mesh": dict(), - "image": dict(), - "navigation": {"skip-mode": 1, - "frame-size": 1, - "frame_idx": 0, - "max_frame": 0, - "last_request": 0, - "frame_name": None}, - "edit": {"updated": False, - "update_faces": False, - "selected": None, - "active": 0, - "redraw": False}} - - # See lib_alignments/annotate.py for color mapping - color = 0 - for key in sorted(state.keys()): - if key not in ("bounding_box", "extract_box", "landmarks", "landmarks_mesh", "image"): - continue - state[key]["display"] = True - if key == "image": - continue - color += 1 - state[key]["size"] = 1 - state[key]["color"] = color - logger.debug("State: %s", state) - return state - - def save_alignments(self, *args): # pylint: disable=unused-argument - """ Save alignments """ - logger.debug("Saving Alignments") - if not self.state["edit"]["updated"]: - logger.debug("Save received, but state not updated. Not saving") - return - self.alignments.save() - self.state["edit"]["updated"] = False - self.set_redraw(True) - - def reload_alignments(self, *args): # pylint: disable=unused-argument - """ Reload alignments """ - logger.debug("Reloading Alignments") - if not self.state["edit"]["updated"]: - logger.debug("Reload received, but state not updated. Not reloading") - return - self.alignments.reload() - self.state["edit"]["updated"] = False - self.state["edit"]["update_faces"] = True - self.set_redraw(True) - - def delete_alignment(self, *args): # pylint: disable=unused-argument - """ Save alignments """ - logger.debug("Deleting Alignments") - selected_face = self.get_selected_face_id() - if self.get_edit_mode() == "View" or selected_face is None: - logger.debug("Delete received, but edit mode is 'View'. Not deleting") - return - frame = self.get_frame_name() - if self.alignments.delete_face_at_index(frame, selected_face): - self.state["edit"]["selected"] = None - self.state["edit"]["updated"] = True - self.state["edit"]["update_faces"] = True - self.set_redraw(True) - - def copy_alignments(self, *args): - """ Copy the alignments from the previous or next frame - to the current frame """ - logger.debug("Copying Alignments") - if self.get_edit_mode() != "Edit": - logger.debug("Copy received, but edit mode is not 'Edit'. Not copying") - return - frame_id = self.state["navigation"]["frame_idx"] + args[1] - if not 0 <= frame_id <= self.state["navigation"]["max_frame"]: - return - current_frame = self.get_frame_name() - get_frame = self.frames.file_list_sorted[frame_id]["frame_fullname"] - alignments = self.alignments.get_faces_in_frame(get_frame) - for alignment in alignments: - self.alignments. add_face(current_frame, alignment) - self.state["edit"]["updated"] = True - self.state["edit"]["update_faces"] = True - self.set_redraw(True) - - def toggle_state(self, item, category): - """ Toggle state of requested item """ - logger.debug("Toggling state: (item: %s, category: %s)", item, category) - self.state[item][category] = not self.state[item][category] - logger.debug("State toggled: (item: %s, category: %s, value: %s)", - item, category, self.state[item][category]) - self.set_redraw(True) - - def iterate_state(self, item, category): - """ Cycle through options (6 possible or 3 currently supported) """ - logger.debug("Cycling state: (item: %s, category: %s)", item, category) - if category == "color": - max_val = 7 - elif category == "frame-size": - max_val = 6 - elif category == "skip-mode": - max_val = 4 - else: - max_val = 3 - val = self.state[item][category] - val = val + 1 if val != max_val else 1 - self.state[item][category] = val - logger.debug("Cycled state: (item: %s, category: %s, value: %s)", - item, category, self.state[item][category]) - self.set_redraw(True) - - def set_state_value(self, item, category, value): - """ Set state of requested item or toggle off """ - logger.debug("Setting state value: (item: %s, category: %s, value: %s)", - item, category, value) - state = self.state[item][category] - value = str(value) if value is not None else value - if state == value: - self.state[item][category] = None - else: - self.state[item][category] = value - logger.debug("Setting state value: (item: %s, category: %s, value: %s)", - item, category, self.state[item][category]) - self.set_redraw(True) - - def iterate_frame(self, *args): - """ Iterate frame up or down, stopping at either end """ - logger.debug("Iterating frame: (args: %s)", args) - iteration = args[1] - max_frame = self.state["navigation"]["max_frame"] - if iteration in ("first", "last"): - next_frame = 0 if iteration == "first" else max_frame - self.state["navigation"]["frame_idx"] = next_frame - self.state["navigation"]["last_request"] = 0 - self.set_redraw(True) - return - - current_frame = self.state["navigation"]["frame_idx"] - next_frame = current_frame + iteration - end = 0 if iteration < 0 else max_frame - if (max_frame == 0 or - (end > 0 and next_frame >= end) or - (end == 0 and next_frame <= end)): - next_frame = end - self.state["navigation"]["frame_idx"] = next_frame - self.state["navigation"]["last_request"] = iteration - self.set_state_value("edit", "selected", None) - - def get_color(self, item): - """ Return color for selected item """ - return self.state[item]["color"] - - def get_size(self, item): - """ Return size for selected item """ - return self.state[item]["size"] - - def get_frame_scaling(self): - """ Return frame scaling factor for requested item """ - factors = (1, 1.25, 1.5, 2, 0.5, 0.75) - idx = self.state["navigation"]["frame-size"] - 1 - return factors[idx] - - def get_edit_mode(self): - """ Return text version and border color for edit mode """ - if self.state["edit"]["active"]: - return "Edit" - return "View" - - def get_skip_mode(self): - """ Return text version of skip mode """ - return self.skip_mode[self.state["navigation"]["skip-mode"]] - - def get_state_color(self): - """ Return a color based on current state - white - View Mode - yellow - Edit Mode - red - Unsaved alignments """ - color = (255, 255, 255) - if self.state["edit"]["updated"]: - color = (0, 0, 255) - elif self.state["edit"]["active"]: - color = (0, 255, 255) - return color - - def get_frame_name(self): - """ Return the current frame number """ - return self.state["navigation"]["frame_name"] - - def get_selected_face_id(self): - """ Return the index of the currently selected face """ - try: - return int(self.state["edit"]["selected"]) - except TypeError: - return None - - def redraw(self): - """ Return whether a redraw is required """ - return self.state["edit"]["redraw"] - - def set_redraw(self, request): - """ Turn redraw requirement on or off """ - self.state["edit"]["redraw"] = request - - -class Help(): - """ Generate and display help in cli and in window """ - def __init__(self, interface): - logger.debug("Initializing %s: (interface: %s)", self.__class__.__name__, interface) - self.interface = interface - self.helptext = self.generate() - logger.debug("Initialized %s", self.__class__.__name__) - - def generate(self): - """ Generate help output """ - logger.debug("Generating help") - sections = ("navigation", "display", "color", "size", "edit") - helpout = {section: list() for section in sections} - helptext = "" - for key, val in self.interface.controls.items(): - logger.trace("Generating help for:(key: '%s', val: '%s'", key, val) - help_section = val["args"][0] - if help_section not in ("navigation", "edit"): - help_section = val["args"][1] - key_text = val.get("key_text", None) - key_text = key_text if key_text else key - logger.trace("Adding help for:(section: '%s', val: '%s', text: '%s'", - help_section, val["help"], key_text) - helpout[help_section].append((val["help"], key_text)) - - helpout["edit"].append(("Bounding Box - Move", "Left Click")) - helpout["edit"].append(("Bounding Box - Resize", "Middle Click")) - - for section in sections: - spacer = "=" * int((40 - len(section)) / 2) - display = "\n{} {} {}\n".format(spacer, section.upper(), spacer) - helpsection = sorted(helpout[section]) - if section == "navigation": - helpsection = sorted(helpout[section], reverse=True) - display += "\n".join(" - '{}': {}".format(item[1], item[0]) - for item in helpsection) - - helptext += display - logger.debug("Added helptext: '%s'", helptext) - return helptext - - def render(self): - """ Render help text to image window """ - # pylint: disable=no-member - logger.trace("Rendering help text") - image = self.background() - display_text = self.helptext + self.compile_status() - self.text_to_image(image, display_text) - cv2.namedWindow("Help") - cv2.imshow("Help", image) - logger.trace("Rendered help text") - - def background(self): - """ Create an image to hold help text """ - # pylint: disable=no-member - logger.trace("Creating help text canvas") - height = 880 - width = 480 - image = np.zeros((height, width, 3), np.uint8) - color = self.interface.get_state_color() - cv2.rectangle(image, (0, 0), (width - 1, height - 1), color, 2) - logger.trace("Created help text canvas") - return image - - def compile_status(self): - """ Render the status text """ - logger.trace("Compiling Status text") - status = "\n=== STATUS\n" - navigation = self.interface.state["navigation"] - frame_scale = int(self.interface.get_frame_scaling() * 100) - status += " File: {}\n".format(self.interface.get_frame_name()) - status += " Frame: {} / {}\n".format( - navigation["frame_idx"] + 1, navigation["max_frame"] + 1) - status += " Frame Size: {}%\n".format(frame_scale) - status += " Skip Mode: {}\n".format(self.interface.get_skip_mode()) - status += " View Mode: {}\n".format(self.interface.get_edit_mode()) - if self.interface.get_selected_face_id() is not None: - status += " Selected Face Index: {}\n".format(self.interface.get_selected_face_id()) - if self.interface.state["edit"]["updated"]: - status += " Warning: There are unsaved changes\n" - - logger.trace("Compiled Status text") - return status - - @staticmethod - def text_to_image(image, display_text): - """ Write out and format help text to image """ - # pylint: disable=no-member - logger.trace("Converting help text to image") - pos_y = 0 - for line in display_text.split("\n"): - if line.startswith("==="): - pos_y += 10 - line = line.replace("=", "").strip() - line = line.replace("- '", "[ ").replace("':", " ]") - cv2.putText(image, line, (20, pos_y), - cv2.FONT_HERSHEY_SIMPLEX, 0.43, (255, 255, 255), 1) - pos_y += 20 - logger.trace("Converted help text to image") - - -class Manual(): - """ Manually adjust or create landmarks data """ - def __init__(self, alignments, arguments): - logger.debug("Initializing %s: (alignments: %s, arguments: %s)", - self.__class__.__name__, alignments, arguments) - self.arguments = arguments - self.alignments = alignments - self.align_eyes = arguments.align_eyes - self.frames = Frames(arguments.frames_dir) - self.extracted_faces = None - self.interface = None - self.help = None - self.mouse_handler = None - logger.debug("Initialized %s", self.__class__.__name__) - - def process(self): - """ Process manual extraction """ - legacy = Legacy(self.alignments, self.arguments, - frames=self.frames, child_process=True) - legacy.process() - - logger.info("[MANUAL PROCESSING]") # Tidy up cli output - self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256, - align_eyes=self.align_eyes) - self.interface = Interface(self.alignments, self.frames) - self.help = Help(self.interface) - self.mouse_handler = MouseHandler(self.interface, self.arguments.loglevel) - - print(self.help.helptext) - max_idx = self.frames.count - 1 - self.interface.state["navigation"]["max_frame"] = max_idx - self.display_frames() - - def display_frames(self): - """ Iterate through frames """ - # pylint: disable=no-member - logger.debug("Display frames") - is_windows = True if platform.system() == "Windows" else False - is_conda = True if "conda" in sys.version.lower() else False - logger.debug("is_windows: %s, is_conda: %s", is_windows, is_conda) - cv2.namedWindow("Frame") - cv2.namedWindow("Faces") - cv2.setMouseCallback('Frame', self.mouse_handler.on_event) - - frame, faces = self.get_frame() - press = self.get_keys() - - while True: - self.help.render() - cv2.imshow("Frame", frame) - cv2.imshow("Faces", faces) - key = cv2.waitKey(1) - - if self.window_closed(is_windows, is_conda, key): - queue_manager.terminate_queues() - break - - if key: - logger.trace("Keypress received: '%s'", key) - if key in press.keys(): - action = press[key]["action"] - logger.debug("Keypress action: key: ('%s', action: '%s')", key, action) - if action == "quit": - break - - if press[key].get("key_type") == range: - args = press[key]["args"] + [chr(key)] - else: - args = press[key]["args"] - action(*args) - - if not self.interface.redraw(): - continue - - logger.trace("Redraw requested") - frame, faces = self.get_frame() - self.interface.set_redraw(False) - - cv2.destroyAllWindows() - - def window_closed(self, is_windows, is_conda, key): - """ Check whether the window has been closed - - MS Windows doesn't appear to read the window state property - properly, so we check for a negative key press. - - Conda (tested on Windows) doesn't appear to read the window - state property or negative key press properly, so we arbitrarily - use another property """ - # pylint: disable=no-member - logger.trace("Commencing closed window check") - closed = False - prop_autosize = cv2.getWindowProperty('Frame', cv2.WND_PROP_AUTOSIZE) - prop_visible = cv2.getWindowProperty('Frame', cv2.WND_PROP_VISIBLE) - if self.arguments.disable_monitor: - closed = False - elif is_conda and prop_autosize < 1: - closed = True - elif is_windows and not is_conda and key == -1: - closed = True - elif not is_windows and not is_conda and prop_visible < 1: - closed = True - logger.trace("Completed closed window check. Closed is %s", closed) - if closed: - logger.debug("Window closed detected") - return closed - - def get_keys(self): - """ Convert keys dict into something useful - for OpenCV """ - keys = dict() - for key, val in self.interface.controls.items(): - if val.get("key_type", str) == range: - for range_key in key: - keys[ord(str(range_key))] = val - elif val.get("key_type", str) == ord: - keys[key] = val - else: - keys[ord(key)] = val - - return keys - - def get_frame(self): - """ Compile the frame and get faces """ - image = self.frame_selector() - frame_name = self.interface.get_frame_name() - logger.debug("Frame Name: '%s'", frame_name) - alignments = self.alignments.get_faces_in_frame(frame_name) - faces_updated = self.interface.state["edit"]["update_faces"] - logger.debug("Faces Updated: %s", faces_updated) - self.extracted_faces.get_faces(frame_name) - roi = [face.original_roi for face in self.extracted_faces.faces] - - if faces_updated: - self.interface.state["edit"]["update_faces"] = False - - frame = FrameDisplay(image, alignments, roi, self.interface).image - faces = self.set_faces(frame_name).image - return frame, faces - - def frame_selector(self): - """ Return frame at given index """ - navigation = self.interface.state["navigation"] - frame_list = self.frames.file_list_sorted - frame = frame_list[navigation["frame_idx"]]["frame_fullname"] - skip_mode = self.interface.get_skip_mode().lower() - logger.debug("navigation: %s, frame: '%s', skip_mode: '%s'", navigation, frame, skip_mode) - - while True: - if navigation["last_request"] == 0: - break - elif navigation["frame_idx"] in (0, navigation["max_frame"]): - break - elif skip_mode == "standard": - break - elif (skip_mode == "no faces" - and not self.alignments.frame_has_faces(frame)): - break - elif (skip_mode == "multi-faces" - and self.alignments.frame_has_multiple_faces(frame)): - break - elif (skip_mode == "has faces" - and self.alignments.frame_has_faces(frame)): - break - else: - self.interface.iterate_frame("navigation", - navigation["last_request"]) - frame = frame_list[navigation["frame_idx"]]["frame_fullname"] - - image = self.frames.load_image(frame) - navigation["last_request"] = 0 - navigation["frame_name"] = frame - return image - - def set_faces(self, frame): - """ Pass the current frame faces to faces window """ - faces = self.extracted_faces.get_faces_in_frame(frame) - landmarks = [{"landmarksXY": face.aligned_landmarks} - for face in self.extracted_faces.faces] - return FacesDisplay(faces, landmarks, self.extracted_faces.size, self.interface) - - -class FrameDisplay(): - """" Window that holds the frame """ - def __init__(self, image, alignments, roi, interface): - logger.trace("Initializing %s: (alignments: %s, roi: %s, interface: %s)", - self.__class__.__name__, alignments, roi, interface) - self.image = image - self.roi = roi - self.alignments = alignments - self.interface = interface - self.annotate_frame() - logger.trace("Initialized %s", self.__class__.__name__) - - def annotate_frame(self): - """ Annotate the frame """ - state = self.interface.state - logger.trace("State: %s", state) - annotate = Annotate(self.image, self.alignments, self.roi) - if not state["image"]["display"]: - annotate.draw_black_image() - - for item in ("bounding_box", "extract_box", "landmarks", "landmarks_mesh"): - color = self.interface.get_color(item) - size = self.interface.get_size(item) - state[item]["display"] = False if color == 7 else True - if not state[item]["display"]: - continue - logger.trace("Annotating: '%s'", item) - annotation = getattr(annotate, "draw_{}".format(item)) - annotation(color, size) - - selected_face = self.interface.get_selected_face_id() - if (selected_face is not None and - int(selected_face) < len(self.alignments)): - annotate.draw_grey_out_faces(selected_face) - - self.image = self.resize_frame(annotate.image) - - def resize_frame(self, image): - """ Set the displayed frame size and add state border""" - # pylint: disable=no-member - logger.trace("Resizing frame") - height, width = image.shape[:2] - color = self.interface.get_state_color() - cv2.rectangle(image, (0, 0), (width - 1, height - 1), color, 1) - scaling = self.interface.get_frame_scaling() - image = cv2.resize(image, (0, 0), fx=scaling, fy=scaling) - logger.trace("Resized frame") - return image - - -class FacesDisplay(): - """ Window that holds faces thumbnail """ - def __init__(self, extracted_faces, landmarks, size, interface): - logger.trace("Initializing %s: (extracted_faces: %s, landmarks: %s, size: %s, " - "interface: %s)", self.__class__.__name__, extracted_faces, - landmarks, size, interface) - self.row_length = 4 - self.faces = self.copy_faces(extracted_faces) - self.roi = self.set_full_roi(size) - self.landmarks = landmarks - self.interface = interface - - self.annotate_faces() - - self.image = self.build_faces_image(size) - logger.trace("Initialized %s", self.__class__.__name__) - - @staticmethod - def copy_faces(faces): - """ Copy the extracted faces so as not to save the annotations back """ - return [face.aligned_face.copy() for face in faces] - - @staticmethod - def set_full_roi(size): - """ ROI is the full frame for faces, so set based on size """ - return [np.array([[(0, 0), (0, size - 1), (size - 1, size - 1), (size - 1, 0)]], np.int32)] - - def annotate_faces(self): - """ Annotate each of the faces """ - state = self.interface.state - selected_face = self.interface.get_selected_face_id() - logger.trace("State: %s, Selected Face ID: %s", state, selected_face) - for idx, face in enumerate(self.faces): - annotate = Annotate(face, [self.landmarks[idx]], self.roi) - if not state["image"]["display"]: - annotate.draw_black_image() - - for item in ("landmarks", "landmarks_mesh"): - if not state[item]["display"]: - continue - logger.trace("Annotating: '%s'", item) - color = self.interface.get_color(item) - size = self.interface.get_size(item) - annotation = getattr(annotate, "draw_{}".format(item)) - annotation(color, size) - - if (selected_face is not None - and int(selected_face) < len(self.faces) - and int(selected_face) != idx): - annotate.draw_grey_out_faces(1) - - self.faces[idx] = annotate.image - - def build_faces_image(self, size): - """ Display associated faces """ - total_faces = len(self.faces) - logger.trace("Building faces panel. (total_faces: %s", total_faces) - if not total_faces: - logger.trace("Returning empty row") - image = self.build_faces_row(list(), size) - return image - total_rows = int(total_faces / self.row_length) + 1 - for idx in range(total_rows): - logger.trace("Building row %s", idx) - face_idx = idx * self.row_length - row_faces = self.faces[face_idx:face_idx + self.row_length] - if not row_faces: - break - row = self.build_faces_row(row_faces, size) - image = row if idx == 0 else np.concatenate((image, row), axis=0) - return image - - def build_faces_row(self, faces, size): - """ Build a row of 4 faces """ - # pylint: disable=no-member - logger.trace("Building row for %s faces", len(faces)) - if len(faces) != 4: - remainder = 4 - (len(faces) % self.row_length) - for _ in range(remainder): - faces.append(np.zeros((size, size, 3), np.uint8)) - for idx, face in enumerate(faces): - color = self.interface.get_state_color() - cv2.rectangle(face, (0, 0), (size - 1, size - 1), - color, 1) - if idx == 0: - row = face - else: - row = np.concatenate((row, face), axis=1) - return row - - -class MouseHandler(): - """ Manual Extraction """ - def __init__(self, interface, loglevel): - logger.debug("Initializing %s: (interface: %s)", self.__class__.__name__, interface) - self.interface = interface - self.alignments = interface.alignments - self.frames = interface.frames - - self.extractor = dict() - self.init_extractor(loglevel) - - self.mouse_state = None - self.last_move = None - self.center = None - self.dims = None - self.media = {"frame_id": None, - "image": None, - "bounding_box": list(), - "bounding_last": list(), - "bounding_box_orig": list()} - logger.debug("Initialized %s", self.__class__.__name__) - - def init_extractor(self, loglevel): - """ Initialize FAN """ - logger.debug("Initialize Extractor") - out_queue = queue_manager.get_queue("out") - - d_kwargs = {"in_queue": queue_manager.get_queue("in"), - "out_queue": queue_manager.get_queue("align")} - a_kwargs = {"in_queue": queue_manager.get_queue("align"), - "out_queue": out_queue} - - detector = PluginLoader.get_detector("manual")(loglevel=loglevel) - detect_process = SpawnProcess(detector.run, **d_kwargs) - d_event = detect_process.event - detect_process.start() - - for plugin in ("fan", "dlib"): - aligner = PluginLoader.get_aligner(plugin)(loglevel=loglevel) - align_process = SpawnProcess(aligner.run, **a_kwargs) - a_event = align_process.event - align_process.start() - - # Wait for Aligner to initialize - # The first ever load of the model for FAN has reportedly taken - # up to 3-4 minutes, hence high timeout. - a_event.wait(300) - if not a_event.is_set(): - if plugin == "fan": - align_process.join() - logger.error("Error initializing FAN. Trying Dlib") - continue - else: - raise ValueError("Error inititalizing Aligner") - if plugin == "dlib": - break - - try: - err = None - err = out_queue.get(True, 1) - except QueueEmpty: - pass - if not err: - break - align_process.join() - logger.error("Error initializing FAN. Trying Dlib") - - d_event.wait(10) - if not d_event.is_set(): - raise ValueError("Error inititalizing Detector") - - self.extractor["detect"] = detector - self.extractor["align"] = aligner - logger.debug("Initialized Extractor") - - def on_event(self, event, x, y, flags, param): # pylint: disable=unused-argument,invalid-name - """ Handle the mouse events """ - # pylint: disable=no-member - if self.interface.get_edit_mode() != "Edit": - return - logger.trace("Mouse event: (event: %s, x: %s, y: %s, flags: %s, param: %s", - event, x, y, flags, param) - if not self.mouse_state and event not in (cv2.EVENT_LBUTTONDOWN, cv2.EVENT_MBUTTONDOWN): - return - - self.initialize() - - if event in (cv2.EVENT_LBUTTONUP, cv2.EVENT_MBUTTONUP): - self.mouse_state = None - self.last_move = None - elif event == cv2.EVENT_LBUTTONDOWN: - self.mouse_state = "left" - self.set_bounding_box(x, y) - elif event == cv2.EVENT_MBUTTONDOWN: - self.mouse_state = "middle" - self.set_bounding_box(x, y) - elif event == cv2.EVENT_MOUSEMOVE: - if self.mouse_state == "left": - self.move_bounding_box(x, y) - elif self.mouse_state == "middle": - self.resize_bounding_box(x, y) - - def initialize(self): - """ Update changed parameters """ - frame = self.interface.get_frame_name() - if frame == self.media["frame_id"]: - return - logger.debug("Initialize frame: '%s'", frame) - self.media["frame_id"] = frame - self.media["image"] = self.frames.load_image(frame) - self.dims = None - self.center = None - self.last_move = None - self.mouse_state = None - self.media["bounding_box"] = list() - self.media["bounding_box_orig"] = list() - - def set_bounding_box(self, pt_x, pt_y): - """ Select or create bounding box """ - if self.interface.get_selected_face_id() is None: - self.check_click_location(pt_x, pt_y) - - if self.interface.get_selected_face_id() is not None: - self.dims_from_alignment() - else: - self.dims_from_image() - - self.move_bounding_box(pt_x, pt_y) - - def check_click_location(self, pt_x, pt_y): - """ Check whether the point clicked is within an existing - bounding box and set face_id """ - frame = self.media["frame_id"] - alignments = self.alignments.get_faces_in_frame(frame) - scale = self.interface.get_frame_scaling() - pt_x = int(pt_x / scale) - pt_y = int(pt_y / scale) - - for idx, alignment in enumerate(alignments): - left = alignment["x"] - right = alignment["x"] + alignment["w"] - top = alignment["y"] - bottom = alignment["y"] + alignment["h"] - - if left <= pt_x <= right and top <= pt_y <= bottom: - self.interface.set_state_value("edit", "selected", idx) - break - - def dims_from_alignment(self): - """ Set the height and width of bounding box from alignment """ - frame = self.media["frame_id"] - face_id = self.interface.get_selected_face_id() - alignment = self.alignments.get_faces_in_frame(frame)[face_id] - self.dims = (alignment["w"], alignment["h"]) - - def dims_from_image(self): - """ Set the height and width of bounding - box at 10% of longest axis """ - size = max(self.media["image"].shape[:2]) - dim = int(size / 10.00) - self.dims = (dim, dim) - - def bounding_from_center(self): - """ Get bounding X Y from center """ - pt_x, pt_y = self.center - width, height = self.dims - scale = self.interface.get_frame_scaling() - self.media["bounding_box"] = [int((pt_x / scale) - width / 2), - int((pt_y / scale) - height / 2), - int((pt_x / scale) + width / 2), - int((pt_y / scale) + height / 2)] - - def move_bounding_box(self, pt_x, pt_y): - """ Move the bounding box """ - self.center = (pt_x, pt_y) - self.bounding_from_center() - self.update_landmarks() - - def resize_bounding_box(self, pt_x, pt_y): - """ Resize the bounding box """ - scale = self.interface.get_frame_scaling() - - if not self.last_move: - self.last_move = (pt_x, pt_y) - self.media["bounding_box_orig"] = self.media["bounding_box"] - - move_x = int(pt_x - self.last_move[0]) - move_y = int(self.last_move[1] - pt_y) - - original = self.media["bounding_box_orig"] - updated = self.media["bounding_box"] - - minsize = int(10 / scale) - center = (int(self.center[0] / scale), int(self.center[1] / scale)) - updated[0] = min(center[0] - minsize, original[0] - move_x) - updated[1] = min(center[1] - minsize, original[1] - move_y) - updated[2] = max(center[0] + minsize, original[2] + move_x) - updated[3] = max(center[1] + minsize, original[3] + move_y) - self.update_landmarks() - self.last_move = (pt_x, pt_y) - - def update_landmarks(self): - """ Update the landmarks """ - queue_manager.get_queue("in").put({"image": self.media["image"], - "filename": self.media["frame_id"], - "face": self.media["bounding_box"]}) - landmarks = queue_manager.get_queue("out").get() - - if isinstance(landmarks, dict) and landmarks.get("exception"): - cv2.destroyAllWindows() # pylint: disable=no-member - pid = landmarks["exception"][0] - t_back = landmarks["exception"][1].getvalue() - err = "Error in child process {}. {}".format(pid, t_back) - raise Exception(err) - if landmarks == "EOF": - exit(0) - - alignment = self.extracted_to_alignment((landmarks["detected_faces"][0], - landmarks["landmarks"][0])) - frame = self.media["frame_id"] - - if self.interface.get_selected_face_id() is None: - idx = self.alignments.add_face(frame, alignment) - self.interface.set_state_value("edit", "selected", idx) - else: - self.alignments.update_face(frame, - self.interface.get_selected_face_id(), - alignment) - self.interface.set_redraw(True) - - self.interface.state["edit"]["updated"] = True - self.interface.state["edit"]["update_faces"] = True - - @staticmethod - def extracted_to_alignment(extract_data): - """ Convert Extracted Tuple to Alignments data """ - alignment = dict() - d_rect, landmarks = extract_data - alignment["x"] = d_rect.left() - alignment["w"] = d_rect.right() - d_rect.left() - alignment["y"] = d_rect.top() - alignment["h"] = d_rect.bottom() - d_rect.top() - alignment["landmarksXY"] = landmarks - return alignment diff --git a/tools/lib_alignments/media.py b/tools/lib_alignments/media.py deleted file mode 100644 index ed4c547dda..0000000000 --- a/tools/lib_alignments/media.py +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env python3 -""" Media items (Alignments, Faces, Frames) - for alignments tool """ - -import logging -import os -from tqdm import tqdm - -import cv2 - -from lib.alignments import Alignments -from lib.faces_detect import DetectedFace -from lib.utils import _image_extensions, _video_extensions, hash_image_file, hash_encode_image - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class AlignmentData(Alignments): - """ Class to hold the alignment data """ - - def __init__(self, alignments_file, destination_format): - logger.debug("Initializing %s: (alignments file: '%s', destination_format: '%s')", - self.__class__.__name__, alignments_file, destination_format) - logger.info("[ALIGNMENT DATA]") # Tidy up cli output - folder, filename = self.check_file_exists(alignments_file) - if filename.lower() == "dfl": - self.set_dfl(destination_format) - return - super().__init__(folder, filename=filename) - self.set_destination_format(destination_format) - logger.verbose("%s items loaded", self.frames_count) - logger.debug("Initialized %s", self.__class__.__name__) - - @staticmethod - def check_file_exists(alignments_file): - """ Check the alignments file exists""" - folder, filename = os.path.split(alignments_file) - if filename.lower() == "dfl": - folder = None - filename = "dfl" - logger.info("Using extracted pngs for alignments") - elif not os.path.isfile(alignments_file): - logger.error("ERROR: alignments file not found at: '%s'", alignments_file) - exit(0) - if folder: - logger.verbose("Alignments file exists at '%s'", alignments_file) - return folder, filename - - def set_dfl(self, destination_format): - """ Set the alignments for dfl alignments """ - logger.debug("Alignments are DFL format") - self.file = "dfl" - self.set_destination_format(destination_format) - - def set_destination_format(self, destination_format): - """ Standardize the destination format to the correct extension """ - extensions = {".json": "json", - ".p": "pickle", - ".yml": "yaml", - ".yaml": "yaml"} - dst_fmt = None - file_ext = os.path.splitext(self.file)[1].lower() - logger.debug("File extension: '%s'", file_ext) - - if destination_format is not None: - dst_fmt = destination_format - elif self.file == "dfl": - dst_fmt = "json" - elif file_ext in extensions.keys(): - dst_fmt = extensions[file_ext] - else: - logger.error("'%s' is not a supported serializer. Exiting", file_ext) - exit(0) - - logger.verbose("Destination format set to '%s'", dst_fmt) - - self.serializer = self.get_serializer("", dst_fmt) - filename = os.path.splitext(self.file)[0] - self.file = "{}.{}".format(filename, self.serializer.ext) - logger.debug("Destination file: '%s'", self.file) - - def save(self): - """ Backup copy of old alignments and save new alignments """ - self.backup() - super().save() - - -class MediaLoader(): - """ Class to load filenames from folder """ - def __init__(self, folder): - logger.debug("Initializing %s: (folder: '%s')", self.__class__.__name__, folder) - logger.info("[%s DATA]", self.__class__.__name__.upper()) - self.folder = folder - self.vid_cap = self.check_input_folder() - self.file_list_sorted = self.sorted_items() - self.items = self.load_items() - logger.verbose("%s items loaded", self.count) - logger.debug("Initialized %s", self.__class__.__name__) - - @property - def count(self): - """ Number of faces or frames """ - if self.vid_cap: - retval = int(self.vid_cap.get(cv2.CAP_PROP_FRAME_COUNT)) # pylint: disable=no-member - else: - retval = len(self.file_list_sorted) - return retval - - def check_input_folder(self): - """ makes sure that the frames or faces folder exists - If frames folder contains a video file return video capture object """ - err = None - loadtype = self.__class__.__name__ - if not self.folder: - err = "ERROR: A {} folder must be specified".format(loadtype) - elif not os.path.exists(self.folder): - err = ("ERROR: The {} location {} could not be " - "found".format(loadtype, self.folder)) - if err: - logger.error(err) - exit(0) - - if (loadtype == "Frames" and - os.path.isfile(self.folder) and - os.path.splitext(self.folder)[1] in _video_extensions): - logger.verbose("Video exists at : '%s'", self.folder) - retval = cv2.VideoCapture(self.folder) # pylint: disable=no-member - else: - logger.verbose("Folder exists at '%s'", self.folder) - retval = None - return retval - - @staticmethod - def valid_extension(filename): - """ Check whether passed in file has a valid extension """ - extension = os.path.splitext(filename)[1] - retval = extension in _image_extensions - logger.trace("Filename has valid extension: '%s': %s", filename, retval) - return retval - - @staticmethod - def sorted_items(): - """ Override for specific folder processing """ - return list() - - @staticmethod - def process_folder(): - """ Override for specific folder processing """ - return list() - - @staticmethod - def load_items(): - """ Override for specific item loading """ - return dict() - - def load_image(self, filename): - """ Load an image """ - if self.vid_cap: - image = self.load_video_frame(filename) - else: - src = os.path.join(self.folder, filename) - logger.trace("Loading image: '%s'", src) - image = cv2.imread(src) # pylint: disable=no-member - return image - - def load_video_frame(self, filename): - """ Load a requested frame from video """ - frame = os.path.splitext(filename)[0] - logger.trace("Loading video frame: '%s'", frame) - frame_no = int(frame[frame.rfind("_") + 1:]) - 1 - self.vid_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) # pylint: disable=no-member - _, image = self.vid_cap.read() - return image - - @staticmethod - def save_image(output_folder, filename, image): - """ Save an image """ - output_file = os.path.join(output_folder, filename) - logger.trace("Saving image: '%s'", output_file) - cv2.imwrite(output_file, image) # pylint: disable=no-member - - -class Faces(MediaLoader): - """ Object to hold the faces that are to be swapped out """ - - def process_folder(self): - """ Iterate through the faces dir pulling out various information """ - logger.info("Loading file list from %s", self.folder) - for face in tqdm(os.listdir(self.folder), desc="Reading Face Hashes"): - if not self.valid_extension(face): - continue - filename = os.path.splitext(face)[0] - file_extension = os.path.splitext(face)[1] - face_hash = hash_image_file(os.path.join(self.folder, face)) - retval = {"face_fullname": face, - "face_name": filename, - "face_extension": file_extension, - "face_hash": face_hash} - logger.trace(retval) - yield retval - - def load_items(self): - """ Load the face names into dictionary """ - faces = dict() - for face in self.file_list_sorted: - faces.setdefault(face["face_hash"], list()).append((face["face_name"], - face["face_extension"])) - logger.trace(faces) - return faces - - def sorted_items(self): - """ Return the items sorted by face name """ - items = sorted([item for item in self.process_folder()], - key=lambda x: (x["face_name"])) - logger.trace(items) - return items - - -class Frames(MediaLoader): - """ Object to hold the frames that are to be checked against """ - - def process_folder(self): - """ Iterate through the frames dir pulling the base filename """ - iterator = self.process_video if self.vid_cap else self.process_frames - for item in iterator(): - yield item - - def process_frames(self): - """ Process exported Frames """ - logger.info("Loading file list from %s", self.folder) - for frame in os.listdir(self.folder): - if not self.valid_extension(frame): - continue - filename = os.path.splitext(frame)[0] - file_extension = os.path.splitext(frame)[1] - - retval = {"frame_fullname": frame, - "frame_name": filename, - "frame_extension": file_extension} - logger.trace(retval) - yield retval - - def process_video(self): - """Dummy in frames for video """ - logger.info("Loading video frames from %s", self.folder) - vidname = os.path.splitext(os.path.basename(self.folder))[0] - for i in range(self.count): - idx = i + 1 - # Keep filename format for outputted face - filename = "{}_{:06d}".format(vidname, idx) - retval = {"frame_fullname": "{}.png".format(filename), - "frame_name": filename, - "frame_extension": ".png"} - logger.trace(retval) - yield retval - - def load_items(self): - """ Load the frame info into dictionary """ - frames = dict() - for frame in self.file_list_sorted: - frames[frame["frame_fullname"]] = (frame["frame_name"], - frame["frame_extension"]) - logger.trace(frames) - return frames - - def sorted_items(self): - """ Return the items sorted by filename """ - items = sorted([item for item in self.process_folder()], - key=lambda x: (x["frame_name"])) - logger.trace(items) - return items - - -class ExtractedFaces(): - """ Holds the extracted faces and matrix for - alignments """ - def __init__(self, frames, alignments, size=256, align_eyes=False): - logger.trace("Initializing %s: (size: %s, align_eyes: %s)", - self.__class__.__name__, size, align_eyes) - self.size = size - self.padding = int(size * 0.1875) - self.align_eyes = align_eyes - self.alignments = alignments - self.frames = frames - - self.current_frame = None - self.faces = list() - logger.trace("Initialized %s", self.__class__.__name__) - - def get_faces(self, frame): - """ Return faces and transformed landmarks - for each face in a given frame with it's alignments""" - logger.trace("Getting faces for frame: '%s'", frame) - self.current_frame = None - alignments = self.alignments.get_faces_in_frame(frame) - logger.trace("Alignments for frame: (frame: '%s', alignments: %s)", frame, alignments) - if not alignments: - self.faces = list() - return - image = self.frames.load_image(frame) - self.faces = [self.extract_one_face(alignment, image.copy()) - for alignment in alignments] - self.current_frame = frame - - def extract_one_face(self, alignment, image): - """ Extract one face from image """ - logger.trace("Extracting one face: (frame: '%s', alignment: %s)", - self.current_frame, alignment) - face = DetectedFace() - face.from_alignment(alignment, image=image) - face.load_aligned(image, size=self.size, align_eyes=self.align_eyes) - return face - - def get_faces_in_frame(self, frame, update=False): - """ Return the faces for the selected frame """ - logger.trace("frame: '%s', update: %s", frame, update) - if self.current_frame != frame or update: - self.get_faces(frame) - return self.faces - - def get_roi_size_for_frame(self, frame): - """ Return the size of the original extract box for - the selected frame """ - logger.trace("frame: '%s'", frame) - if self.current_frame != frame: - self.get_faces(frame) - sizes = list() - for face in self.faces: - top_left, top_right = face.original_roi[0], face.original_roi[3] - len_x = top_right[0] - top_left[0] - len_y = top_right[1] - top_left[1] - if top_left[1] == top_right[1]: - length = len_y - else: - length = int(((len_x ** 2) + (len_y ** 2)) ** 0.5) - sizes.append(length) - logger.trace("sizes: '%s'", sizes) - return sizes - - @staticmethod - def save_face_with_hash(filename, extension, face): - """ Save a face and return it's hash """ - f_hash, img = hash_encode_image(face, extension) - logger.trace("Saving face: '%s'", filename) - with open(filename, "wb") as out_file: - out_file.write(img) - return f_hash diff --git a/tools/manual/__init__.py b/tools/manual/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/manual/cli.py b/tools/manual/cli.py new file mode 100644 index 0000000000..bb34c007ba --- /dev/null +++ b/tools/manual/cli.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +""" The Command Line Arguments for the Manual Editor tool. """ +import gettext + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import DirOrFileFullPaths, FileFullPaths +from lib.utils import get_module_objects + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + +_HELPTEXT = _("This command lets you perform various actions on frames, " + "faces and alignments files using visual tools.") + + +class ManualArgs(FaceSwapArgs): + """ Generate the command line options for the Manual Editor Tool.""" + + @staticmethod + def get_info(): + """ Obtain the information about what the Manual Tool does. """ + return _("A tool to perform various actions on frames, faces and alignments files using " + "visual tools") + + @staticmethod + def get_argument_list(): + """ Generate the command line argument list for the Manual Tool. """ + argument_list = [] + argument_list.append({ + "opts": ("-a", "--alignments"), + "action": FileFullPaths, + "filetypes": "alignments", + "type": str, + "group": _("data"), + "dest": "alignments_path", + "help": _( + "Path to the alignments file for the input, if not at the default location")}) + argument_list.append({ + "opts": ("-f", "--frames"), + "action": DirOrFileFullPaths, + "filetypes": "video", + "required": True, + "group": _("data"), + "help": _( + "Video file or directory containing source frames that faces were extracted " + "from.")}) + argument_list.append({ + "opts": ("-t", "--thumb-regen"), + "action": "store_true", + "dest": "thumb_regen", + "default": False, + "group": _("options"), + "help": _( + "Force regeneration of the low resolution jpg thumbnails in the alignments " + "file.")}) + argument_list.append({ + "opts": ("-s", "--single-process"), + "action": "store_true", + "dest": "single_process", + "default": False, + "group": _("options"), + "help": _( + "The process attempts to speed up generation of thumbnails by extracting from the " + "video in parallel threads. For some videos, this causes the caching process to " + "hang. If this happens, then set this option to generate the thumbnails in a " + "slower, but more stable single thread.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/detected_faces.py b/tools/manual/detected_faces.py new file mode 100644 index 0000000000..50129b4f83 --- /dev/null +++ b/tools/manual/detected_faces.py @@ -0,0 +1,939 @@ +#!/usr/bin/env python3 +""" Alignments handling for Faceswap's Manual Adjustments tool. Handles the conversion of +alignments data to :class:`~lib.align.DetectedFace` objects, and the update of these faces +when edits are made in the GUI. """ +from __future__ import annotations +import logging +import os +import sys +import tkinter as tk +import typing as T +from copy import deepcopy +from queue import Queue, Empty + +import cv2 +import numpy as np + +from lib.align import Alignments, AlignedFace, DetectedFace +from lib.gui.custom_widgets import PopupProgress +from lib.gui.utils import FileHandler +from lib.image import ImagesLoader, ImagesSaver, encode_image, generate_thumbnail +from lib.multithreading import MultiThread +from lib.utils import get_folder, get_module_objects + +if T.TYPE_CHECKING: + from . import manual + from lib.align.alignments import AlignmentFileDict, PNGHeaderDict + +logger = logging.getLogger(__name__) + + +class DetectedFaces(): + """ Handles the manipulation of :class:`~lib.align.DetectedFace` objects stored + in the alignments file. Acts as a parent class for the IO operations (saving and loading from + an alignments file), the face update operations (when changes are made to alignments in the + GUI) and the face filters (when a user changes the filter navigation mode.) + + Parameters + ---------- + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + alignments_path: str + The full path to the alignments file + input_location: str + The location of the input folder of frames or video file + extractor: :class:`~tools.manual.manual.Aligner` + The pipeline for passing faces through the aligner and retrieving results + """ + def __init__(self, + tk_globals: manual.TkGlobals, + alignments_path: str, + input_location: str, + extractor: manual.Aligner) -> None: + logger.debug("Initializing %s: (tk_globals: %s. alignments_path: %s, input_location: %s " + "extractor: %s)", self.__class__.__name__, tk_globals, alignments_path, + input_location, extractor) + self._globals = tk_globals + self._frame_faces: list[list[DetectedFace]] = [] + self._updated_frame_indices: set[int] = set() + + self._alignments: Alignments = self._get_alignments(alignments_path, input_location) + self._alignments.update_legacy_has_source(os.path.basename(input_location)) + + self._extractor = extractor + self._tk_vars = self._set_tk_vars() + + self._io = _DiskIO(self, input_location) + self._update = FaceUpdate(self) + self._filter = Filter(self) + logger.debug("Initialized %s", self.__class__.__name__) + + # <<<< PUBLIC PROPERTIES >>>> # + @property + def extractor(self) -> manual.Aligner: + """ :class:`~tools.manual.manual.Aligner`: The pipeline for passing faces through the + aligner and retrieving results. """ + return self._extractor + + @property + def filter(self) -> Filter: + """ :class:`Filter`: Handles returning of faces and stats based on the current user set + navigation mode filter. """ + return self._filter + + @property + def update(self) -> FaceUpdate: + """ :class:`FaceUpdate`: Handles the adding, removing and updating of + :class:`~lib.align.DetectedFace` stored within the alignments file. """ + return self._update + + # << TKINTER VARIABLES >> # + @property + def tk_unsaved(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: The variable indicating whether the alignments have been + updated since the last save. """ + return self._tk_vars["unsaved"] + + @property + def tk_edited(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: The variable indicating whether an edit has occurred + meaning a GUI redraw needs to be triggered. """ + return self._tk_vars["edited"] + + @property + def tk_face_count_changed(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: The variable indicating whether a face has been added or + removed meaning the :class:`FaceViewer` grid redraw needs to be triggered. """ + return self._tk_vars["face_count_changed"] + + # << STATISTICS >> # + @property + def frame_list(self) -> list[str]: + """ list[str]: The list of all frame names that appear in the alignments file """ + return list(self._alignments.data) + + @property + def available_masks(self) -> dict[str, int]: + """ dict[str, int]: The mask type names stored in the alignments; type as key with the + number of faces which possess the mask type as value. """ + return self._alignments.mask_summary + + @property + def current_faces(self) -> list[list[DetectedFace]]: + """ list[list[:class:`~lib.align.DetectedFace`]]: The most up to date full list of detected + face objects. """ + return self._frame_faces + + @property + def video_meta_data(self) -> dict[str, list[int] | list[float] | None]: + """ dict[str, list[int] | list[float] | None]: The frame meta data stored in the alignments + file. If data does not exist in the alignments file then ``None`` is returned for each + Key """ + return self._alignments.video_meta_data + + @property + def face_count_per_index(self) -> list[int]: + """ list[int]: Count of faces for each frame. List is in frame index order. + + The list needs to be calculated on the fly as the number of faces in a frame + can change based on user actions. """ + return [len(faces) for faces in self._frame_faces] + + # <<<< PUBLIC METHODS >>>> # + def is_frame_updated(self, frame_index: int) -> bool: + """ Check whether the given frame index has been updated + + Parameters + ---------- + frame_index: int + The frame index to check + + Returns + ------- + bool: + ``True`` if the given frame index has updated faces within it otherwise ``False`` + """ + return frame_index in self._updated_frame_indices + + def load_faces(self) -> None: + """ Load the faces as :class:`~lib.align.DetectedFace` objects from the alignments + file. """ + self._io.load() + + def save(self) -> None: + """ Save the alignments file with the latest edits. """ + self._io.save() + + def revert_to_saved(self, frame_index): + """ Revert the frame's alignments to their saved version for the given frame index. + + Parameters + ---------- + frame_index: int + The frame that should have their faces reverted to their saved version + """ + self._io.revert_to_saved(frame_index) + + def extract(self) -> None: + """ Extract the faces in the current video to a user supplied folder. """ + self._io.extract() + + def save_video_meta_data(self, pts_time: list[float], keyframes: list[int]) -> None: + """ Save video meta data to the alignments file. This is executed if the video meta data + does not already exist in the alignments file, so the video does not need to be scanned + on every use of the Manual Tool. + + Parameters + ---------- + pts_time: list[float] + A list of presentation timestamps in frame index order for every frame in the input + video + keyframes: list[int] + A list of frame indices corresponding to the key frames in the input video. + """ + if self._globals.is_video: + self._alignments.save_video_meta_data(pts_time, keyframes) + + # <<<< PRIVATE METHODS >>> # + # << INIT >> # + @staticmethod + def _set_tk_vars() -> dict[T.Literal["unsaved", "edited", "face_count_changed"], + tk.BooleanVar]: + """ Set the required tkinter variables. + + The alignments specific `unsaved` and `edited` are set here. + The global variables are added into the dictionary with `None` as value, so the + objects exist. Their actual variables are populated during :func:`load_faces`. + + Returns + ------- + dict + The internal variable name as key with the tkinter variable as value + """ + retval = {} + for name in T.get_args(T.Literal["unsaved", "edited", "face_count_changed"]): + var = tk.BooleanVar() + var.set(False) + retval[name] = var + logger.debug(retval) + return retval + + def _get_alignments(self, alignments_path: str, input_location: str) -> Alignments: + """ Get the :class:`~lib.align.Alignments` object for the given location. + + Parameters + ---------- + alignments_path: str + Full path to the alignments file. If empty string is passed then location is calculated + from the source folder + input_location: str + The location of the input folder of frames or video file + + Returns + ------- + :class:`~lib.align.Alignments` + The alignments object for the given input location + """ + logger.debug("alignments_path: %s, input_location: %s", alignments_path, input_location) + if alignments_path: + folder, filename = os.path.split(alignments_path) + else: + filename = "alignments.fsa" + if self._globals.is_video: + folder, vid = os.path.split(os.path.splitext(input_location)[0]) + filename = f"{vid}_{filename}" + else: + folder = input_location + retval = Alignments(folder, filename) + if retval.version == 1.0: + logger.error("The Manual Tool is not compatible with legacy Alignments files.") + logger.info("You can update legacy Alignments files by using the Extract job in the " + "Alignments tool to re-extract the faces in full-head format.") + sys.exit(0) + logger.debug("folder: %s, filename: %s, alignments: %s", folder, filename, retval) + return retval + + +class _DiskIO(): + """ Handles the loading of :class:`~lib.align.DetectedFaces` from the alignments file + into :class:`DetectedFaces` and the saving of this data (in the opposite direction) to an + alignments file. + + Parameters + ---------- + detected_faces: :class:`DetectedFaces` + The parent :class:`DetectedFaces` object + input_location: str + The location of the input folder of frames or video file + """ + def __init__(self, detected_faces: DetectedFaces, input_location: str) -> None: + logger.debug("Initializing %s: (detected_faces: %s, input_location: %s)", + self.__class__.__name__, detected_faces, input_location) + self._input_location = input_location + self._alignments = detected_faces._alignments + self._frame_faces = detected_faces._frame_faces + self._updated_frame_indices = detected_faces._updated_frame_indices + self._tk_unsaved = detected_faces.tk_unsaved + self._tk_edited = detected_faces.tk_edited + self._tk_face_count_changed = detected_faces.tk_face_count_changed + self._globals = detected_faces._globals + + # Must be populated after loading faces as video_meta_data may have increased frame count + self._sorted_frame_names: list[str] = [] + logger.debug("Initialized %s", self.__class__.__name__) + + def load(self) -> None: + """ Load the faces from the alignments file, convert to + :class:`~lib.align.DetectedFace`. objects and add to :attr:`_frame_faces`. """ + for key in sorted(self._alignments.data): + this_frame_faces: list[DetectedFace] = [] + for item in self._alignments.data[key]["faces"]: + face = DetectedFace() + face.from_alignment(item, with_thumb=True) + face.load_aligned(None) + _ = face.aligned.average_distance # cache the distances + this_frame_faces.append(face) + self._frame_faces.append(this_frame_faces) + self._sorted_frame_names = sorted(self._alignments.data) + + def save(self) -> None: + """ Convert updated :class:`~lib.align.DetectedFace` objects to alignments format + and save the alignments file. """ + if not self._tk_unsaved.get(): + logger.debug("Alignments not updated. Returning") + return + frames = list(self._updated_frame_indices) + logger.verbose("Saving alignments for %s updated frames", # type:ignore[attr-defined] + len(frames)) + + for idx, faces in zip(frames, + np.array(self._frame_faces, dtype="object")[np.array(frames)]): + frame = self._sorted_frame_names[idx] + self._alignments.data[frame]["faces"] = [face.to_alignment() for face in faces] + + self._alignments.backup() + self._alignments.save() + self._updated_frame_indices.clear() + self._tk_unsaved.set(False) + + def revert_to_saved(self, frame_index: int) -> None: + """ Revert the frame's alignments to their saved version for the given frame index. + + Parameters + ---------- + frame_index: int + The frame that should have their faces reverted to their saved version + """ + if frame_index not in self._updated_frame_indices: + logger.debug("Alignments not amended. Returning") + return + logger.verbose("Reverting alignments for frame_index %s", # type:ignore[attr-defined] + frame_index) + alignments = self._alignments.data[self._sorted_frame_names[frame_index]]["faces"] + faces = self._frame_faces[frame_index] + + reset_grid = self._add_remove_faces(alignments, faces) + + for detected_face, face in zip(faces, alignments): + detected_face.from_alignment(face, with_thumb=True) + detected_face.load_aligned(None, force=True) + _ = detected_face.aligned.average_distance # cache the distances + + self._updated_frame_indices.remove(frame_index) + if not self._updated_frame_indices: + self._tk_unsaved.set(False) + + if reset_grid: + self._tk_face_count_changed.set(True) + else: + self._tk_edited.set(True) + self._globals.var_full_update.set(True) + + @classmethod + def _add_remove_faces(cls, + alignments: list[AlignmentFileDict], + faces: list[DetectedFace]) -> bool: + """ On a revert, ensure that the alignments and detected face object counts for each frame + are in sync. + + Parameters + ---------- + alignments: list[:class:`~lib.align.alignments.AlignmentFileDict`] + Alignments stored for a frame + + faces: list[:class:`~lib.align.DetectedFace`] + List of detected faces for a frame + + Returns + ------- + bool + ``True`` if a face was added or removed otherwise ``False`` + """ + num_alignments = len(alignments) + num_faces = len(faces) + if num_alignments == num_faces: + retval = False + elif num_alignments > num_faces: + faces.extend([DetectedFace() for _ in range(num_faces, num_alignments)]) + retval = True + else: + del faces[num_alignments:] + retval = True + return retval + + def extract(self) -> None: + """ Extract the current faces to a folder. + + To stop the GUI becoming completely unresponsive (particularly in Windows) the extract is + done in a background thread, with the process count passed back in a queue to the main + thread to update the progress bar. + """ + dirname = FileHandler("dir", None, + initial_folder=os.path.dirname(self._input_location), + title="Select output folder...").return_file + if not dirname: + return + logger.debug(dirname) + + queue: Queue = Queue() + pbar = PopupProgress("Extracting Faces...", self._alignments.frames_count + 1) + thread = MultiThread(self._background_extract, dirname, queue) + thread.start() + self._monitor_extract(thread, queue, pbar) + + def _monitor_extract(self, + thread: MultiThread, + queue: Queue, + progress_bar: PopupProgress) -> None: + """ Monitor the extraction thread, and update the progress bar. + + On completion, save alignments and clear progress bar. + + Parameters + ---------- + thread: :class:`~lib.multithreading.MultiThread` + The thread that is performing the extraction task + queue: :class:`queue.Queue` + The queue that the worker thread is putting it's incremental counts to + progress_bar: :class:`~lib.gui.custom_widget.PopupProgress` + The popped up progress bar + """ + thread.check_and_raise_error() + if not thread.is_alive(): + thread.join() + progress_bar.stop() + return + + while True: + try: + progress_bar.step(queue.get(False, 0)) + except Empty: + break + progress_bar.after(100, self._monitor_extract, thread, queue, progress_bar) + + def _background_extract(self, output_folder: str, progress_queue: Queue) -> None: + """ Perform the background extraction in a thread so GUI doesn't become unresponsive. + + Parameters + ---------- + output_folder: str + The location to save the output faces to + progress_queue: :class:`queue.Queue` + The queue to place incremental counts to for updating the GUI's progress bar + """ + saver = ImagesSaver(get_folder(output_folder), as_bytes=True) + loader = ImagesLoader(self._input_location, count=self._alignments.frames_count) + for frame_idx, (filename, image) in enumerate(loader.load()): + logger.trace("Outputting frame: %s: %s", # type:ignore[attr-defined] + frame_idx, filename) + src_filename = os.path.basename(filename) + progress_queue.put(1) + + for face_idx, face in enumerate(self._frame_faces[frame_idx]): + output = f"{os.path.splitext(src_filename)[0]}_{face_idx}.png" + aligned = AlignedFace(face.landmarks_xy, + image=image, + centering="head", + size=512) # TODO user selectable size + meta: PNGHeaderDict = {"alignments": face.to_png_meta(), + "source": {"alignments_version": self._alignments.version, + "original_filename": output, + "face_index": face_idx, + "source_filename": src_filename, + "source_is_video": self._globals.is_video, + "source_frame_dims": image.shape[:2]}} + + assert aligned.face is not None + b_image = encode_image(aligned.face, ".png", metadata=meta) + saver.save(output, b_image) + saver.close() + + +class Filter(): + """ Returns stats and frames for filtered frames based on the user selected navigation mode + filter. + + Parameters + ---------- + detected_faces: :class:`DetectedFaces` + The parent :class:`DetectedFaces` object + """ + def __init__(self, detected_faces: DetectedFaces) -> None: + logger.debug("Initializing %s: (detected_faces: %s)", + self.__class__.__name__, detected_faces) + self._globals = detected_faces._globals + self._detected_faces = detected_faces + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def frame_meets_criteria(self) -> bool: + """ bool: ``True`` if the current frame meets the selected filter criteria otherwise + ``False`` """ + filter_mode = self._globals.var_filter_mode.get() + frame_faces = self._detected_faces.current_faces[self._globals.frame_index] + distance = self._filter_distance + + retval = ( + filter_mode == "All Frames" or + (filter_mode == "No Faces" and not frame_faces) or + (filter_mode == "Has Face(s)" and len(frame_faces) > 0) or + (filter_mode == "Multiple Faces" and len(frame_faces) > 1) or + (filter_mode == "Misaligned Faces" and any(face.aligned.average_distance > distance + for face in frame_faces))) + assert isinstance(retval, bool) + logger.trace("filter_mode: %s, frame meets criteria: %s", # type:ignore[attr-defined] + filter_mode, retval) + return retval + + @property + def _filter_distance(self) -> float: + """ float: The currently selected distance when Misaligned Faces filter is selected. """ + try: + retval = self._globals.var_filter_distance.get() + except tk.TclError: + # Suppress error when distance box is empty + retval = 0 + return retval / 100. + + @property + def count(self) -> int: + """ int: The number of frames that meet the filter criteria returned by + :attr:`~tools.manual.manual.TkGlobals.var_filter_mode.get()`. """ + face_count_per_index = self._detected_faces.face_count_per_index + if self._globals.var_filter_mode.get() == "No Faces": + retval = sum(1 for fcount in face_count_per_index if fcount == 0) + elif self._globals.var_filter_mode.get() == "Has Face(s)": + retval = sum(1 for fcount in face_count_per_index if fcount != 0) + elif self._globals.var_filter_mode.get() == "Multiple Faces": + retval = sum(1 for fcount in face_count_per_index if fcount > 1) + elif self._globals.var_filter_mode.get() == "Misaligned Faces": + distance = self._filter_distance + retval = sum(1 for frame in self._detected_faces.current_faces + if any(face.aligned.average_distance > distance for face in frame)) + else: + retval = len(face_count_per_index) + logger.trace("filter mode: %s, frame count: %s", # type:ignore[attr-defined] + self._globals.var_filter_mode.get(), retval) + return retval + + @property + def raw_indices(self) -> dict[T.Literal["frame", "face"], list[int]]: + """ dict[str, int]: The frame and face indices that meet the current filter criteria for + each displayed face. """ + frame_indices: list[int] = [] + face_indices: list[int] = [] + face_counts = self._detected_faces.face_count_per_index # Copy to avoid recalculations + + for frame_idx in self.frames_list: + for face_idx in range(face_counts[frame_idx]): + frame_indices.append(frame_idx) + face_indices.append(face_idx) + + retval: dict[T.Literal["frame", "face"], list[int]] = {"frame": frame_indices, + "face": face_indices} + logger.trace("frame_indices: %s, face_indices: %s", # type:ignore[attr-defined] + frame_indices, face_indices) + return retval + + @property + def frames_list(self) -> list[int]: + """ list[int]: The list of frame indices that meet the filter criteria returned by + :attr:`~tools.manual.manual.TkGlobals.var_filter_mode.get()`. """ + face_count_per_index = self._detected_faces.face_count_per_index + if self._globals.var_filter_mode.get() == "No Faces": + retval = [idx for idx, count in enumerate(face_count_per_index) if count == 0] + elif self._globals.var_filter_mode.get() == "Multiple Faces": + retval = [idx for idx, count in enumerate(face_count_per_index) if count > 1] + elif self._globals.var_filter_mode.get() == "Has Face(s)": + retval = [idx for idx, count in enumerate(face_count_per_index) if count != 0] + elif self._globals.var_filter_mode.get() == "Misaligned Faces": + distance = self._filter_distance + retval = [idx for idx, frame in enumerate(self._detected_faces.current_faces) + if any(face.aligned.average_distance > distance for face in frame)] + else: + retval = list(range(len(face_count_per_index))) + logger.trace("filter mode: %s, number_frames: %s", # type:ignore[attr-defined] + self._globals.var_filter_mode.get(), len(retval)) + return retval + + +class FaceUpdate(): + """ Perform updates on :class:`~lib.align.DetectedFace` objects stored in + :class:`DetectedFaces` when changes are made within the GUI. + + Parameters + ---------- + detected_faces: :class:`DetectedFaces` + The parent :class:`DetectedFaces` object + """ + def __init__(self, detected_faces: DetectedFaces) -> None: + logger.debug("Initializing %s: (detected_faces: %s)", + self.__class__.__name__, detected_faces) + self._detected_faces = detected_faces + self._globals = detected_faces._globals + self._frame_faces = detected_faces._frame_faces + self._updated_frame_indices = detected_faces._updated_frame_indices + self._tk_unsaved = detected_faces.tk_unsaved + self._extractor = detected_faces.extractor + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _tk_edited(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: The variable indicating whether an edit has occurred + meaning a GUI redraw needs to be triggered. + + Notes + ----- + The variable is still a ``None`` when this class is initialized, so referenced explicitly. + """ + return self._detected_faces.tk_edited + + @property + def _tk_face_count_changed(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: The variable indicating whether an edit has occurred + meaning a GUI redraw needs to be triggered. + + Notes + ----- + The variable is still a ``None`` when this class is initialized, so referenced explicitly. + """ + return self._detected_faces.tk_face_count_changed + + def _faces_at_frame_index(self, frame_index: int) -> list[DetectedFace]: + """ Checks whether the frame has already been added to :attr:`_updated_frame_indices` and + adds it. Triggers the unsaved variable if this is the first edited frame. Returns the + detected face objects for the given frame. + + Parameters + ---------- + frame_index: int + The frame index to check whether there are updated alignments available + + Returns + ------- + list + The :class:`~lib.align.DetectedFace` objects for the requested frame + """ + if not self._updated_frame_indices and not self._tk_unsaved.get(): + self._tk_unsaved.set(True) + self._updated_frame_indices.add(frame_index) + retval = self._frame_faces[frame_index] + return retval + + def add(self, frame_index: int, pnt_x: int, width: int, pnt_y: int, height: int) -> None: + """ Add a :class:`~lib.align.DetectedFace` object to the current frame with the + given dimensions. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + pnt_x: int + The left point of the bounding box + width: int + The width of the bounding box + pnt_y: int + The top point of the bounding box + height: int + The height of the bounding box + """ + face = DetectedFace() + faces = self._faces_at_frame_index(frame_index) + faces.append(face) + face_index = len(faces) - 1 + + self.bounding_box(frame_index, face_index, pnt_x, width, pnt_y, height, aligner="cv2-dnn") + face.load_aligned(None) + self._tk_face_count_changed.set(True) + + def delete(self, frame_index: int, face_index: int) -> None: + """ Delete the :class:`~lib.align.DetectedFace` object for the given frame and face + indices. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + """ + logger.debug("Deleting face at frame index: %s face index: %s", frame_index, face_index) + faces = self._faces_at_frame_index(frame_index) + del faces[face_index] + self._tk_face_count_changed.set(True) + self._globals.var_full_update.set(True) + + def bounding_box(self, + frame_index: int, + face_index: int, + pnt_x: int, + width: int, + pnt_y: int, + height: int, + aligner: manual.TypeManualExtractor = "FAN") -> None: + """ Update the bounding box for the :class:`~lib.align.DetectedFace` object at the + given frame and face indices, with the given dimensions and update the 68 point landmarks + from the :class:`~tools.manual.manual.Aligner` for the updated bounding box. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + pnt_x: int + The left point of the bounding box + width: int + The width of the bounding box + pnt_y: int + The top point of the bounding box + height: int + The height of the bounding box + aligner: ["cv2-dnn", "FAN"], optional + The aligner to use to generate the landmarks. Default: "FAN" + """ + logger.trace("frame_index: %s, face_index %s, pnt_x %s, " # type:ignore[attr-defined] + "width %s, pnt_y %s, height %s, aligner: %s", + frame_index, face_index, pnt_x, width, pnt_y, height, aligner) + face = self._faces_at_frame_index(frame_index)[face_index] + face.left = pnt_x + face.width = width + face.top = pnt_y + face.height = height + face.add_landmarks_xy(self._extractor.get_landmarks(frame_index, face_index, aligner)) + self._globals.var_full_update.set(True) + + def landmark(self, + frame_index: int, face_index: int, + landmark_index: int, + shift_x: int, + shift_y: int, + is_zoomed: bool) -> None: + """ Shift a single landmark point for the :class:`~lib.align.DetectedFace` object + at the given frame and face indices by the given x and y values. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + landmark_index: int or list + The landmark index to shift. If a list is provided, this should be a list of landmark + indices to be shifted + shift_x: int + The amount to shift the landmark by along the x axis + shift_y: int + The amount to shift the landmark by along the y axis + is_zoomed: bool + ``True`` if landmarks are being adjusted on a zoomed image otherwise ``False`` + """ + face = self._faces_at_frame_index(frame_index)[face_index] + if is_zoomed: + aligned = AlignedFace(face.landmarks_xy, + centering="face", + size=min(self._globals.frame_display_dims)) + landmark = aligned.landmarks[landmark_index] + landmark += (shift_x, shift_y) + matrix = aligned.adjusted_matrix + matrix = cv2.invertAffineTransform(matrix) + if landmark.ndim == 1: + landmark = np.reshape(landmark, (1, 1, 2)) + landmark = cv2.transform(landmark, matrix, landmark.shape).squeeze() + face.landmarks_xy[landmark_index] = landmark + else: + for lmk, idx in zip(landmark, landmark_index): # type:ignore[call-overload] + lmk = np.reshape(lmk, (1, 1, 2)) + lmk = cv2.transform(lmk, matrix, lmk.shape).squeeze() + face.landmarks_xy[idx] = lmk + else: + face.landmarks_xy[landmark_index] += (shift_x, shift_y) + self._globals.var_full_update.set(True) + + def landmarks(self, frame_index: int, face_index: int, shift_x: int, shift_y: int) -> None: + """ Shift all of the landmarks and bounding box for the + :class:`~lib.align.DetectedFace` object at the given frame and face indices by the + given x and y values and update the masks. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + shift_x: int + The amount to shift the landmarks by along the x axis + shift_y: int + The amount to shift the landmarks by along the y axis + + Notes + ----- + Whilst the bounding box does not need to be shifted, it is anyway, to ensure that it is + aligned with the newly adjusted landmarks. + """ + face = self._faces_at_frame_index(frame_index)[face_index] + assert face.left is not None and face.top is not None + face.left += shift_x + face.top += shift_y + face.add_landmarks_xy(face.landmarks_xy + (shift_x, shift_y)) + self._globals.var_full_update.set(True) + + def landmarks_rotate(self, + frame_index: int, + face_index: int, + angle: float, + center: np.ndarray) -> None: + """ Rotate the landmarks on an Extract Box rotate for the + :class:`~lib.align.DetectedFace` object at the given frame and face indices for the + given angle from the given center point. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + angle: float + The angle, in radians to rotate the points by + center: :class:`numpy.ndarray` + The center point of the Landmark's Extract Box + """ + face = self._faces_at_frame_index(frame_index)[face_index] + rot_mat = cv2.getRotationMatrix2D(tuple(center.astype("float32")), angle, 1.) + face.add_landmarks_xy(cv2.transform(np.expand_dims(face.landmarks_xy, axis=0), + rot_mat).squeeze()) + self._globals.var_full_update.set(True) + + def landmarks_scale(self, + frame_index: int, + face_index: int, + scale: np.ndarray, + center: np.ndarray) -> None: + """ Scale the landmarks on an Extract Box resize for the + :class:`~lib.align.DetectedFace` object at the given frame and face indices from the + given center point. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + scale: float + The amount to scale the landmarks by + center: :class:`numpy.ndarray` + The center point of the Landmark's Extract Box + """ + face = self._faces_at_frame_index(frame_index)[face_index] + face.add_landmarks_xy(((face.landmarks_xy - center) * scale) + center) + self._globals.var_full_update.set(True) + + def mask(self, frame_index: int, face_index: int, mask: np.ndarray, mask_type: str) -> None: + """ Update the mask on an edit for the :class:`~lib.align.DetectedFace` object at + the given frame and face indices, for the given mask and mask type. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + mask: class:`numpy.ndarray`: + The mask to replace + mask_type: str + The name of the mask that is to be replaced + """ + face = self._faces_at_frame_index(frame_index)[face_index] + face.mask[mask_type].replace_mask(mask) + self._tk_edited.set(True) + self._globals.var_full_update.set(True) + + def copy(self, frame_index: int, direction: T.Literal["prev", "next"]) -> None: + """ Copy the alignments from the previous or next frame that has alignments + to the current frame. + + Parameters + ---------- + frame_index: int + The frame that the needs to have alignments copied to it + direction: ["prev", "next"] + Whether to copy alignments from the previous frame with alignments, or the next + frame with alignments + """ + logger.debug("frame: %s, direction: %s", frame_index, direction) + faces = self._faces_at_frame_index(frame_index) + frames_with_faces = [idx for idx, faces in enumerate(self._detected_faces.current_faces) + if len(faces) > 0] + if direction == "prev": + idx = next((idx for idx in reversed(frames_with_faces) + if idx < frame_index), None) + else: + idx = next((idx for idx in frames_with_faces + if idx > frame_index), None) + if idx is None: + # No previous/next frame available + return + logger.debug("Copying alignments from frame %s to frame: %s", idx, frame_index) + + # aligned_face cannot be deep copied, so remove and recreate + to_copy = self._faces_at_frame_index(idx) + for face in to_copy: + face._aligned = None # pylint:disable=protected-access + copied = deepcopy(to_copy) + + for old_face, new_face in zip(to_copy, copied): + old_face.load_aligned(None) + new_face.load_aligned(None) + + faces.extend(copied) + self._tk_face_count_changed.set(True) + self._globals.var_full_update.set(True) + + def post_edit_trigger(self, frame_index: int, face_index: int) -> None: + """ Update the jpg thumbnail, the viewport thumbnail, the landmark masks and the aligned + face on a face edit. + + Parameters + ---------- + frame_index: int + The frame that the face is being set for + face_index: int + The face index within the frame + """ + face = self._frame_faces[frame_index][face_index] + face.load_aligned(None, force=True) # Update average distance + face.mask = self._extractor.get_masks(frame_index, face_index) + face.clear_all_identities() + + aligned = AlignedFace(face.landmarks_xy, + image=self._globals.current_frame.image, + centering="head", + size=96) + assert aligned.face is not None + face.thumbnail = generate_thumbnail(aligned.face, size=96) + if self._globals.var_filter_mode.get() == "Misaligned Faces": + self._detected_faces.tk_face_count_changed.set(True) + self._tk_edited.set(True) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/faceviewer/__init__.py b/tools/manual/faceviewer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/manual/faceviewer/frame.py b/tools/manual/faceviewer/frame.py new file mode 100644 index 0000000000..30ebbdf7ef --- /dev/null +++ b/tools/manual/faceviewer/frame.py @@ -0,0 +1,805 @@ +#!/usr/bin/env python3 +""" The Faces Viewer Frame and Canvas for Faceswap's Manual Tool. """ +from __future__ import annotations +import colorsys +import gettext +import logging +import platform +import tkinter as tk +from tkinter import ttk +import typing as T +from math import floor, ceil +from threading import Thread, Event + +import numpy as np + +from lib.gui.custom_widgets import RightClickMenu, Tooltip +from lib.gui.utils import get_config, get_images +from lib.image import hex_to_rgb, rgb_to_hex +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .viewport import Viewport + +if T.TYPE_CHECKING: + from tools.manual.detected_faces import DetectedFaces + from tools.manual.frameviewer.frame import DisplayFrame + from tools.manual.manual import TkGlobals + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class FacesFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """ The faces display frame (bottom section of GUI). This frame holds the faces viewport and + the tkinter objects. + + Parameters + ---------- + parent: :class:`ttk.Frame` + The paned window that the faces frame resides in + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The :class:`~lib.align.DetectedFace` objects for this video + display_frame: :class:`~tools.manual.frameviewer.frame.DisplayFrame` + The section of the Manual Tool that holds the frames viewer + """ + def __init__(self, + parent: ttk.Frame, + tk_globals: TkGlobals, + detected_faces: DetectedFaces, + display_frame: DisplayFrame) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent) + self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + self._actions_frame = FacesActionsFrame(self) + + self._faces_frame = ttk.Frame(self) + self._faces_frame.pack_propagate(False) + self._faces_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self._event = Event() + self._canvas = FacesViewer(self._faces_frame, + tk_globals, + self._actions_frame._tk_vars, + detected_faces, + display_frame, + self._event) + self._add_scrollbar() + logger.debug("Initialized %s", self.__class__.__name__) + + def _add_scrollbar(self) -> None: + """ Add a scrollbar to the faces frame """ + logger.debug("Add Faces Viewer Scrollbar") + scrollbar = ttk.Scrollbar(self._faces_frame, command=self._on_scroll) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self._canvas.config(yscrollcommand=scrollbar.set) + self.bind("", self._update_viewport) + logger.debug("Added Faces Viewer Scrollbar") + self.update_idletasks() # Update so scrollbar width is correct + + def _on_scroll(self, *event: tk.Event) -> None: + """ Callback on scrollbar scroll. Updates the canvas location and displays/hides + thumbnail images. + + Parameters + ---------- + event :class:`tkinter.Event` + The scrollbar callback event + """ + self._canvas.yview(*event) + self._canvas.viewport.update() + + def _update_viewport(self, event: tk.Event) -> None: # pylint:disable=unused-argument + """ Update the faces viewport and scrollbar. + + Parameters + ---------- + event: :class:`tkinter.Event` + Unused but required + """ + self._canvas.viewport.update() + self._canvas.configure(scrollregion=self._canvas.bbox("backdrop")) + + def canvas_scroll(self, direction: T.Literal["up", "down", "page-up", "page-down"]) -> None: + """ Scroll the canvas on an up/down or page-up/page-down key press. + + Notes + ----- + To protect against a held down key press stacking tasks and locking up the GUI + a background thread is launched and discards subsequent key presses whilst the + previous update occurs. + + Parameters + ---------- + direction: ["up", "down", "page-up", "page-down"] + The request page scroll direction and amount. + """ + + if self._event.is_set(): + logger.trace("Update already running. " # type:ignore[attr-defined] + "Aborting repeated keypress") + return + logger.trace("Running update on received key press: %s", # type:ignore[attr-defined] + direction) + + amount = 1 if direction.endswith("down") else -1 + units = "pages" if direction.startswith("page") else "units" + self._event.set() + thread = Thread(target=self._canvas.canvas_scroll, + args=(amount, units, self._event)) + thread.start() + + def set_annotation_display(self, key: str) -> None: + """ Set the optional annotation overlay based on keyboard shortcut. + + Parameters + ---------- + key: str + The pressed key + """ + self._actions_frame.on_click(self._actions_frame.key_bindings[key]) + + +class FacesActionsFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """ The left hand action frame holding the optional annotation buttons. + + Parameters + ---------- + parent: :class:`FacesFrame` + The Faces frame that this actions frame reside in + """ + def __init__(self, parent: FacesFrame) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent) + self.pack(side=tk.LEFT, fill=tk.Y, padx=(2, 4), pady=2) + self._tk_vars: dict[T.Literal["mesh", "mask"], tk.BooleanVar] = {} + self._configure_styles() + self._buttons = self._add_buttons() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def key_bindings(self) -> dict[str, T.Literal["mask", "mesh"]]: + """ dict: The mapping of key presses to optional annotations to display. Keyboard shortcuts + utilize the function keys. """ + return {f"F{idx + 9}": display + for idx, display in enumerate(T.get_args(T.Literal["mesh", "mask"]))} + + @property + def _helptext(self) -> dict[T.Literal["mask", "mesh"], str]: + """ dict: `button key`: `button helptext`. The help text to display for each button. """ + inverse_keybindings = {val: key for key, val in self.key_bindings.items()} + retval: dict[T.Literal["mask", "mesh"], str] = {"mesh": _('Display the landmarks mesh'), + "mask": _('Display the mask')} + for item in retval: + retval[item] += f" ({inverse_keybindings[item]})" + return retval + + def _configure_styles(self) -> None: + """ Configure the background color for button frame and the button styles. """ + style = ttk.Style() + style.configure("display.TFrame", background='#d3d3d3') + style.configure("display_selected.TButton", relief="flat", background="#bedaf1") + style.configure("display_deselected.TButton", relief="flat") + self.config(style="display.TFrame") + + def _add_buttons(self) -> dict[T.Literal["mesh", "mask"], ttk.Button]: + """ Add the display buttons to the Faces window. + + Returns + ------- + dict[Literal["mesh", "mask"], tk.Button]] + The display name and its associated button. + """ + frame = ttk.Frame(self) + frame.pack(side=tk.TOP, fill=tk.Y) + buttons = {} + display: T.Literal["mask", "mesh"] + for display in self.key_bindings.values(): + var = tk.BooleanVar() + var.set(False) + self._tk_vars[display] = var + + lookup = "landmarks" if display == "mesh" else display + button = ttk.Button( + frame, + image=get_images().icons[lookup], + command=T.cast( + T.Callable, + lambda t=display: self.on_click(t)), # pyright:ignore[reportArgumentType] + style="display_deselected.TButton") + button.state(["!pressed", "!focus"]) + button.pack() + Tooltip(button, text=self._helptext[display]) + buttons[display] = button + return buttons + + def on_click(self, display: T.Literal["mesh", "mask"]) -> None: + """ Click event for the optional annotation buttons. Loads and unloads the annotations from + the faces viewer. + + Parameters + ---------- + display: Literal["mesh", "mask"] + The display name for the button that has called this event as exists in + :attr:`_buttons` + """ + is_pressed = not self._tk_vars[display].get() + style = "display_selected.TButton" if is_pressed else "display_deselected.TButton" + state = ["pressed", "focus"] if is_pressed else ["!pressed", "!focus"] + btn = self._buttons[display] + btn.configure(style=style) + btn.state(state) + self._tk_vars[display].set(is_pressed) + + +class FacesViewer(tk.Canvas): # pylint:disable=too-many-ancestors + """ The :class:`tkinter.Canvas` that holds the faces viewer section of the Manual Tool. + + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The parent frame for the canvas + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + tk_action_vars: dict + The :class:`tkinter.BooleanVar` objects for selectable optional annotations + as set by the buttons in the :class:`FacesActionsFrame` + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The :class:`~lib.align.DetectedFace` objects for this video + display_frame: :class:`~tools.manual.frameviewer.frame.DisplayFrame` + The section of the Manual Tool that holds the frames viewer + event: :class:`threading.Event` + The threading event object for repeated key press protection + """ + def __init__(self, parent: ttk.Frame, + tk_globals: TkGlobals, + tk_action_vars: dict[T.Literal["mesh", "mask"], tk.BooleanVar], + detected_faces: DetectedFaces, + display_frame: DisplayFrame, + event: Event) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent, + bd=0, + highlightthickness=0, + bg=get_config().user_theme["group_panel"]["panel_background"]) + self.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, anchor=tk.E) + self._sizes = {"tiny": 32, "small": 64, "medium": 96, "large": 128, "extralarge": 192} + + self._globals = tk_globals + self._tk_optional_annotations = tk_action_vars + self._event = event + self._display_frame = display_frame + self._grid = Grid(self, detected_faces) + self._view = Viewport(self, detected_faces.tk_edited) + self._annotation_colors = {"mesh": self.get_muted_color("Mesh"), + "box": self.control_colors["ExtractBox"]} + + ContextMenu(self, detected_faces) + self._bind_mouse_wheel_scrolling() + self._set_tk_callbacks(detected_faces) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def face_size(self) -> int: + """ int: The currently selected thumbnail size in pixels """ + scaling = get_config().scaling_factor + size = self._sizes[self._globals.var_faces_size.get().lower().replace(" ", "")] + scaled = size * scaling + return int(round(scaled / 2) * 2) + + @property + def viewport(self) -> Viewport: + """ :class:`~tools.manual.faceviewer.viewport.Viewport`: The viewport area of the + faces viewer. """ + return self._view + + @property + def layout(self) -> Grid: + """ :class:`Grid`: The grid for the current :class:`FacesViewer`. """ + return self._grid + + @property + def optional_annotations(self) -> dict[T.Literal["mesh", "mask"], bool]: + """ dict[Literal["mesh", "mask"], bool]: The values currently set for the + selectable optional annotations. """ + return {opt: val.get() for opt, val in self._tk_optional_annotations.items()} + + @property + def selected_mask(self) -> str: + """ str: The currently selected mask from the display frame control panel. """ + return self._display_frame.tk_selected_mask.get().lower() + + @property + def control_colors(self) -> dict[str, str]: + """dict[str, str]: The frame Editor name as key with the current user selected hex code as + value. """ + return ({key: val.get() for key, val in self._display_frame.tk_control_colors.items()}) + + # << CALLBACK FUNCTIONS >> # + def _set_tk_callbacks(self, detected_faces: DetectedFaces): + """ Set the tkinter variable call backs. + + Parameters + ---------- + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The Manual Tool's Detected Faces object + + Redraw the grid on a face size change, a filter change or on add/remove faces. + Updates the annotation colors when user amends a color drop down. + Updates the mask type when the user changes the selected mask types + Toggles the face viewer annotations on an optional annotation button press. + """ + for strvar in (self._globals.var_faces_size, self._globals.var_filter_mode): + strvar.trace_add( + "write", + lambda *e, v=strvar: self.refresh_grid(v)) # pyright:ignore[reportArgumentType] + boolvar = detected_faces.tk_face_count_changed + boolvar.trace_add("write", + lambda *e, v=boolvar: self.refresh_grid(v, retain_position=True)) + + self._display_frame.tk_control_colors["Mesh"].trace_add( + "write", lambda *e: self._update_mesh_color()) + self._display_frame.tk_control_colors["ExtractBox"].trace_add( + "write", lambda *e: self._update_box_color()) + self._display_frame.tk_selected_mask.trace_add( + "write", lambda *e: self._update_mask_type()) + + for opt, var in self._tk_optional_annotations.items(): + var.trace_add("write", + lambda *e, o=opt: self._toggle_annotations( + o)) # pyright:ignore[reportArgumentType] + + self.bind("", lambda *e: self._view.update()) + + def refresh_grid(self, trigger_var: tk.BooleanVar, retain_position: bool = False) -> None: + """ Recalculate the full grid and redraw. Used when the active filter pull down is used, a + face has been added or removed, or the face thumbnail size has changed. + + Parameters + ---------- + trigger_var: :class:`tkinter.BooleanVar` + The tkinter variable that has triggered the grid update. Will either be the variable + indicating that the face size have been changed, or the variable indicating that the + selected filter mode has been changed. + retain_position: bool, optional + ``True`` if the grid should be set back to the position it was at after the update has + been processed, otherwise ``False``. Default: ``False``. + """ + if not trigger_var.get(): + return + size_change = isinstance(trigger_var, tk.StringVar) + move_to = self.yview()[0] if retain_position else 0.0 + self._grid.update() + if move_to != 0.0: + self.yview_moveto(move_to) + if size_change: + self._view.reset() + self._view.update(refresh_annotations=retain_position) + if not size_change: + trigger_var.set(False) + + def _update_mask_type(self) -> None: + """ Update the displayed mask in the :class:`FacesViewer` canvas when the user changes + the mask type. """ + state: T.Literal["normal", "hidden"] + state = "normal" if self.optional_annotations["mask"] else "hidden" + logger.debug("Updating mask type: (mask_type: %s. state: %s)", self.selected_mask, state) + self._view.toggle_mask(state, self.selected_mask) + + # << MOUSE HANDLING >> + def _bind_mouse_wheel_scrolling(self) -> None: + """ Bind mouse wheel to scroll the :class:`FacesViewer` canvas. """ + if platform.system() == "Linux": + self.bind("", self._scroll) + self.bind("", self._scroll) + else: + self.bind("", self._scroll) + + def _scroll(self, event: tk.Event) -> None: + """ Handle mouse wheel scrolling over the :class:`FacesViewer` canvas. + + Update is run in a thread to avoid repeated scroll actions stacking and locking up the GUI. + + Parameters + ---------- + event: :class:`tkinter.Event` + The event fired by the mouse scrolling + """ + if self._event.is_set(): + logger.trace("Update already running. " # type:ignore[attr-defined] + "Aborting repeated mousewheel") + return + if platform.system() == "Darwin": + adjust = event.delta + elif platform.system() == "Windows": + adjust = int(event.delta / 120) + elif event.num == 5: + adjust = -1 + else: + adjust = 1 + self._event.set() + thread = Thread(target=self.canvas_scroll, args=(-1 * adjust, "units", self._event)) + thread.start() + + def canvas_scroll(self, amount: int, units: T.Literal["pages", "units"], event: Event) -> None: + """ Scroll the canvas on an up/down or page-up/page-down key press. + + Parameters + ---------- + amount: int + The number of units to scroll the canvas + units: Literal["pages", "units"] + The unit type to scroll by + event: :class:`threading.Event` + event to indicate to the calling process whether the scroll is still updating + """ + self.yview_scroll(int(amount), units) + self._view.update() + self._view.hover_box.on_hover(None) + event.clear() + + # << OPTIONAL ANNOTATION METHODS >> # + def _update_mesh_color(self) -> None: + """ Update the mesh color when user updates the control panel. """ + color = self.get_muted_color("Mesh") + if self._annotation_colors["mesh"] == color: + return + highlight_color = self.control_colors["Mesh"] + + self.itemconfig("viewport_polygon", outline=color) + self.itemconfig("viewport_line", fill=color) + self.itemconfig("active_mesh_polygon", outline=highlight_color) + self.itemconfig("active_mesh_line", fill=highlight_color) + self._annotation_colors["mesh"] = color + + def _update_box_color(self) -> None: + """ Update the active box color when user updates the control panel. """ + color = self.control_colors["ExtractBox"] + + if self._annotation_colors["box"] == color: + return + self.itemconfig("active_highlighter", outline=color) + self._annotation_colors["box"] = color + + def get_muted_color(self, color_key: str) -> str: + """ Creates a muted version of the given annotation color for non-active faces. + + Parameters + ---------- + color_key: str + The annotation key to obtain the color for from :attr:`control_colors` + + Returns + ------- + str + The hex color code of the muted color + """ + scale = 0.65 + hls = np.array(colorsys.rgb_to_hls(*hex_to_rgb(self.control_colors[color_key]))) + scale = (1 - scale) + 1 if hls[1] < 120 else scale + hls[1] = max(0., min(256., scale * hls[1])) + rgb = np.clip(np.rint(colorsys.hls_to_rgb(*hls)).astype("uint8"), 0, 255) + retval = rgb_to_hex(rgb) + return retval + + def _toggle_annotations(self, annotation: T.Literal["mesh", "mask"]) -> None: + """ Toggle optional annotations on or off after the user depresses an optional button. + + Parameters + ---------- + annotation: ["mesh", "mask"] + The optional annotation to toggle on or off + """ + state: T.Literal["hidden", "normal"] + state = "normal" if self.optional_annotations[annotation] else "hidden" + logger.debug("Toggle annotation: (annotation: %s, state: %s)", annotation, state) + if annotation == "mesh": + self._view.toggle_mesh(state) + if annotation == "mask": + self._view.toggle_mask(state, self.selected_mask) + + +class Grid(): + """ Holds information on the current filtered grid layout. + + The grid keeps information on frame indices, face indices, x and y positions and detected face + objects laid out in a numpy array to reflect the current full layout of faces within the face + viewer based on the currently selected filter and face thumbnail size. + + Parameters + ---------- + canvas: :class:`~FacesViewer` + The :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The :class:`~lib.align.DetectedFace` objects for this video + """ + def __init__(self, canvas: FacesViewer, detected_faces: DetectedFaces): + logger.debug(parse_class_init(locals())) + self._canvas = canvas + self._detected_faces = detected_faces + self._raw_indices = detected_faces.filter.raw_indices + self._frames_list = detected_faces.filter.frames_list + + self._is_valid: bool = False + self._face_size: int = 0 + self._grid: np.ndarray | None = None + self._display_faces: np.ndarray | None = None + + self._canvas.update_idletasks() + self._canvas.create_rectangle(0, 0, 0, 0, tags=["backdrop"]) + self.update() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def face_size(self) -> int: + """ int: The pixel size of each thumbnail within the face viewer. """ + return self._face_size + + @property + def is_valid(self) -> bool: + """ bool: ``True`` if the current filter means that the grid holds faces. ``False`` if + there are no faces displayed in the grid. """ + return self._is_valid + + @property + def columns_rows(self) -> tuple[int, int]: + """ tuple: the (`columns`, `rows`) required to hold all display images. """ + if not self._is_valid: + return (0, 0) + assert self._grid is not None + retval = tuple(reversed(self._grid.shape[1:])) + return T.cast(tuple[int, int], retval) + + @property + def dimensions(self) -> tuple[int, int]: + """ tuple: The (`width`, `height`) required to hold all display images. """ + if self._is_valid: + assert self._grid is not None + retval = tuple(dim * self._face_size for dim in reversed(self._grid.shape[1:])) + assert len(retval) == 2 + else: + retval = (0, 0) + return T.cast(tuple[int, int], retval) + + @property + def _visible_row_indices(self) -> tuple[int, int]: + """tuple: A 1 dimensional array of the (`top_row_index`, `bottom_row_index`) of the grid + currently in the viewable area. + """ + height = self.dimensions[1] + visible = (max(0, floor(height * self._canvas.yview()[0]) - self._face_size), + ceil(height * self._canvas.yview()[1])) + logger.trace("height: %s, yview: %s, face_size: %s, " # type:ignore[attr-defined] + "visible: %s", height, self._canvas.yview(), self._face_size, visible) + assert self._grid is not None + y_points = self._grid[3, :, 1] + top = np.searchsorted(y_points, visible[0], side="left") + bottom = np.searchsorted(y_points, visible[1], side="right") + return int(top), int(bottom) + + @property + def visible_area(self) -> tuple[np.ndarray, np.ndarray]: + """tuple[:class:`numpy.ndarray`, :class:`numpy.ndarray`]: Tuple containing 2 arrays. + + 1st array contains an array of shape (`4`, `rows`, `columns`) corresponding + to the viewable area of the display grid. 1st dimension contains frame indices, 2nd + dimension face indices. The 3rd and 4th dimension contain the x and y position of the top + left corner of the face respectively. + + 2nd array contains :class:`~lib.align.DetectedFace` objects laid out in (rows, columns) + + Any locations that are not populated by a face will have a frame and face index of -1 + """ + if not self._is_valid: + retval: tuple[np.ndarray, np.ndarray] = np.zeros((4, 0, 0)), np.zeros((0, 0)) + else: + assert self._grid is not None + assert self._display_faces is not None + top, bottom = self._visible_row_indices + retval = self._grid[:, top:bottom, :], self._display_faces[top:bottom, :] + logger.trace([r if r is None else r.shape for r in retval]) # type:ignore[attr-defined] + return retval + + def y_coord_from_frame(self, frame_index: int) -> int: + """ Return the y coordinate for the first face that appears in the given frame. + + Parameters + ---------- + frame_index: int + The frame index to locate in the grid + + Returns + ------- + int + The y coordinate of the first face for the given frame + """ + assert self._grid is not None + return min(self._grid[3][np.where(self._grid[0] == frame_index)]) + + def frame_has_faces(self, frame_index: int) -> bool | np.bool_: + """ Check whether the given frame index contains any faces. + + Parameters + ---------- + frame_index: int + The frame index to locate in the grid + + Returns + ------- + bool + ``True`` if there are faces in the given frame otherwise ``False`` + """ + if not self._is_valid: + return False + assert self._grid is not None + return np.any(self._grid[0] == frame_index) + + def update(self) -> None: + """ Update the underlying grid. + + Called on initialization, on a filter change or on add/remove faces. Recalculates the + underlying grid for the current filter view and updates the attributes :attr:`_grid`, + :attr:`_display_faces`, :attr:`_raw_indices`, :attr:`_frames_list` and :attr:`is_valid` + """ + self._face_size = self._canvas.face_size + self._raw_indices = self._detected_faces.filter.raw_indices + self._frames_list = self._detected_faces.filter.frames_list + self._get_grid() + self._get_display_faces() + self._canvas.coords("backdrop", 0, 0, *self.dimensions) + self._canvas.configure(scrollregion=self._canvas.bbox("backdrop")) + self._canvas.yview_moveto(0.0) + + def _get_grid(self) -> None: + """ Get the grid information for faces currently displayed in the :class:`FacesViewer`. + and set to :attr:`_grid`. Creates a numpy array of shape (`4`, `rows`, `columns`) + corresponding to the display grid. 1st dimension contains frame indices, 2nd dimension face + indices. The 3rd and 4th dimension contain the x and y position of the top left corner of + the face respectively. + + Any locations that are not populated by a face will have a frame and face index of -1""" + labels = self._get_labels() + if not self._is_valid: + logger.debug("Setting grid to None for no faces.") + self._grid = None + return + assert labels is not None + x_coords = np.linspace(0, + labels.shape[2] * self._face_size, + num=labels.shape[2], + endpoint=False, + dtype="int") + y_coords = np.linspace(0, + labels.shape[1] * self._face_size, + num=labels.shape[1], + endpoint=False, + dtype="int") + self._grid = np.array((*labels, *np.meshgrid(x_coords, y_coords)), dtype="int") + logger.debug(self._grid.shape) + + def _get_labels(self) -> np.ndarray | None: + """ Get the frame and face index for each grid position for the current filter. + + Returns + ------- + :class:`numpy.ndarray` | None + Array of dimensions (2, rows, columns) corresponding to the display grid, with frame + index as the first dimension and face index within the frame as the 2nd dimension. + + Any remaining placeholders at the end of the grid which are not populated with a face + are given the index -1 + """ + face_count = len(self._raw_indices["frame"]) + self._is_valid = face_count != 0 + if not self._is_valid: + return None + columns = self._canvas.winfo_width() // self._face_size + rows = ceil(face_count / columns) + remainder = face_count % columns + padding = [] if remainder == 0 else [-1 for _ in range(columns - remainder)] + labels = np.array((self._raw_indices["frame"] + padding, + self._raw_indices["face"] + padding), + dtype="int").reshape((2, rows, columns)) + logger.debug("face-count: %s, columns: %s, rows: %s, remainder: %s, padding: %s, labels " + "shape: %s", face_count, columns, rows, remainder, padding, labels.shape) + return labels + + def _get_display_faces(self): + """ Get the detected faces for the current filter, arrange to grid and set to + :attr:`_display_faces`. This is an array of dimensions (rows, columns) corresponding to the + display grid, containing the corresponding :class:`lib.align.DetectFace` object + + Any remaining placeholders at the end of the grid which are not populated with a face are + replaced with ``None``""" + if not self._is_valid: + logger.debug("Setting display_faces to None for no faces.") + self._display_faces = None + return + current_faces = self._detected_faces.current_faces + columns, rows = self.columns_rows + face_count = len(self._raw_indices["frame"]) + padding = [None for _ in range(face_count, columns * rows)] + self._display_faces = np.array( + [None if idx is None or face_idx is None else current_faces[idx][face_idx] + for idx, face_idx + in zip(self._raw_indices["frame"] + padding, self._raw_indices["face"] + padding)], + dtype="object").reshape(rows, columns) + logger.debug("faces: (shape: %s, dtype: %s)", + self._display_faces.shape, self._display_faces.dtype) + + def transport_index_from_frame(self, frame_index: int) -> int | None: + """ Return the main frame's transport index for the given frame index based on the current + filter criteria. + + Parameters + ---------- + frame_index: int + The absolute index for the frame within the full frames list + + Returns + ------- + int | None + The index of the requested frame within the filtered frames view. None if no valid + frames + """ + retval = self._frames_list.index(frame_index) if frame_index in self._frames_list else None + logger.trace("frame_index: %s, transport_index: %s", # type:ignore[attr-defined] + frame_index, retval) + return retval + + +class ContextMenu(): # pylint:disable=too-few-public-methods + """ Enables a right click context menu for the + :class:`~tools.manual.faceviewer.frame.FacesViewer`. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The :class:`FacesViewer` canvas + detected_faces: :class:`~tools.manual.detected_faces` + The manual tool's detected faces class + """ + def __init__(self, canvas, detected_faces): + logger.debug("Initializing: %s (canvas: %s, detected_faces: %s)", + self.__class__.__name__, canvas, detected_faces) + self._canvas = canvas + self._detected_faces = detected_faces + self._menu = RightClickMenu(["Delete Face"], [self._delete_face]) + self._frame_index = None + self._face_index = None + self._canvas.bind("" if platform.system() == "Darwin" else "", + self._pop_menu) + logger.debug("Initialized: %s", self.__class__.__name__) + + def _pop_menu(self, event): + """ Pop up the context menu on a right click mouse event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The mouse event that has triggered the pop up menu + """ + frame_idx, face_idx = self._canvas.viewport.face_from_point( + self._canvas.canvasx(event.x), self._canvas.canvasy(event.y))[:2] + if frame_idx == -1: + logger.trace("No valid item under mouse") # type:ignore[attr-defined] + self._frame_index = self._face_index = None + return + self._frame_index = frame_idx + self._face_index = face_idx + logger.trace("Popping right click menu") # type:ignore[attr-defined] + self._menu.popup(event) + + def _delete_face(self): + """ Delete the selected face on a right click mouse delete action. """ + logger.trace("Right click delete received. frame_id: %s, " # type:ignore[attr-defined] + "face_id: %s", self._frame_index, self._face_index) + self._detected_faces.update.delete(self._frame_index, self._face_index) + self._frame_index = self._face_index = None + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/faceviewer/interact.py b/tools/manual/faceviewer/interact.py new file mode 100644 index 0000000000..d083f33cb1 --- /dev/null +++ b/tools/manual/faceviewer/interact.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +""" Handles the viewport area for mouse hover actions and the active frame """ +from __future__ import annotations +import logging +import tkinter as tk +import typing as T +from dataclasses import dataclass + +import numpy as np + +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from lib.align import DetectedFace + from .viewport import Viewport + +logger = logging.getLogger(__name__) + + +class HoverBox(): + """ Handle the current mouse location when over the :class:`Viewport`. + + Highlights the face currently underneath the cursor and handles actions when clicking + on a face. + + Parameters + ---------- + viewport: :class:`Viewport` + The viewport object for the :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas + """ + def __init__(self, viewport: Viewport) -> None: + logger.debug(parse_class_init(locals())) + self._viewport = viewport + self._canvas = viewport._canvas + self._grid = viewport._canvas.layout + self._globals = viewport._canvas._globals + self._navigation = viewport._canvas._display_frame.navigation + self._box = self._canvas.create_rectangle(0., # type:ignore[call-overload] + 0., + float(self._size), + float(self._size), + outline="#0000ff", + width=2, + state="hidden", + fill="#0000ff", + stipple="gray12", + tags="hover_box") + self._current_frame_index = None + self._current_face_index = None + self._canvas.bind("", lambda e: self._clear()) + self._canvas.bind("", self.on_hover) + self._canvas.bind("", lambda e: self._select_frame()) + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def _size(self) -> int: + """ int: the currently set viewport face size in pixels. """ + return self._viewport.face_size + + def on_hover(self, event: tk.Event | None) -> None: + """ Highlight the face and set the mouse cursor for the mouse's current location. + + Parameters + ---------- + event: :class:`tkinter.Event` or ``None`` + The tkinter mouse event. Provides the current location of the mouse cursor. If ``None`` + is passed as the event (for example when this function is being called outside of a + mouse event) then the location of the cursor will be calculated + """ + if event is None: + pnts = np.array((self._canvas.winfo_pointerx(), self._canvas.winfo_pointery())) + pnts -= np.array((self._canvas.winfo_rootx(), self._canvas.winfo_rooty())) + else: + pnts = np.array((event.x, event.y)) + + coords = (int(self._canvas.canvasx(pnts[0])), int(self._canvas.canvasy(pnts[1]))) + face = self._viewport.face_from_point(*coords) + frame_idx, face_idx = face[:2] + + if frame_idx == self._current_frame_index and face_idx == self._current_face_index: + return + + is_zoomed = self._globals.is_zoomed + if (-1 in face or (frame_idx == self._globals.frame_index + and (not is_zoomed or + (is_zoomed and face_idx == self._globals.face_index)))): + self._clear() + self._canvas.config(cursor="") + self._current_frame_index = None + self._current_face_index = None + return + + logger.debug("Viewport hover: frame_idx: %s, face_idx: %s", frame_idx, face_idx) + + self._canvas.config(cursor="hand2") + self._highlight(face[2:]) + self._current_frame_index = frame_idx + self._current_face_index = face_idx + + def _clear(self) -> None: + """ Hide the hover box when the mouse is not over a face. """ + if self._canvas.itemcget(self._box, "state") != "hidden": + self._canvas.itemconfig(self._box, state="hidden") + + def _highlight(self, top_left: np.ndarray) -> None: + """ Display the hover box around the face that the mouse is currently over. + + Parameters + ---------- + top_left: :class:`np.ndarray` + The top left point of the highlight box location + """ + coords = (*top_left, *[x + self._size for x in top_left]) + self._canvas.coords(self._box, *coords) + self._canvas.itemconfig(self._box, state="normal") + self._canvas.tag_raise(self._box) + + def _select_frame(self) -> None: + """ Select the face and the subsequent frame (in the editor view) when a face is clicked + on in the :class:`Viewport`. """ + frame_id = self._current_frame_index + is_zoomed = self._globals.is_zoomed + logger.debug("Face clicked. Global frame index: %s, Current frame_id: %s, is_zoomed: %s", + self._globals.frame_index, frame_id, is_zoomed) + if frame_id is None or (frame_id == self._globals.frame_index and not is_zoomed): + return + face_idx = self._current_face_index if is_zoomed else 0 + self._globals.set_face_index(face_idx) + transport_id = self._grid.transport_index_from_frame(frame_id) + logger.trace("frame_index: %s, transport_id: %s, face_idx: %s", + frame_id, transport_id, face_idx) + if transport_id is None: + return + self._navigation.stop_playback() + self._globals.var_transport_index.set(transport_id) + self._viewport.move_active_to_top() + self.on_hover(None) + + +@dataclass +class Asset: + """ Holds all of the display assets identifiers for the active frame's face viewer objects + + Parameters + ---------- + images: list[int] + Indices for a frame's tk image ids displayed in the active frame + meshes: list[dict[Literal["polygon", "line"], list[int]]] + Indices for a frame's tk line/polygon object ids displayed in the active frame + faces: list[:class:`~lib.align.detected_faces.DetectedFace`] + DetectedFace objects that exist in the current frame + boxes: list[int] + Indices for a frame's bounding box object ids displayed in the active frame + """ + images: list[int] + """list[int]: Indices for a frame's tk image ids displayed in the active frame""" + meshes: list[dict[T.Literal["polygon", "line"], list[int]]] + """list[dict[Literal["polygon", "line"], list[int]]]: Indices for a frame's tk line/polygon + object ids displayed in the active frame""" + faces: list[DetectedFace] + """list[:class:`~lib.align.detected_faces.DetectedFace`]: DetectedFace objects that exist + in the current frame""" + boxes: list[int] + """list[int]: Indices for a frame's bounding box object ids displayed in the active + frame""" + + +class ActiveFrame(): + """ Handles the display of faces and annotations for the currently active frame. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas + tk_edited_variable: :class:`tkinter.BooleanVar` + The tkinter callback variable indicating that a face has been edited + """ + def __init__(self, viewport: Viewport, tk_edited_variable: tk.BooleanVar) -> None: + logger.debug(parse_class_init(locals())) + self._objects = viewport._objects + self._viewport = viewport + self._grid = viewport._grid + self._tk_faces = viewport._tk_faces + self._canvas = viewport._canvas + self._globals = viewport._canvas._globals + self._navigation = viewport._canvas._display_frame.navigation + self._last_execution: dict[T.Literal["frame_index", "size"], + int] = {"frame_index": -1, "size": viewport.face_size} + self._tk_vars: dict[T.Literal["selected_editor", "edited"], + tk.StringVar | tk.BooleanVar] = { + "selected_editor": self._canvas._display_frame.tk_selected_action, + "edited": tk_edited_variable} + self._assets: Asset = Asset([], [], [], []) + + self._globals.var_update_active_viewport.trace_add("write", + lambda *e: self._reload_callback()) + tk_edited_variable.trace_add("write", lambda *e: self._update_on_edit()) + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def frame_index(self) -> int: + """ int: The frame index of the currently displayed frame. """ + return self._globals.frame_index + + @property + def current_frame(self) -> np.ndarray: + """ :class:`numpy.ndarray`: A BGR version of the frame currently being displayed. """ + return self._globals.current_frame.image + + @property + def _size(self) -> int: + """ int: The size of the thumbnails displayed in the viewport, in pixels. """ + return self._viewport.face_size + + @property + def _optional_annotations(self) -> dict[T.Literal["mesh", "mask"], bool]: + """ dict[Literal["mesh", "mask"], bool]: The currently selected optional + annotations """ + return self._canvas.optional_annotations + + def _reload_callback(self) -> None: + """ If a frame has changed, triggering the variable, then update the active frame. Return + having done nothing if the variable is resetting. """ + if self._globals.var_update_active_viewport.get(): + self.reload_annotations() + + def reload_annotations(self) -> None: + """ Handles the reloading of annotations for the currently active faces. + + Highlights the faces within the viewport of those faces that exist in the currently + displaying frame. Applies annotations based on the optional annotations and current + editor selections. + """ + logger.trace("Reloading annotations") # type:ignore[attr-defined] + if self._assets.images: + self._clear_previous() + + self._set_active_objects() + self._check_active_in_view() + + if not self._assets.images: + logger.trace("No active faces. Returning") # type:ignore[attr-defined] + self._last_execution["frame_index"] = self.frame_index + return + + if self._last_execution["frame_index"] != self.frame_index: + self.move_to_top() + self._create_new_boxes() + + self._update_face() + self._canvas.tag_raise("active_highlighter") + self._globals.var_update_active_viewport.set(False) + self._last_execution["frame_index"] = self.frame_index + + def _clear_previous(self) -> None: + """ Reverts the previously selected annotations to their default state. """ + logger.trace("Clearing previous active frame") # type:ignore[attr-defined] + self._canvas.itemconfig("active_highlighter", state="hidden") + + for key in T.get_args(T.Literal["polygon", "line"]): + tag = f"active_mesh_{key}" + self._canvas.itemconfig(tag, **self._viewport.mesh_kwargs[key], width=1) + self._canvas.dtag(tag) + + if self._viewport.selected_editor == "mask" and not self._optional_annotations["mask"]: + for name, tk_face in self._tk_faces.items(): + if name.startswith(f"{self._last_execution['frame_index']}_"): + tk_face.update_mask(None) + + def _set_active_objects(self) -> None: + """ Collect the objects that exist in the currently active frame from the main grid. """ + if self._grid.is_valid: + rows, cols = np.where(self._objects.visible_grid[0] == self.frame_index) + logger.trace("Setting active objects: (rows: %s, " # type:ignore[attr-defined] + "columns: %s)", rows, cols) + self._assets.images = self._objects.images[rows, cols].tolist() + self._assets.meshes = self._objects.meshes[rows, cols].tolist() + self._assets.faces = self._objects.visible_faces[rows, cols].tolist() + else: + logger.trace("No valid grid. Clearing active objects") # type:ignore[attr-defined] + self._assets.images = [] + self._assets.meshes = [] + self._assets.faces = [] + + def _check_active_in_view(self) -> None: + """ If the frame has changed, there are faces in the frame, but they don't appear in the + viewport, then bring the active faces to the top of the viewport. """ + if (not self._assets.images and + self._last_execution["frame_index"] != self.frame_index and + self._grid.frame_has_faces(self.frame_index)): + y_coord = self._grid.y_coord_from_frame(self.frame_index) + logger.trace("Active not in view. Moving to: %s", y_coord) # type:ignore[attr-defined] + self._canvas.yview_moveto(y_coord / self._canvas.bbox("backdrop")[3]) + self._viewport.update() + + def move_to_top(self) -> None: + """ Move the currently selected frame's faces to the top of the viewport if they are moving + off the bottom of the viewer. """ + height = self._canvas.bbox("backdrop")[3] + bot = int(self._canvas.coords(self._assets.images[-1])[1] + self._size) + + y_top, y_bot = (int(round(pnt * height)) for pnt in self._canvas.yview()) + + if y_top < bot < y_bot: # bottom face is still in fully visible area + logger.trace("Active faces in frame. Returning") # type:ignore[attr-defined] + return + + top = int(self._canvas.coords(self._assets.images[0])[1]) + if y_top == top: + logger.trace("Top face already on top row. Returning") # type:ignore[attr-defined] + return + + if self._canvas.winfo_height() > self._size: + logger.trace("Viewport taller than single face height. " # type:ignore[attr-defined] + "Moving Active faces to top: %s", top) + self._canvas.yview_moveto(top / height) + self._viewport.update() + elif self._canvas.winfo_height() <= self._size and y_top != top: + logger.trace("Viewport shorter than single face height. " # type:ignore[attr-defined] + "Moving Active faces to top: %s", top) + self._canvas.yview_moveto(top / height) + self._viewport.update() + + def _create_new_boxes(self) -> None: + """ The highlight boxes (border around selected faces) are the only additional annotations + that are required for the highlighter. If more faces are displayed in the current frame + than highlight boxes are available, then new boxes are created to accommodate the + additional faces. """ + new_boxes_count = max(0, len(self._assets.images) - len(self._assets.boxes)) + if new_boxes_count == 0: + return + logger.debug("new_boxes_count: %s", new_boxes_count) + for _ in range(new_boxes_count): + box = self._canvas.create_rectangle(0., # type:ignore[call-overload] + 0., + float(self._viewport.face_size), + float(self._viewport.face_size), + outline="#00FF00", + width=2, + state="hidden", + tags=["active_highlighter"]) + logger.trace("Created new highlight_box: %s", box) # type:ignore[attr-defined] + self._assets.boxes.append(box) + + def _update_on_edit(self) -> None: + """ Update the active faces on a frame edit. """ + if not self._tk_vars["edited"].get(): + return + self._set_active_objects() + self._update_face() + assert isinstance(self._tk_vars["edited"], tk.BooleanVar) + self._tk_vars["edited"].set(False) + + def _update_face(self) -> None: + """ Update the highlighted annotations for faces in the currently selected frame. """ + for face_idx, (image_id, mesh_ids, box_id, det_face), in enumerate( + zip(self._assets.images, + self._assets.meshes, + self._assets.boxes, + self._assets.faces)): + if det_face is None: + continue + top_left = self._canvas.coords(image_id) + coords = [*top_left, *[x + self._size for x in top_left]] + tk_face = self._viewport.get_tk_face(self.frame_index, face_idx, det_face) + self._canvas.itemconfig(image_id, image=tk_face.photo) + self._show_box(box_id, coords) + self._show_mesh(mesh_ids, face_idx, det_face, top_left) + self._last_execution["size"] = self._viewport.face_size + + def _show_box(self, item_id: int, coordinates: list[float]) -> None: + """ Display the highlight box around the given coordinates. + + Parameters + ---------- + item_id: int + The tkinter canvas object identifier for the highlight box + coordinates: list[float] + The (x, y, x1, y1) coordinates of the top left corner of the box + """ + self._canvas.coords(item_id, *coordinates) + self._canvas.itemconfig(item_id, state="normal") + + def _show_mesh(self, + mesh_ids: dict[T.Literal["polygon", "line"], list[int]], + face_index: int, + detected_face: DetectedFace, + top_left: list[float]) -> None: + """ Display the mesh annotation for the given face, at the given location. + + Parameters + ---------- + mesh_ids: dict[Literal["polygon", "line"], list[int]] + Dictionary containing the `polygon` and `line` tkinter canvas identifiers that make up + the mesh for the given face + face_index: int + The face index within the frame for the given face + detected_face: :class:`~lib.align.DetectedFace` + The detected face object that contains the landmarks for generating the mesh + top_left: list[float] + The (x, y) top left co-ordinates of the mesh's bounding box + """ + state = "normal" if (self._tk_vars["selected_editor"].get() != "Mask" or + self._optional_annotations["mesh"]) else "hidden" + kwargs: dict[T.Literal["polygon", "line"], dict[str, T.Any]] = { + "polygon": {"fill": "", "width": 2, "outline": self._canvas.control_colors["Mesh"]}, + "line": {"fill": self._canvas.control_colors["Mesh"], "width": 2}} + + assert isinstance(self._tk_vars["edited"], tk.BooleanVar) + edited = (self._tk_vars["edited"].get() and + self._tk_vars["selected_editor"].get() not in ("Mask", "View")) + landmarks = self._viewport.get_landmarks(self.frame_index, + face_index, + detected_face, + top_left, + edited) + for key, kwarg in kwargs.items(): + if key not in mesh_ids: + continue + for idx, mesh_id in enumerate(mesh_ids[key]): + self._canvas.coords(mesh_id, *landmarks[key][idx].flatten()) + self._canvas.itemconfig(mesh_id, state=state, **kwarg) + self._canvas.addtag_withtag(f"active_mesh_{key}", mesh_id) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/faceviewer/viewport.py b/tools/manual/faceviewer/viewport.py new file mode 100644 index 0000000000..27ce490aff --- /dev/null +++ b/tools/manual/faceviewer/viewport.py @@ -0,0 +1,792 @@ +#!/usr/bin/env python3 +""" Handles the visible area of the :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas. """ +from __future__ import annotations +import logging +import tkinter as tk +import typing as T + +import cv2 +import numpy as np +from PIL import Image, ImageTk + +from lib.align import AlignedFace, LANDMARK_PARTS, LandmarkType +from lib.logger import parse_class_init +from lib.utils import get_module_objects + +from .interact import ActiveFrame, HoverBox + +if T.TYPE_CHECKING: + from lib.align import CenteringType, DetectedFace + from .frame import FacesViewer + +logger = logging.getLogger(__name__) + + +class Viewport(): + """ Handles the display of faces and annotations in the currently viewable area of the canvas. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas + tk_edited_variable: :class:`tkinter.BooleanVar` + The variable that indicates that a face has been edited + """ + def __init__(self, canvas: FacesViewer, tk_edited_variable: tk.BooleanVar) -> None: + logger.debug(parse_class_init(locals())) + self._canvas = canvas + self._grid = canvas.layout + self._centering: CenteringType = "face" + self._tk_selected_editor = canvas._display_frame.tk_selected_action + self._landmarks: dict[str, dict[T.Literal["polygon", "line"], list[np.ndarray]]] = {} + self._tk_faces: dict[str, TKFace] = {} + self._objects = VisibleObjects(self) + self._hoverbox = HoverBox(self) + self._active_frame = ActiveFrame(self, tk_edited_variable) + self._tk_selected_editor.trace( + "w", lambda *e: self._active_frame.reload_annotations()) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def face_size(self) -> int: + """ int: The pixel size of each thumbnail """ + return self._grid.face_size + + @property + def mesh_kwargs(self) -> dict[T.Literal["polygon", "line"], dict[str, T.Any]]: + """ dict[Literal["polygon", "line"], str | int]: Dynamic keyword arguments defining the + color and state for the objects that make up a single face's mesh annotation based on the + current user selected options. Values are the keyword arguments for that given type. """ + state = "normal" if self._canvas.optional_annotations["mesh"] else "hidden" + color = self._canvas.control_colors["Mesh"] + return {"polygon": {"fill": "", "outline": color, "state": state}, + "line": {"fill": color, "state": state}} + + @property + def hover_box(self) -> HoverBox: + """ :class:`HoverBox`: The hover box for the viewport. """ + return self._hoverbox + + @property + def selected_editor(self) -> str: + """ str: The currently selected editor. """ + return self._tk_selected_editor.get().lower() + + def toggle_mesh(self, state: T.Literal["hidden", "normal"]) -> None: + """ Toggles the mesh optional annotations on and off. + + Parameters + ---------- + state: Literal["hidden", "normal"] + The state to set the mesh annotations to + """ + logger.debug("Toggling mesh annotations to: %s", state) + self._canvas.itemconfig("viewport_mesh", state=state) + self.update() + + def toggle_mask(self, state: T.Literal["hidden", "normal"], mask_type: str) -> None: + """ Toggles the mask optional annotation on and off. + + Parameters + ---------- + state: Literal["hidden", "normal"] + Whether the mask should be displayed or hidden + mask_type: str + The type of mask to overlay onto the face + """ + logger.debug("Toggling mask annotations to: %s. mask_type: %s", state, mask_type) + for (frame_idx, face_idx), det_face in zip( + self._objects.visible_grid[:2].transpose(1, 2, 0).reshape(-1, 2), + self._objects.visible_faces.flatten()): + if frame_idx == -1: + continue + + key = "_".join([str(frame_idx), str(face_idx)]) + mask = None if state == "hidden" else self._obtain_mask(det_face, mask_type) + self._tk_faces[key].update_mask(mask) + self.update() + + @classmethod + def _obtain_mask(cls, detected_face: DetectedFace, mask_type: str) -> np.ndarray | None: + """ Obtain the mask for the correct "face" centering that is used in the thumbnail display. + + Parameters + ----------- + detected_face: :class:`lib.align.DetectedFace` + The Detected Face object to obtain the mask for + mask_type: str + The type of mask to obtain + + Returns + ------- + :class:`numpy.ndarray` or ``None`` + The single channel mask of requested mask type, if it exists, otherwise ``None`` + """ + mask = detected_face.mask.get(mask_type) + if not mask: + return None + if mask.stored_centering != "face": + face = AlignedFace(detected_face.landmarks_xy) + mask.set_sub_crop(face.pose.offset[mask.stored_centering], + face.pose.offset["face"], + centering="face") + return mask.mask.squeeze() + + def reset(self) -> None: + """ Reset all the cached objects on a face size change. """ + self._landmarks = {} + self._tk_faces = {} + + def update(self, refresh_annotations: bool = False) -> None: + """ Update the viewport. + + Parameters + ---------- + refresh_annotations: bool, optional + ``True`` if mesh annotations should be re-calculated otherwise ``False``. + Default: ``False`` + + Obtains the objects that are currently visible. Updates the visible area of the canvas + and reloads the active frame's annotations. """ + self._objects.update() + self._update_viewport(refresh_annotations) + self._active_frame.reload_annotations() + + def _update_viewport(self, refresh_annotations: bool) -> None: + """ Update the viewport + + Parameters + ---------- + refresh_annotations: bool + ``True`` if mesh annotations should be re-calculated otherwise ``False`` + + Clear out cached objects that are not currently in view. Populate the cache for any + faces that are now in view. Populate the correct face image and annotations for each + object in the viewport based on current location. If optional mesh annotations are + enabled, then calculates newly displayed meshes. """ + if not self._grid.is_valid: + return + self._discard_tk_faces() + + for collection in zip(self._objects.visible_grid.transpose(1, 2, 0), + self._objects.images, + self._objects.meshes, + self._objects.visible_faces): + for (frame_idx, face_idx, pnt_x, pnt_y), image_id, mesh_ids, face in zip(*collection): + if frame_idx == self._active_frame.frame_index and not refresh_annotations: + logger.trace("Skipping active frame: %s", # type:ignore[attr-defined] + frame_idx) + continue + if frame_idx == -1: + logger.trace("Blanking non-existant face") # type:ignore[attr-defined] + self._canvas.itemconfig(image_id, image="") + for area in mesh_ids.values(): + for mesh_id in area: + self._canvas.itemconfig(mesh_id, state="hidden") + continue + + tk_face = self.get_tk_face(frame_idx, face_idx, face) + self._canvas.itemconfig(image_id, image=tk_face.photo) + + if (self._canvas.optional_annotations["mesh"] + or frame_idx == self._active_frame.frame_index + or refresh_annotations): + landmarks = self.get_landmarks(frame_idx, face_idx, face, [pnt_x, pnt_y], + refresh=True) + self._locate_mesh(mesh_ids, landmarks) + + def _discard_tk_faces(self) -> None: + """ Remove any :class:`TKFace` objects from the cache that are not currently displayed. """ + keys = [f"{pnt_x}_{pnt_y}" + for pnt_x, pnt_y in self._objects.visible_grid[:2].T.reshape(-1, 2)] + for key in list(self._tk_faces): + if key not in keys: + del self._tk_faces[key] + logger.trace("keys: %s allocated_faces: %s", # type:ignore[attr-defined] + keys, len(self._tk_faces)) + + def get_tk_face(self, frame_index: int, face_index: int, face: DetectedFace) -> TKFace: + """ Obtain the :class:`TKFace` object for the given face from the cache. If the face does + not exist in the cache, then it is generated and added prior to returning. + + Parameters + ---------- + frame_index: int + The frame index to obtain the face for + face_index: int + The face index of the face within the requested frame + face: :class:`~lib.align.DetectedFace` + The detected face object, containing the thumbnail jpg + + Returns + ------- + :class:`TKFace` + An object for displaying in the faces viewer canvas populated with the aligned mesh + landmarks and face thumbnail + """ + is_active = frame_index == self._active_frame.frame_index + key = "_".join([str(frame_index), str(face_index)]) + if key not in self._tk_faces or is_active: + logger.trace("creating new tk_face: (key: %s, " # type:ignore[attr-defined] + "is_active: %s)", key, is_active) + if is_active: + image = AlignedFace(face.landmarks_xy, + image=self._active_frame.current_frame, + centering=self._centering, + size=self.face_size).face + else: + thumb = face.thumbnail + assert thumb is not None + image = AlignedFace(face.landmarks_xy, + image=cv2.imdecode(thumb, cv2.IMREAD_UNCHANGED), + centering=self._centering, + size=self.face_size, + is_aligned=True).face + assert image is not None + tk_face = self._get_tk_face_object(face, image, is_active) + self._tk_faces[key] = tk_face + else: + logger.trace("tk_face exists: %s", key) # type:ignore[attr-defined] + tk_face = self._tk_faces[key] + return tk_face + + def _get_tk_face_object(self, + face: DetectedFace, + image: np.ndarray, + is_active: bool) -> TKFace: + """ Obtain an existing unallocated, or a newly created :class:`TKFace` and populate it with + face information from the requested frame and face index. + + If the face is currently active, then the face is generated from the currently displayed + frame, otherwise it is generated from the jpg thumbnail. + + Parameters + ---------- + face: :class:`lib.align.DetectedFace` + A detected face object to create the :class:`TKFace` from + image: :class:`numpy.ndarray` + The jpg thumbnail or the 3 channel image for the face + is_active: bool + ``True`` if the face in the currently active frame otherwise ``False`` + + Returns + ------- + :class:`TKFace` + An object for displaying in the faces viewer canvas populated with the aligned face + image with a mask applied, if required. + """ + get_mask = (self._canvas.optional_annotations["mask"] or + (is_active and self.selected_editor == "mask")) + mask = self._obtain_mask(face, self._canvas.selected_mask) if get_mask else None + tk_face = TKFace(image, size=self.face_size, mask=mask) + logger.trace("face: %s, tk_face: %s", face, tk_face) # type:ignore[attr-defined] + return tk_face + + def get_landmarks(self, + frame_index: int, + face_index: int, + face: DetectedFace, + top_left: list[float], + refresh: bool = False + ) -> dict[T.Literal["polygon", "line"], list[np.ndarray]]: + """ Obtain the landmark points for each mesh annotation. + + First tries to obtain the aligned landmarks from the cache. If the landmarks do not exist + in the cache, or a refresh has been requested, then the landmarks are calculated from the + detected face object. + + Parameters + ---------- + frame_index: int + The frame index to obtain the face for + face_index: int + The face index of the face within the requested frame + face: :class:`lib.align.DetectedFace` + The detected face object to obtain landmarks for + top_left: list[float] + The top left (x, y) points of the face's bounding box within the viewport + refresh: bool, optional + Whether to force a reload of the face's aligned landmarks, even if they already exist + within the cache. Default: ``False`` + + Returns + ------- + dict + The key is the tkinter canvas object type for each part of the mesh annotation + (`polygon`, `line`). The value is a list containing the (x, y) coordinates of each + part of the mesh annotation, from the top left corner location. + """ + key = f"{frame_index}_{face_index}" + landmarks: dict[T.Literal["polygon", "line"], list[np.ndarray]] | None + landmarks = self._landmarks.get(key, None) + if not landmarks or refresh: + aligned = AlignedFace(face.landmarks_xy, + centering=self._centering, + size=self.face_size) + landmarks = {"polygon": [], "line": []} + for start, end, fill in LANDMARK_PARTS[aligned.landmark_type].values(): + points = aligned.landmarks[start:end] + top_left + shape: T.Literal["polygon", "line"] = "polygon" if fill else "line" + landmarks[shape].append(points) + self._landmarks[key] = landmarks + return landmarks + + def _locate_mesh(self, mesh_ids, landmarks): + """ Place the mesh annotation canvas objects in the correct location. + + Parameters + ---------- + mesh_ids: list + The list of mesh id objects to set coordinates for + landmarks: dict + The mesh point groupings and whether each group should be a line or a polygon + """ + for key, area in landmarks.items(): + if key not in mesh_ids: + continue + for coords, mesh_id in zip(area, mesh_ids[key]): + self._canvas.coords(mesh_id, *coords.flatten()) + + def face_from_point(self, point_x: int, point_y: int) -> np.ndarray: + """ Given an (x, y) point on the :class:`Viewport`, obtain the face information at that + location. + + Parameters + ---------- + point_x: int + The x position on the canvas of the point to retrieve the face for + point_y: int + The y position on the canvas of the point to retrieve the face for + + Returns + ------- + :class:`numpy.ndarray` + Array of shape (4, ) containing the (`frame index`, `face index`, `x_point of top left + corner`, `y point of top left corner`) of the face at the given coordinates. + + If the given coordinates are not over a face, then the frame and face indices will be + -1 + """ + if not self._grid.is_valid or point_x > self._grid.dimensions[0]: + retval = np.array((-1, -1, -1, -1)) + else: + x_idx = np.searchsorted(self._objects.visible_grid[2, 0, :], point_x, side="left") - 1 + y_idx = np.searchsorted(self._objects.visible_grid[3, :, 0], point_y, side="left") - 1 + if x_idx < 0 or y_idx < 0: + retval = np.array((-1, -1, -1, -1)) + else: + retval = self._objects.visible_grid[:, y_idx, x_idx] + logger.trace(retval) # type:ignore[attr-defined] + return retval + + def move_active_to_top(self) -> None: + """ Check whether the active frame is going off the bottom of the viewport, if so: move it + to the top of the viewport. """ + self._active_frame.move_to_top() + + +class Recycler: + """ Tkinter can slow down when constantly creating new objects. + + This class delivers recycled objects, if stale objects are available, otherwise creates a new + object + + Parameters + ---------- + :class:`~tools.manual.faceviewe.frame.FacesViewer` + The canvas that holds the faces display + """ + def __init__(self, canvas: FacesViewer) -> None: + self._canvas = canvas + self._assets: dict[T.Literal["image", "line", "polygon"], + list[int]] = {"image": [], "line": [], "polygon": []} + self._mesh_methods: dict[T.Literal["line", "polygon"], + T.Callable] = {"line": canvas.create_line, + "polygon": canvas.create_polygon} + + def recycle_assets(self, asset_ids: list[int]) -> None: + """ Recycle assets that are no longer required + + Parameters + ---------- + asset_ids: list[int] + The IDs of the assets to be recycled + """ + logger.trace("Recycling %s objects", len(asset_ids)) # type:ignore[attr-defined] + for asset_id in asset_ids: + asset_type = T.cast(T.Literal["image", "line", "polygon"], self._canvas.type(asset_id)) + assert asset_type in self._assets + coords = (0, 0, 0, 0) if asset_type == "line" else (0, 0) + self._canvas.coords(asset_id, *coords) + + if asset_type == "image": + self._canvas.itemconfig(asset_id, image="") + + self._assets[asset_type].append(asset_id) + logger.trace("Recycled objects: %s", self._assets) # type:ignore[attr-defined] + + def get_image(self, coordinates: tuple[float | int, float | int]) -> int: + """ Obtain a recycled or new image object ID + + Parameters + ---------- + coordinates: tuple[float | int, float | int] + The co-ordinates that the image should be displayed at + + Returns + ------- + int + The canvas object id for the created image + """ + if self._assets["image"]: + retval = self._assets["image"].pop() + self._canvas.coords(retval, *coordinates) + logger.trace("Recycled image: %s", retval) # type:ignore[attr-defined] + else: + retval = self._canvas.create_image(*coordinates, + anchor=tk.NW, + tags=["viewport", "viewport_image"]) + logger.trace("Created new image: %s", retval) # type:ignore[attr-defined] + return retval + + def get_mesh(self, face: DetectedFace) -> dict[T.Literal["polygon", "line"], list[int]]: + """ Get the mesh annotation for the landmarks. This is made up of a series of polygons + or lines, depending on which part of the face is being annotated. Creates a new series of + objects, or pulls existing objects from the recycled objects pool if they are available. + + Parameters + ---------- + face: :class:`~lib.align.detected_face.DetectedFace` + The detected face object to obrain the mesh for + + Returns + ------- + dict[Literal["polygon", "line"], list[int]] + The dictionary of line and polygon tkinter canvas object ids for the mesh annotation + """ + mesh_kwargs = self._canvas.viewport.mesh_kwargs + mesh_parts = LANDMARK_PARTS[LandmarkType.from_shape(face.landmarks_xy.shape)] + retval: dict[T.Literal["polygon", "line"], list[int]] = {} + for _, _, fill in mesh_parts.values(): + asset_type: T.Literal["polygon", "line"] = "polygon" if fill else "line" + kwargs = mesh_kwargs[asset_type] + if self._assets[asset_type]: + asset_id = self._assets[asset_type].pop() + self._canvas.itemconfig(asset_id, **kwargs) + logger.trace("Recycled mesh %s: %s", # type:ignore[attr-defined] + asset_type, asset_id) + else: + coords = (0, 0) if asset_type == "polygon" else (0, 0, 0, 0) + tags = ["viewport", "viewport_mesh", f"viewport_{asset_type}"] + asset_id = self._mesh_methods[asset_type](coords, width=1, tags=tags, **kwargs) + logger.trace("Created new mesh %s: %s", # type:ignore[attr-defined] + asset_type, asset_id) + + retval.setdefault(asset_type, []).append(asset_id) + logger.trace("Got mesh: %s", retval) # type:ignore[attr-defined] + return retval + + +class VisibleObjects(): + """ Holds the objects from the :class:`~tools.manual.faceviewer.frame.Grid` that appear in the + viewable area of the :class:`Viewport`. + + Parameters + ---------- + viewport: :class:`Viewport` + The viewport object for the :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas + """ + def __init__(self, viewport: Viewport) -> None: + logger.debug(parse_class_init(locals())) + self._viewport = viewport + self._canvas = viewport._canvas + self._grid = viewport._grid + self._size = viewport.face_size + + self._visible_grid = np.zeros((4, 0, 0)) + self._visible_faces = np.zeros((0, 0)) + self._recycler = Recycler(self._canvas) + self._images: np.ndarray = np.zeros((0, 0), dtype=np.int64) + self._meshes: np.ndarray = np.zeros((0, 0)) + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def visible_grid(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The currently visible section of the + :class:`~tools.manual.faceviewer.frame.Grid` + + A numpy array of shape (`4`, `rows`, `columns`) corresponding to the viewable area of the + display grid. 1st dimension contains frame indices, 2nd dimension face indices. The 3rd and + 4th dimension contain the x and y position of the top left corner of the face respectively. + + Any locations that are not populated by a face will have a frame and face index of -1. """ + return self._visible_grid + + @property + def visible_faces(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The currently visible :class:`~lib.align.DetectedFace` + objects. + + A numpy array of shape (`rows`, `columns`) corresponding to the viewable area of the + display grid and containing the detected faces at their currently viewable position. + + Any locations that are not populated by a face will have ``None`` in it's place. """ + return self._visible_faces + + @property + def images(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The viewport's tkinter canvas image objects. + + A numpy array of shape (`rows`, `columns`) corresponding to the viewable area of the + display grid and containing the tkinter canvas image object for the face at the + corresponding location. """ + return self._images + + @property + def meshes(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The viewport's tkinter canvas mesh annotation objects. + + A numpy array of shape (`rows`, `columns`) corresponding to the viewable area of the + display grid and containing a dictionary of the corresponding tkinter polygon and line + objects required to build a face's mesh annotation for the face at the corresponding + location. """ + return self._meshes + + @property + def _top_left(self) -> np.ndarray: + """ :class:`numpy.ndarray`: The canvas (`x`, `y`) position of the face currently in the + viewable area's top left position. """ + if not np.any(self._images): + retval = [0.0, 0.0] + else: + retval = self._canvas.coords(self._images[0][0]) + return np.array(retval, dtype="int") + + def update(self) -> None: + """ Load and unload thumbnails in the visible area of the faces viewer. """ + if self._canvas.optional_annotations["mesh"]: # Display any hidden end of row meshes + self._canvas.itemconfig("viewport_mesh", state="normal") + + self._visible_grid, self._visible_faces = self._grid.visible_area + if (np.any(self._images) and np.any(self._visible_grid) + and self._visible_grid.shape[1:] != self._images.shape): + self._reset_viewport() + + required_rows = self._visible_grid.shape[1] if self._grid.is_valid else 0 + existing_rows = len(self._images) + logger.trace("existing_rows: %s. required_rows: %s", # type:ignore[attr-defined] + existing_rows, required_rows) + + if existing_rows > required_rows: + self._remove_rows(existing_rows, required_rows) + if existing_rows < required_rows: + self._add_rows(existing_rows, required_rows) + + self._shift() + + def _reset_viewport(self) -> None: + """ Reset all objects in the viewport on a column count change. Reset the viewport size + to the newly specified face size. """ + logger.debug("Resetting Viewport") + self._size = self._viewport.face_size + images = self._images.flatten().tolist() + meshes = [parts for mesh in [mesh.values() for mesh in self._meshes.flatten()] + for parts in mesh] + mesh_ids = [asset for mesh in meshes for asset in mesh] + self._recycler.recycle_assets(images + mesh_ids) + self._images = np.zeros((0, 0), np.int64) + self._meshes = np.zeros((0, 0)) + + def _remove_rows(self, existing_rows: int, required_rows: int) -> None: + """ Remove and recycle rows from the viewport that are not in the view area. + + Parameters + ---------- + existing_rows: int + The number of existing rows within the viewport + required_rows: int + The number of rows required by the viewport + """ + logger.debug("Removing rows from viewport: (existing_rows: %s, required_rows: %s)", + existing_rows, required_rows) + images = self._images[required_rows: existing_rows].flatten().tolist() + meshes = [parts + for mesh in [mesh.values() + for mesh in self._meshes[required_rows: existing_rows].flatten()] + for parts in mesh] + mesh_ids = [asset for mesh in meshes for asset in mesh] + self._recycler.recycle_assets(images + mesh_ids) + self._images = self._images[:required_rows] + self._meshes = self._meshes[:required_rows] + logger.trace("self._images: %s, self._meshes: %s", # type:ignore[attr-defined] + self._images.shape, self._meshes.shape) + + def _add_rows(self, existing_rows: int, required_rows: int) -> None: + """ Add rows to the viewport. + + Parameters + ---------- + existing_rows: int + The number of existing rows within the viewport + required_rows: int + The number of rows required by the viewport + """ + logger.debug("Adding rows to viewport: (existing_rows: %s, required_rows: %s)", + existing_rows, required_rows) + columns = self._grid.columns_rows[0] + + base_coords: list[list[float | int]] + + if not np.any(self._images): + base_coords = [[col * self._size, 0] for col in range(columns)] + else: + base_coords = [self._canvas.coords(item_id) for item_id in self._images[0]] + logger.trace("existing rows: %s, required_rows: %s, " # type:ignore[attr-defined] + "base_coords: %s", existing_rows, required_rows, base_coords) + images = [] + meshes = [] + for row in range(existing_rows, required_rows): + y_coord = base_coords[0][1] + (row * self._size) + images.append([self._recycler.get_image((coords[0], y_coord)) + for coords in base_coords]) + meshes.append([{} if face is None else self._recycler.get_mesh(face) + for face in self._visible_faces[row]]) + + a_images: np.ndarray = np.array(images) + a_meshes: np.ndarray = np.array(meshes) + + if not np.any(self._images): + logger.debug("Adding initial viewport objects: (image shapes: %s, mesh shapes: %s)", + a_images.shape, a_meshes.shape) + self._images = a_images + self._meshes = a_meshes + else: + logger.debug("Adding new viewport objects: (image shapes: %s, mesh shapes: %s)", + a_images.shape, a_meshes.shape) + self._images = np.concatenate((self._images, a_images)) + self._meshes = np.concatenate((self._meshes, a_meshes)) + + logger.trace("self._images: %s, self._meshes: %s", # type:ignore[attr-defined] + self._images.shape, self._meshes.shape) + + def _shift(self) -> bool: + """ Shift the viewport in the y direction if required + + Returns + ------- + bool + ``True`` if the viewport was shifted otherwise ``False`` + """ + current_y = self._top_left[1] + required_y = self.visible_grid[3, 0, 0] if self._grid.is_valid else 0 + logger.trace("current_y: %s, required_y: %s", # type:ignore[attr-defined] + current_y, required_y) + if current_y == required_y: + logger.trace("No move required") # type:ignore[attr-defined] + return False + shift_amount = required_y - current_y + logger.trace("Shifting viewport: %s", shift_amount) # type:ignore[attr-defined] + self._canvas.move("viewport", 0, shift_amount) + return True + + +class TKFace(): + """ An object that holds a single :class:`tkinter.PhotoImage` face, ready for placement in the + :class:`Viewport`, Handles the placement of and removal of masks for the face as well as + updates on any edits. + + Parameters + ---------- + face: :class:`numpy.ndarray` + The face, sized correctly as a 3 channel BGR image or an encoded jpg to create a + :class:`tkinter.PhotoImage` from + size: int, optional + The pixel size of the face image. Default: `128` + mask: :class:`numpy.ndarray` or ``None``, optional + The mask to be applied to the face image. Pass ``None`` if no mask is to be used. + Default ``None`` + """ + def __init__(self, face: np.ndarray, size: int = 128, mask: np.ndarray | None = None) -> None: + logger.trace(parse_class_init(locals())) # type:ignore[attr-defined] + self._size = size + if face.ndim == 2 and face.shape[1] == 1: + self._face = self._image_from_jpg(face) + else: + self._face = face[..., 2::-1] + self._photo = ImageTk.PhotoImage(self._generate_tk_face_data(mask)) + + logger.trace("Initialized %s", self.__class__.__name__) # type:ignore[attr-defined] + + # << PUBLIC PROPERTIES >> # + @property + def photo(self) -> ImageTk.PhotoImage: + """ :class:`PIL.ImageTk.PhotoImage`: The face in a format that can be placed on the + :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas. """ + return self._photo + + # << PUBLIC METHODS >> # + def update(self, face: np.ndarray, mask: np.ndarray) -> None: + """ Update the :attr:`photo` with the given face and mask. + + Parameters + ---------- + face: :class:`numpy.ndarray` + The face, sized correctly as a 3 channel BGR image + mask: :class:`numpy.ndarray` or ``None`` + The mask to be applied to the face image. Pass ``None`` if no mask is to be used + """ + self._face = face[..., 2::-1] + self._photo.paste(self._generate_tk_face_data(mask)) + + def update_mask(self, mask: np.ndarray | None) -> None: + """ Update the mask in the 4th channel of :attr:`photo` to the given mask. + + Parameters + ---------- + mask: :class:`numpy.ndarray` or ``None`` + The mask to be applied to the face image. Pass ``None`` if no mask is to be used + """ + self._photo.paste(self._generate_tk_face_data(mask)) + + # << PRIVATE METHODS >> # + def _image_from_jpg(self, face: np.ndarray) -> np.ndarray: + """ Convert an encoded jpg into 3 channel BGR image. + + Parameters + ---------- + face: :class:`numpy.ndarray` + The encoded jpg as a two dimension numpy array + + Returns + ------- + :class:`numpy.ndarray` + The decoded jpg as a 3 channel BGR image + """ + retval = cv2.imdecode(face, cv2.IMREAD_UNCHANGED) + assert retval is not None + interp = cv2.INTER_CUBIC if retval.shape[0] < self._size else cv2.INTER_AREA + if retval.shape[0] != self._size: + face = cv2.resize(retval, (self._size, self._size), interpolation=interp) + return retval[..., 2::-1] + + def _generate_tk_face_data(self, mask: np.ndarray | None) -> Image.Image: + """ Create the :class:`tkinter.PhotoImage` from the currant :attr:`_face`. + + Parameters + ---------- + mask: :class:`numpy.ndarray` or ``None`` + The mask to add to the image. ``None`` if a mask is not being used + + Returns + ------- + :class:`PIL.Image.Image` + The face formatted for the :class:`~tools.manual.faceviewer.frame.FacesViewer` canvas. + """ + mask = np.ones(self._face.shape[:2], dtype="uint8") * 255 if mask is None else mask + if mask.shape[0] != self._size: + mask = cv2.resize(mask, self._face.shape[:2], interpolation=cv2.INTER_AREA) + img = np.concatenate((self._face, mask[..., None]), axis=-1) + return Image.fromarray(img) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/frameviewer/__init__.py b/tools/manual/frameviewer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/manual/frameviewer/control.py b/tools/manual/frameviewer/control.py new file mode 100644 index 0000000000..5bbb681bdf --- /dev/null +++ b/tools/manual/frameviewer/control.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +""" Handles Navigation and Background Image for the Frame Viewer section of the manual +tool GUI. """ + +import logging +import tkinter as tk + +import cv2 +import numpy as np +from PIL import Image, ImageTk + +from lib.align import AlignedFace +from lib.utils import get_module_objects + +logger = logging.getLogger(__name__) + + +class Navigation(): + """ Handles playback and frame navigation for the Frame Viewer Window. + + Parameters + ---------- + display_frame: :class:`DisplayFrame` + The parent frame viewer window + """ + def __init__(self, display_frame): + logger.debug("Initializing %s", self.__class__.__name__) + self._display_frame = display_frame + self._globals = display_frame._globals + self._det_faces = display_frame._det_faces + self._nav = display_frame._nav + self._tk_is_playing = tk.BooleanVar() + self._tk_is_playing.set(False) + self._det_faces.tk_face_count_changed.trace("w", self._update_total_frame_count) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _current_nav_frame_count(self): + """ int: The current frame count for the transport slider """ + return self._nav["scale"].cget("to") + 1 + + def nav_scale_callback(self, *args, reset_progress=True): # pylint:disable=unused-argument + """ Adjust transport slider scale for different filters. Hide or display optional filter + controls. + """ + self._display_frame.pack_threshold_slider() + if reset_progress: + self.stop_playback() + frame_count = self._det_faces.filter.count + if self._current_nav_frame_count == frame_count: + logger.trace("Filtered count has not changed. Returning") + if self._globals.var_filter_mode.get() == "Misaligned Faces": + self._det_faces.tk_face_count_changed.set(True) + self._update_total_frame_count() + if reset_progress: + self._globals.var_transport_index.set(0) + + def _update_total_frame_count(self, *args): # pylint:disable=unused-argument + """ Update the displayed number of total frames that meet the current filter criteria. + + Parameters + ---------- + args: tuple + Required for tkinter trace callback but unused + """ + frame_count = self._det_faces.filter.count + if self._current_nav_frame_count == frame_count: + logger.trace("Filtered count has not changed. Returning") + return + max_frame = max(0, frame_count - 1) + logger.debug("Filtered frame count has changed. Updating from %s to %s", + self._current_nav_frame_count, frame_count) + self._nav["scale"].config(to=max_frame) + self._nav["label"].config(text=f"/{max_frame}") + state = "disabled" if max_frame == 0 else "normal" + self._nav["entry"].config(state=state) + + @property + def tk_is_playing(self): + """ :class:`tkinter.BooleanVar`: Whether the stream is currently playing. """ + return self._tk_is_playing + + def handle_play_button(self): + """ Handle the play button. + + Switches the :attr:`tk_is_playing` variable. + """ + is_playing = self.tk_is_playing.get() + self.tk_is_playing.set(not is_playing) + + def stop_playback(self): + """ Stop play back if playing """ + if self.tk_is_playing.get(): + logger.trace("Stopping playback") + self.tk_is_playing.set(False) + + def increment_frame(self, frame_count=None, is_playing=False): + """ Update The frame navigation position to the next frame based on filter. """ + if not is_playing: + self.stop_playback() + position = self._get_safe_frame_index() + face_count_change = not self._det_faces.filter.frame_meets_criteria + if face_count_change: + position -= 1 + frame_count = self._det_faces.filter.count if frame_count is None else frame_count + if not face_count_change and (frame_count == 0 or position == frame_count - 1): + logger.debug("End of Stream. Not incrementing") + self.stop_playback() + return + self._globals.var_transport_index.set(min(position + 1, max(0, frame_count - 1))) + + def decrement_frame(self): + """ Update The frame navigation position to the previous frame based on filter. """ + self.stop_playback() + position = self._get_safe_frame_index() + face_count_change = not self._det_faces.filter.frame_meets_criteria + if not face_count_change and (self._det_faces.filter.count == 0 or position == 0): + logger.debug("End of Stream. Not decrementing") + return + self._globals.var_transport_index.set(min(max(0, self._det_faces.filter.count - 1), + max(0, position - 1))) + + def _get_safe_frame_index(self): + """ Obtain the current frame position from the var_transport_index variable in + a safe manner (i.e. handle for non-numeric) + + Returns + ------- + int + The current transport frame index + """ + try: + retval = self._globals.var_transport_index.get() + except tk.TclError as err: + if "expected floating-point" not in str(err): + raise + val = str(err).rsplit(" ", maxsplit=1)[-1].replace("\"", "") + retval = "".join(ch for ch in val if ch.isdigit()) + retval = 0 if not retval else int(retval) + self._globals.var_transport_index.set(retval) + return retval + + def goto_first_frame(self): + """ Go to the first frame that meets the filter criteria. """ + self.stop_playback() + position = self._globals.var_transport_index.get() + if position == 0: + return + self._globals.var_transport_index.set(0) + + def goto_last_frame(self): + """ Go to the last frame that meets the filter criteria. """ + self.stop_playback() + position = self._globals.var_transport_index.get() + frame_count = self._det_faces.filter.count + if position == frame_count - 1: + return + self._globals.var_transport_index.set(frame_count - 1) + + +class BackgroundImage(): + """ The background image of the canvas """ + def __init__(self, canvas): + self._canvas = canvas + self._globals = canvas._globals + self._det_faces = canvas._det_faces + placeholder = np.ones((*reversed(self._globals.frame_display_dims), 3), dtype="uint8") + self._tk_frame = ImageTk.PhotoImage(Image.fromarray(placeholder)) + self._tk_face = ImageTk.PhotoImage(Image.fromarray(placeholder)) + self._image = self._canvas.create_image(self._globals.frame_display_dims[0] / 2, + self._globals.frame_display_dims[1] / 2, + image=self._tk_frame, + anchor=tk.CENTER, + tags="main_image") + self._zoomed_centering = "face" + + @property + def _current_view_mode(self): + """ str: `frame` if global zoom mode variable is set to ``False`` other wise `face`. """ + retval = "face" if self._globals.is_zoomed else "frame" + logger.trace(retval) + return retval + + def refresh(self, view_mode): + """ Update the displayed frame. + + Parameters + ---------- + view_mode: ["frame", "face"] + The currently active editor's selected view mode. + """ + self._switch_image(view_mode) + logger.trace("Updating background frame") + getattr(self, f"_update_tk_{self._current_view_mode}")() + + def _switch_image(self, view_mode): + """ Switch the image between the full frame image and the zoomed face image. + + Parameters + ---------- + view_mode: ["frame", "face"] + The currently active editor's selected view mode. + """ + if view_mode == self._current_view_mode and ( + self._canvas.active_editor.zoomed_centering == self._zoomed_centering): + return + self._zoomed_centering = self._canvas.active_editor.zoomed_centering + logger.trace("Switching background image from '%s' to '%s'", + self._current_view_mode, view_mode) + img = getattr(self, f"_tk_{view_mode}") + self._canvas.itemconfig(self._image, image=img) + self._globals.set_zoomed(view_mode == "face") + self._globals.set_face_index(0) + + def _update_tk_face(self): + """ Update the currently zoomed face. """ + face = self._get_zoomed_face() + padding = self._get_padding((min(self._globals.frame_display_dims), + min(self._globals.frame_display_dims))) + face = cv2.copyMakeBorder(face, *padding, cv2.BORDER_CONSTANT) + if self._tk_frame.height() != face.shape[0]: + self._resize_frame() + + logger.trace("final shape: %s", face.shape) + self._tk_face.paste(Image.fromarray(face)) + + def _get_zoomed_face(self): + """ Get the zoomed face or a blank image if no faces are available. + + Returns + ------- + :class:`numpy.ndarray` + The face sized to the shortest dimensions of the face viewer + """ + frame_idx = self._globals.frame_index + face_idx = self._globals.face_index + faces_in_frame = self._det_faces.face_count_per_index[frame_idx] + size = min(self._globals.frame_display_dims) + + if face_idx + 1 > faces_in_frame: + logger.debug("Resetting face index to 0 for more faces in frame than current index: (" + "faces_in_frame: %s, zoomed_face_index: %s", faces_in_frame, face_idx) + self._globals.set_face_index(0) + + if faces_in_frame == 0: + face = np.ones((size, size, 3), dtype="uint8") + else: + det_face = self._det_faces.current_faces[frame_idx][face_idx] + face = AlignedFace(det_face.landmarks_xy, + image=self._globals.current_frame.image, + centering=self._zoomed_centering, + size=size).face + logger.trace("face shape: %s", face.shape) + return face[..., 2::-1] + + def _update_tk_frame(self): + """ Place the currently held frame into :attr:`_tk_frame`. """ + img = cv2.resize(self._globals.current_frame.image, + self._globals.current_frame.display_dims, + interpolation=self._globals.current_frame.interpolation)[..., 2::-1] + padding = self._get_padding(img.shape[:2]) + if any(padding): + img = cv2.copyMakeBorder(img, *padding, cv2.BORDER_CONSTANT) + logger.trace("final shape: %s", img.shape) + + if self._tk_frame.height() != img.shape[0]: + self._resize_frame() + + self._tk_frame.paste(Image.fromarray(img)) + + def _get_padding(self, size): + """ Obtain the Left, Top, Right, Bottom padding required to place the square face or frame + in to the Photo Image + + Returns + ------- + tuple + The (Left, Top, Right, Bottom) padding to apply to the face image in pixels + """ + pad_lt = ((self._globals.frame_display_dims[1] - size[0]) // 2, + (self._globals.frame_display_dims[0] - size[1]) // 2) + padding = (pad_lt[0], + self._globals.frame_display_dims[1] - size[0] - pad_lt[0], + pad_lt[1], + self._globals.frame_display_dims[0] - size[1] - pad_lt[1]) + logger.debug("Frame dimensions: %s, size: %s, padding: %s", + self._globals.frame_display_dims, size, padding) + return padding + + def _resize_frame(self): + """ Resize the :attr:`_tk_frame`, attr:`_tk_face` photo images, update the canvas to + offset the image correctly. + """ + logger.trace("Resizing video frame on resize event: %s", self._globals.frame_display_dims) + placeholder = np.ones((*reversed(self._globals.frame_display_dims), 3), dtype="uint8") + self._tk_frame = ImageTk.PhotoImage(Image.fromarray(placeholder)) + self._tk_face = ImageTk.PhotoImage(Image.fromarray(placeholder)) + self._canvas.coords(self._image, + self._globals.frame_display_dims[0] / 2, + self._globals.frame_display_dims[1] / 2) + img = self._tk_face if self._current_view_mode == "face" else self._tk_frame + self._canvas.itemconfig(self._image, image=img) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/frameviewer/editor/__init__.py b/tools/manual/frameviewer/editor/__init__.py new file mode 100644 index 0000000000..7902b47227 --- /dev/null +++ b/tools/manual/frameviewer/editor/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +""" The Frame Viewer for Faceswap's Manual Tool. """ + +from ._base import View +from .bounding_box import BoundingBox +from .extract_box import ExtractBox +from .landmarks import Landmarks, Mesh +from .mask import Mask diff --git a/tools/manual/frameviewer/editor/_base.py b/tools/manual/frameviewer/editor/_base.py new file mode 100644 index 0000000000..10906f301a --- /dev/null +++ b/tools/manual/frameviewer/editor/_base.py @@ -0,0 +1,636 @@ +#!/usr/bin/env python3 +""" Editor objects for the manual adjustments tool """ + +import gettext +import logging +import tkinter as tk + +from collections import OrderedDict + +import numpy as np + +from lib.gui.control_helper import ControlPanelOption + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class Editor(): + """ Parent Class for Object Editors. + + Editors allow the user to use a variety of tools to manipulate alignments from the main + display frame. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + control_text: str + The text that is to be displayed at the top of the Editor's control panel. + """ + def __init__(self, canvas, detected_faces, control_text="", key_bindings=None): + logger.debug("Initializing %s: (canvas: '%s', detected_faces: %s, control_text: %s)", + self.__class__.__name__, canvas, detected_faces, control_text) + self.zoomed_centering = "face" # Override for different zoomed centering per editor + self._canvas = canvas + self._globals = canvas._globals + self._det_faces = detected_faces + + self._current_color = {} + self._actions = OrderedDict() + self._controls = {"header": control_text, "controls": []} + self._add_key_bindings(key_bindings) + + self._add_actions() + self._add_controls() + self._add_annotation_format_controls() + + self._mouse_location = None + self._drag_data = {} + self._drag_callback = None + self.bind_mouse_motion() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _default_colors(self): + """ dict: The default colors for each annotation """ + return {"BoundingBox": "#0000ff", + "ExtractBox": "#00ff00", + "Landmarks": "#ff00ff", + "Mask": "#ff0000", + "Mesh": "#00ffff"} + + @property + def _is_active(self): + """ bool: ``True`` if this editor is currently active otherwise ``False``. + + Notes + ----- + When initializing, the active_editor parameter will not be set in the parent, + so return ``False`` in this instance + """ + return hasattr(self._canvas, "active_editor") and self._canvas.active_editor == self + + @property + def view_mode(self): + """ ["frame", "face"]: The view mode for the currently selected editor. If the editor does + not have a view mode that can be updated, then `"frame"` will be returned. """ + tk_var = self._actions.get("magnify", {}).get("tk_var", None) + retval = "frame" if tk_var is None or not tk_var.get() else "face" + return retval + + @property + def _zoomed_roi(self): + """ :class:`numpy.ndarray`: The (`left`, `top`, `right`, `bottom`) roi of the zoomed face + in the display frame. """ + half_size = min(self._globals.frame_display_dims) / 2 + left = self._globals.frame_display_dims[0] / 2 - half_size + top = 0 + right = self._globals.frame_display_dims[0] / 2 + half_size + bottom = self._globals.frame_display_dims[1] + retval = np.rint(np.array((left, top, right, bottom))).astype("int32") + logger.trace("Zoomed ROI: %s", retval) + return retval + + @property + def _zoomed_dims(self): + """ tuple: The (`width`, `height`) of the zoomed ROI. """ + roi = self._zoomed_roi + return (roi[2] - roi[0], roi[3] - roi[1]) + + @property + def _control_vars(self): + """ dict: The tk control panel variables for the currently selected editor. """ + return self._canvas.control_tk_vars.get(self.__class__.__name__, {}) + + @property + def controls(self): + """ dict: The control panel options and header text for the current editor """ + return self._controls + + @property + def _control_color(self): + """ str: The hex color code set in the control panel for the current editor. """ + annotation = self.__class__.__name__ + return self._annotation_formats[annotation]["color"].get() + + @property + def _annotation_formats(self): + """ dict: The format (color, opacity etc.) of each editor's annotation display. """ + return self._canvas.annotation_formats + + @property + def actions(self): + """ list: The optional action buttons for the actions frame in the GUI for the + current editor """ + return self._actions + + @property + def _face_iterator(self): + """ list: The detected face objects to be iterated. This will either be all faces in the + frame (normal view) or the single zoomed in face (zoom mode). """ + if self._globals.frame_index == -1: + faces = [] + else: + faces = self._det_faces.current_faces[self._globals.frame_index] + faces = ([faces[self._globals.face_index]] + if self._globals.is_zoomed and faces else faces) + return faces + + def _add_key_bindings(self, key_bindings): + """ Add the editor specific key bindings for the currently viewed editor. + + Parameters + ---------- + key_bindings: dict + The key binding to method dictionary for this editor. + """ + if key_bindings is None: + return + for key, method in key_bindings.items(): + logger.debug("Binding key '%s' to method %s for editor '%s'", + key, method, self.__class__.__name__) + self._canvas.key_bindings.setdefault(key, {})["bound_to"] = None + self._canvas.key_bindings[key][self.__class__.__name__] = method + + @staticmethod + def _get_anchor_points(bounding_box): + """ Retrieve the (x, y) co-ordinates for each of the 4 corners of a bounding box's anchors + for both the displayed anchors and the anchor grab locations. + + Parameters + ---------- + bounding_box: tuple + The (`top-left`, `top-right`, `bottom-right`, `bottom-left`) (x, y) coordinates of the + bounding box + + Returns + display_anchors: tuple + The (`top`, `left`, `bottom`, `right`) co-ordinates for each circle at each point + of the bounding box corners, sized for display + grab_anchors: tuple + The (`top`, `left`, `bottom`, `right`) co-ordinates for each circle at each point + of the bounding box corners, at a larger size for grabbing with a mouse + """ + radius = 3 + grab_radius = radius * 3 + display_anchors = tuple((cnr[0] - radius, cnr[1] - radius, + cnr[0] + radius, cnr[1] + radius) + for cnr in bounding_box) + grab_anchors = tuple((cnr[0] - grab_radius, cnr[1] - grab_radius, + cnr[0] + grab_radius, cnr[1] + grab_radius) + for cnr in bounding_box) + return display_anchors, grab_anchors + + def update_annotation(self): + """ Update the display annotations for the current objects. + + Override for specific editors. + """ + logger.trace("Default annotations. Not storing Objects") + + def hide_annotation(self, tag=None): + """ Hide annotations for this editor. + + Parameters + ---------- + tag: str, optional + The specific tag to hide annotations for. If ``None`` then all annotations for this + editor are hidden, otherwise only the annotations specified by the given tag are + hidden. Default: ``None`` + """ + tag = self.__class__.__name__ if tag is None else tag + logger.trace("Hiding annotations for tag: %s", tag) + self._canvas.itemconfig(tag, state="hidden") + + def _object_tracker(self, key, object_type, face_index, + coordinates, object_kwargs): + """ Create an annotation object and add it to :attr:`_objects` or update an existing + annotation if it has already been created. + + Parameters + ---------- + key: str + The key for this annotation in :attr:`_objects` + object_type: str + This can be any string that is a natural extension to :class:`tkinter.Canvas.create_` + face_index: int + The index of the face within the current frame + coordinates: tuple or list + The bounding box coordinates for this object + object_kwargs: dict + The keyword arguments for this object + + Returns + ------- + int: + The tkinter canvas item identifier for the created object + """ + object_color_keys = self._get_object_color_keys(key, object_type) + tracking_id = "_".join((key, str(face_index))) + face_tag = f"face_{face_index}" + face_objects = set(self._canvas.find_withtag(face_tag)) + annotation_objects = set(self._canvas.find_withtag(key)) + existing_object = tuple(face_objects.intersection(annotation_objects)) + if not existing_object: + item_id = self._add_new_object(key, + object_type, + face_index, + coordinates, + object_kwargs) + update_color = bool(object_color_keys) + else: + item_id = existing_object[0] + update_color = self._update_existing_object( + existing_object[0], + coordinates, + object_kwargs, + tracking_id, + object_color_keys) + if update_color: + self._current_color[tracking_id] = object_kwargs[object_color_keys[0]] + return item_id + + @staticmethod + def _get_object_color_keys(key, object_type): + """ The canvas object's parameter that needs to be adjusted for color varies based on + the type of object that is being used. Returns the correct parameter based on object. + + Parameters + ---------- + key: str + The key for this annotation's tag creation + object_type: str + This can be any string that is a natural extension to :class:`tkinter.Canvas.create_` + + Returns + ------- + list: + The list of keyword arguments for this objects color parameter(s) or an empty list + if it is not relevant for this object + """ + if object_type in ("line", "text"): + retval = ["fill"] + elif object_type == "image": + retval = [] + elif object_type == "oval" and key.startswith("lm_dsp_"): + retval = ["fill", "outline"] + else: + retval = ["outline"] + logger.trace("returning %s for key: %s, object_type: %s", retval, key, object_type) + return retval + + def _add_new_object(self, key, object_type, face_index, coordinates, object_kwargs): + """ Add a new object to the canvas. + + Parameters + ---------- + key: str + The key for this annotation's tag creation + object_type: str + This can be any string that is a natural extension to :class:`tkinter.Canvas.create_` + face_index: int + The index of the face within the current frame + coordinates: tuple or list + The bounding box coordinates for this object + object_kwargs: dict + The keyword arguments for this object + + Returns + ------- + int: + The tkinter canvas item identifier for the created object + """ + logger.debug("Adding object: (key: '%s', object_type: '%s', face_index: %s, " + "coordinates: %s, object_kwargs: %s)", key, object_type, face_index, + coordinates, object_kwargs) + object_kwargs["tags"] = self._set_object_tags(face_index, key) + item_id = getattr(self._canvas, + f"create_{object_type}")(*coordinates, **object_kwargs) + return item_id + + def _set_object_tags(self, face_index, key): + """ Create the tkinter object tags for the incoming object. + + Parameters + ---------- + face_index: int + The face index within the current frame for the face that tags are being created for + key: str + The base tag for this object, for which additional tags will be generated + + Returns + ------- + list + The generated tags for the current object + """ + tags = [f"face_{face_index}", + self.__class__.__name__, + f"{self.__class__.__name__}_face_{face_index}", + key, + f"{key}_face_{face_index}"] + if "_" in key: + split_key = key.split("_") + if split_key[-1].isdigit(): + base_tag = "_".join(split_key[:-1]) + tags.append(base_tag) + tags.append(f"{base_tag}_face_{face_index}") + return tags + + def _update_existing_object(self, item_id, coordinates, object_kwargs, + tracking_id, object_color_keys): + """ Update an existing tracked object. + + Parameters + ---------- + item_id: int + The canvas object item_id to be updated + coordinates: tuple or list + The bounding box coordinates for this object + object_kwargs: dict + The keyword arguments for this object + tracking_id: str + The tracking identifier for this object's color + object_color_keys: list + The list of keyword arguments for this object to update for color + + Returns + ------- + bool + ``True`` if :attr:`_current_color` should be updated otherwise ``False`` + """ + update_color = (object_color_keys and + object_kwargs[object_color_keys[0]] != self._current_color[tracking_id]) + update_kwargs = {"state": object_kwargs.get("state", "normal")} + if update_color: + for key in object_color_keys: + update_kwargs[key] = object_kwargs[object_color_keys[0]] + if self._canvas.type(item_id) == "image" and "image" in object_kwargs: # noqa:E721 + update_kwargs["image"] = object_kwargs["image"] + logger.trace("Updating coordinates: (item_id: '%s', object_kwargs: %s, " + "coordinates: %s, update_kwargs: %s", item_id, object_kwargs, + coordinates, update_kwargs) + self._canvas.itemconfig(item_id, **update_kwargs) + self._canvas.coords(item_id, *coordinates) + return update_color + + # << MOUSE CALLBACKS >> + # Mouse cursor display + def bind_mouse_motion(self): + """ Binds the mouse motion for the current editor's mouse event to the editor's + :func:`_update_cursor` function. + + Called on initialization and active editor update. + """ + self._canvas.bind("", self._update_cursor) + + def _update_cursor(self, event): # pylint:disable=unused-argument + """ The mouse cursor display as bound to the mouse's event.. + + The default is to always return a standard cursor, so this method should be overridden for + editor specific cursor update. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Unused for default tracking, but available for specific editor + tracking. + """ + self._canvas.config(cursor="") + + # Mouse click and drag actions + def set_mouse_click_actions(self): + """ Add the bindings for left mouse button click and drag actions. + + This binds the mouse to the :func:`_drag_start`, :func:`_drag` and :func:`_drag_stop` + methods. + + By default these methods do nothing (except for :func:`_drag_stop` which resets + :attr:`_drag_data`. + + This bindings should be added for all editors. To add additional bindings, + `super().set_mouse_click_actions` should be called prior to adding them.. + """ + logger.debug("Setting mouse bindings") + self._canvas.bind("", self._drag_start) + self._canvas.bind("", self._drag_stop) + self._canvas.bind("", self._drag) + + def _drag_start(self, event): # pylint:disable=unused-argument + """ The action to perform when the user starts clicking and dragging the mouse. + + The default does nothing except reset the attr:`drag_data` and attr:`drag_callback`. + Override for Editor specific click and drag start actions. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Unused but for default action, but available for editor + specific actions + """ + self._drag_data = {} + self._drag_callback = None + + def _drag(self, event): + """ The default callback for the drag part of a mouse click and drag action. + + :attr:`_drag_callback` should be set in :func:`self._drag_start`. This callback will then + be executed on a mouse drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + if self._drag_callback is None: + return + self._drag_callback(event) # pylint:disable=not-callable + + def _drag_stop(self, event): # pylint:disable=unused-argument + """ The action to perform when the user stops clicking and dragging the mouse. + + Default is to set :attr:`_drag_data` to `dict`. Override for Editor specific stop actions. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Unused but required + """ + self._drag_data = {} + + def _scale_to_display(self, points): + """ Scale and offset the given points to the current display scale and offset values. + + Parameters + ---------- + points: :class:`numpy.ndarray` + Array of x, y co-ordinates to adjust + + Returns + ------- + :class:`numpy.ndarray` + The adjusted x, y co-ordinates for display purposes rounded to the nearest integer + """ + retval = np.rint((points * self._globals.current_frame.scale) + + self._canvas.offset).astype("int32") + logger.trace("Original points: %s, scaled points: %s", points, retval) + return retval + + def scale_from_display(self, points, do_offset=True): + """ Scale and offset the given points from the current display to the correct original + values. + + Parameters + ---------- + points: :class:`numpy.ndarray` + Array of x, y co-ordinates to adjust + offset: bool, optional + ``True`` if the offset should be calculated otherwise ``False``. Default: ``True`` + + Returns + ------- + :class:`numpy.ndarray` + The adjusted x, y co-ordinates to the original frame location rounded to the nearest + integer + """ + offset = self._canvas.offset if do_offset else (0, 0) + retval = np.rint((points - offset) / self._globals.current_frame.scale).astype("int32") + logger.trace("Original points: %s, scaled points: %s", points, retval) + return retval + + # << ACTION CONTROL PANEL OPTIONS >> + def _add_actions(self): + """ Add the Action buttons for this editor's optional left hand side action sections. + + The default does nothing. Override for editor specific actions. + """ + self._actions = self._actions + + def _add_action(self, title, icon, helptext, group=None, hotkey=None): + """ Add an action dictionary to :attr:`_actions`. This will create a button in the optional + actions frame to the left hand side of the frames viewer. + + Parameters + ---------- + title: str + The title of the action to be generated + icon: str + The name of the icon that is used to display this action's button + helptext: str + The tooltip text to display for this action + group: str, optional + If a group is passed in, then any buttons belonging to that group will be linked (i.e. + only one button can be active at a time.). If ``None`` is passed in then the button + will act independently. Default: ``None`` + hotkey: str, optional + The hotkey binding for this action. Set to ``None`` if there is no hotkey binding. + Default: ``None`` + """ + var = tk.BooleanVar() + action = {"icon": icon, + "helptext": helptext, + "group": group, + "tk_var": var, + "hotkey": hotkey} + logger.debug("Adding action: %s", action) + self._actions[title] = action + + def _add_controls(self): + """ Add the controls for this editor's control panel. + + The default does nothing. Override for editor specific controls. + """ + self._controls = self._controls + + def _add_control(self, option, global_control=False): + """ Add a control panel control to :attr:`_controls` and add a trace to the variable + to update display. + + Parameters + ---------- + option: :class:`lib.gui.control_helper.ControlPanelOption' + The control panel option to add to this editor's control + global_control: bool, optional + Whether the given control is a global control (i.e. annotation formatting). + Default: ``False`` + """ + self._controls["controls"].append(option) + if global_control: + logger.debug("Added global control: '%s' for editor: '%s'", + option.title, self.__class__.__name__) + return + logger.debug("Added local control: '%s' for editor: '%s'", + option.title, self.__class__.__name__) + editor_key = self.__class__.__name__ + group_key = option.group.replace(" ", "").lower() + group_key = "none" if group_key == "_master" else group_key + annotation_key = option.title.replace(" ", "") + self._canvas.control_tk_vars.setdefault( + editor_key, {}).setdefault(group_key, {})[annotation_key] = option.tk_var + + def _add_annotation_format_controls(self): + """ Add the annotation display (color/size) controls to :attr:`_annotation_formats`. + + These should be universal and available for all editors. + """ + editors = ("Bounding Box", "Extract Box", "Landmarks", "Mask", "Mesh") + if not self._annotation_formats: + opacity = ControlPanelOption("Mask Opacity", + int, + group="Color", + min_max=(0, 100), + default=40, + rounding=1, + helptext="Set the mask opacity") + for editor in editors: + annotation_key = editor.replace(" ", "") + logger.debug("Adding to global format controls: '%s'", editor) + colors = ControlPanelOption(editor, + str, + group="Color", + subgroup="colors", + choices="colorchooser", + default=self._default_colors[annotation_key], + helptext="Set the annotation color") + colors.set(self._default_colors[annotation_key]) + self._annotation_formats.setdefault(annotation_key, {})["color"] = colors + self._annotation_formats[annotation_key]["mask_opacity"] = opacity + + for editor in editors: + annotation_key = editor.replace(" ", "") + for group, ctl in self._annotation_formats[annotation_key].items(): + logger.debug("Adding global format control to editor: (editor:'%s', group: '%s')", + editor, group) + self._add_control(ctl, global_control=True) + + +class View(Editor): + """ The view Editor. + + Does not allow any editing, just used for previewing annotations. + + This is the default start-up editor. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + """ + def __init__(self, canvas, detected_faces): + control_text = "Viewer\nPreview the frame's annotations." + super().__init__(canvas, detected_faces, control_text) + + def _add_actions(self): + """ Add the optional action buttons to the viewer. Current actions are Zoom. """ + self._add_action("magnify", "zoom", _("Magnify/Demagnify the View"), + group=None, hotkey="M") + self._actions["magnify"]["tk_var"].trace_add( + "write", + lambda *e: self._globals.var_full_update.set(True)) diff --git a/tools/manual/frameviewer/editor/bounding_box.py b/tools/manual/frameviewer/editor/bounding_box.py new file mode 100644 index 0000000000..1f07e76da7 --- /dev/null +++ b/tools/manual/frameviewer/editor/bounding_box.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +""" Bounding Box Editor for the manual adjustments tool """ + +import gettext +import platform +from functools import partial + +import numpy as np + +from lib.gui.custom_widgets import RightClickMenu +from lib.utils import get_module_objects +from ._base import ControlPanelOption, Editor, logger + + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class BoundingBox(Editor): + """ The Bounding Box Editor. + + Adjusting the bounding box feeds the aligner to generate new 68 point landmarks. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + """ + def __init__(self, canvas, detected_faces): + self._tk_aligner = None + self._right_click_menu = RightClickMenu([_("Delete Face")], + [self._delete_current_face], + ["Del"]) + control_text = _("Bounding Box Editor\nEdit the bounding box being fed into the aligner " + "to recalculate the landmarks.\n\n" + " - Grab the corner anchors to resize the bounding box.\n" + " - Click and drag the bounding box to relocate.\n" + " - Click in empty space to create a new bounding box.\n" + " - Right click a bounding box to delete a face.") + key_bindings = {"": self._delete_current_face} + super().__init__(canvas, detected_faces, + control_text=control_text, key_bindings=key_bindings) + + @property + def _corner_order(self): + """ dict: The position index of bounding box corners """ + return {0: ("top", "left"), + 1: ("top", "right"), + 2: ("bottom", "right"), + 3: ("bottom", "left")} + + @property + def _bounding_boxes(self): + """ list: The :func:`tkinter.Canvas.coords` for all displayed bounding boxes. """ + item_ids = self._canvas.find_withtag("bb_box") + return [self._canvas.coords(item_id) for item_id in item_ids + if self._canvas.itemcget(item_id, "state") != "hidden"] + + def _add_controls(self): + """ Controls for feeding the Aligner. Exposes Normalization Method as a parameter. """ + align_ctl = ControlPanelOption( + "Aligner", + str, + group="Aligner", + choices=["cv2-dnn", "FAN"], + default="FAN", + is_radio=True, + helptext=_("Aligner to use. FAN will obtain better alignments, but cv2-dnn can be " + "useful if FAN cannot get decent alignments and you want to set a base to " + "edit from.")) + self._tk_aligner = align_ctl.tk_var + self._add_control(align_ctl) + + norm_ctl = ControlPanelOption( + "Normalization method", + str, + group="Aligner", + choices=["none", "clahe", "hist", "mean"], + default="hist", + is_radio=True, + helptext=_("Normalization method to use for feeding faces to the aligner. This can " + "help the aligner better align faces with difficult lighting conditions. " + "Different methods will yield different results on different sets. NB: " + "This does not impact the output face, just the input to the aligner." + "\n\tnone: Don't perform normalization on the face." + "\n\tclahe: Perform Contrast Limited Adaptive Histogram Equalization on " + "the face." + "\n\thist: Equalize the histograms on the RGB channels." + "\n\tmean: Normalize the face colors to the mean.")) + var = norm_ctl.tk_var + var.trace("w", + lambda *e, v=var: self._det_faces.extractor.set_normalization_method(v.get())) + self._add_control(norm_ctl) + + def update_annotation(self): + """ Get the latest bounding box data from alignments and update. """ + if self._globals.is_zoomed: + logger.trace("Image is zoomed. Hiding Bounding Box.") + self.hide_annotation() + return + key = "bb_box" + color = self._control_color + for idx, face in enumerate(self._face_iterator): + box = np.array([(face.left, face.top), (face.right, face.bottom)]) + box = self._scale_to_display(box).astype("int32").flatten() + kwargs = {"outline": color, "width": 1} + logger.trace("frame_index: %s, face_index: %s, box: %s, kwargs: %s", + self._globals.frame_index, idx, box, kwargs) + self._object_tracker(key, "rectangle", idx, box, kwargs) + self._update_anchor_annotation(idx, box, color) + logger.trace("Updated bounding box annotations") + + def _update_anchor_annotation(self, face_index, bounding_box, color): + """ Update the anchor annotations for each corner of the bounding box. + + The anchors only display when the bounding box editor is active. + + Parameters + ---------- + face_index: int + The index of the face being annotated + bounding_box: :class:`numpy.ndarray` + The scaled bounding box to get the corner anchors for + color: str + The hex color of the bounding box line + """ + if not self._is_active: + self.hide_annotation("bb_anc_dsp") + self.hide_annotation("bb_anc_grb") + return + fill_color = "gray" + activefill_color = "white" if self._is_active else "" + anchor_points = self._get_anchor_points(((bounding_box[0], bounding_box[1]), + (bounding_box[2], bounding_box[1]), + (bounding_box[2], bounding_box[3]), + (bounding_box[0], bounding_box[3]))) + for idx, (anc_dsp, anc_grb) in enumerate(zip(*anchor_points)): + dsp_kwargs = {"outline": color, "fill": fill_color, "width": 1} + grb_kwargs = {"outline": '', "fill": '', "width": 1, "activefill": activefill_color} + dsp_key = f"bb_anc_dsp_{idx}" + grb_key = f"bb_anc_grb_{idx}" + self._object_tracker(dsp_key, "oval", face_index, anc_dsp, dsp_kwargs) + self._object_tracker(grb_key, "oval", face_index, anc_grb, grb_kwargs) + logger.trace("Updated bounding box anchor annotations") + + # << MOUSE HANDLING >> + # Mouse cursor display + def _update_cursor(self, event): + """ Set the cursor action. + + Update :attr:`_mouse_location` with the current cursor position and display appropriate + icon. + + If the cursor is over a corner anchor, then pop resize icon. + If the cursor is over a bounding box, then pop move icon. + If the cursor is over the image, then pop add icon. + + Parameters + ---------- + event: :class:`tkinter.Event` + The current tkinter mouse event + """ + if self._check_cursor_anchors(): + return + if self._check_cursor_bounding_box(event): + return + if self._check_cursor_image(event): + return + + self._canvas.config(cursor="") + self._mouse_location = None + + def _check_cursor_anchors(self): + """ Check whether the cursor is over a corner anchor. + + If it is, set the appropriate cursor type and set :attr:`_mouse_location` to + ("anchor", (`face index`, `anchor index`) + + Returns + ------- + bool + ``True`` if cursor is over an anchor point otherwise ``False`` + """ + anchors = set(self._canvas.find_withtag("bb_anc_grb")) + item_ids = set(self._canvas.find_withtag("current")).intersection(anchors) + if not item_ids: + return False + item_id = list(item_ids)[0] + tags = self._canvas.gettags(item_id) + face_idx = int(next(tag for tag in tags if tag.startswith("face_")).split("_")[-1]) + corner_idx = int(next(tag for tag in tags + if tag.startswith("bb_anc_grb_") + and "face_" not in tag).split("_")[-1]) + pos_x, pos_y = self._corner_order[corner_idx] + self._canvas.config(cursor=f"{pos_x}_{pos_y}_corner") + self._mouse_location = ("anchor", f"{face_idx}_{corner_idx}") + return True + + def _check_cursor_bounding_box(self, event): + """ Check whether the cursor is over a bounding box. + + If it is, set the appropriate cursor type and set :attr:`_mouse_location` to: + ("box", `face index`) + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event + + Returns + ------- + bool + ``True`` if cursor is over a bounding box otherwise ``False`` + + Notes + ----- + We can't use tags on unfilled rectangles as the interior of the rectangle is not tagged. + """ + for face_idx, bbox in enumerate(self._bounding_boxes): + if bbox[0] <= event.x <= bbox[2] and bbox[1] <= event.y <= bbox[3]: + self._canvas.config(cursor="fleur") + self._mouse_location = ("box", str(face_idx)) + return True + return False + + def _check_cursor_image(self, event): + """ Check whether the cursor is over the image. + + If it is, set the appropriate cursor type and set :attr:`_mouse_location` to: + ("image", ) + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event + + Returns + ------- + bool + ``True`` if cursor is over a bounding box otherwise ``False`` + """ + if self._globals.frame_index == -1: + return False + display_dims = self._globals.current_frame.display_dims + if (self._canvas.offset[0] <= event.x <= display_dims[0] + self._canvas.offset[0] and + self._canvas.offset[1] <= event.y <= display_dims[1] + self._canvas.offset[1]): + self._canvas.config(cursor="plus") + self._mouse_location = ("image", ) + return True + return False + + # Mouse Actions + def set_mouse_click_actions(self): + """ Add context menu to OS specific right click action. """ + super().set_mouse_click_actions() + self._canvas.bind("" if platform.system() == "Darwin" else "", + self._context_menu) + + def _drag_start(self, event): + """ The action to perform when the user starts clicking and dragging the mouse. + + If :attr:`_mouse_location` indicates a corner anchor, then the bounding box is resized + based on the adjusted corner, and the alignments re-generated. + + If :attr:`_mouse_location` indicates a bounding box, then the bounding box is moved, and + the alignments re-generated. + + If :attr:`_mouse_location` indicates being over the main image, then a new bounding box is + created, and alignments generated. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + if self._mouse_location is None: + self._drag_data = {} + self._drag_callback = None + return + if self._mouse_location[0] == "anchor": + corner_idx = int(self._mouse_location[1].split("_")[-1]) + self._drag_data["corner"] = self._corner_order[corner_idx] + self._drag_callback = self._resize + elif self._mouse_location[0] == "box": + self._drag_data["current_location"] = (event.x, event.y) + self._drag_callback = self._move + elif self._mouse_location[0] == "image": + self._create_new_bounding_box(event) + # Refresh cursor and _mouse_location for new bounding box and reset _drag_start + self._update_cursor(event) + self._drag_start(event) + + def _drag_stop(self, event): # pylint:disable=unused-argument + """ Trigger a viewport thumbnail update on click + drag release + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Required but unused. + """ + if self._mouse_location is None: + return + face_idx = int(self._mouse_location[1].split("_")[0]) + self._det_faces.update.post_edit_trigger(self._globals.frame_index, face_idx) + + def _create_new_bounding_box(self, event): + """ Create a new bounding box when user clicks on image, outside of existing boxes. + + The bounding box is created as a square located around the click location, with dimensions + 1 quarter the size of the frame's shortest side + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event + """ + size = min(self._globals.current_frame.display_dims) // 8 + box = (event.x - size, event.y - size, event.x + size, event.y + size) + logger.debug("Creating new bounding box: %s ", box) + self._det_faces.update.add(self._globals.frame_index, *self._coords_to_bounding_box(box)) + + def _resize(self, event): + """ Resizes a bounding box on a corner anchor drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + face_idx = int(self._mouse_location[1].split("_")[0]) + face_tag = f"bb_box_face_{face_idx}" + box = self._canvas.coords(face_tag) + logger.trace("Face Index: %s, Corner Index: %s. Original ROI: %s", + face_idx, self._drag_data["corner"], box) + # Switch top/bottom and left/right and set partial so indices match and we don't + # need branching logic for min/max. + limits = (partial(min, box[2] - 20), + partial(min, box[3] - 20), + partial(max, box[0] + 20), + partial(max, box[1] + 20)) + rect_xy_indices = [("left", "top", "right", "bottom").index(pnt) + for pnt in self._drag_data["corner"]] + box[rect_xy_indices[1]] = limits[rect_xy_indices[1]](event.x) + box[rect_xy_indices[0]] = limits[rect_xy_indices[0]](event.y) + logger.trace("New ROI: %s", box) + self._det_faces.update.bounding_box(self._globals.frame_index, + face_idx, + *self._coords_to_bounding_box(box), + aligner=self._tk_aligner.get()) + + def _move(self, event): + """ Moves the bounding box on a bounding box drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + logger.trace("event: %s, mouse_location: %s", event, self._mouse_location) + face_idx = int(self._mouse_location[1]) + shift = (event.x - self._drag_data["current_location"][0], + event.y - self._drag_data["current_location"][1]) + face_tag = f"bb_box_face_{face_idx}" + coords = np.array(self._canvas.coords(face_tag)) + (*shift, *shift) + logger.trace("face_tag: %s, shift: %s, new co-ords: %s", face_tag, shift, coords) + self._det_faces.update.bounding_box(self._globals.frame_index, + face_idx, + *self._coords_to_bounding_box(coords), + aligner=self._tk_aligner.get()) + self._drag_data["current_location"] = (event.x, event.y) + + def _coords_to_bounding_box(self, coords): + """ Converts tkinter coordinates to :class:`lib.align.DetectedFace` bounding + box format, scaled up and offset for feeding the model. + + Returns + ------- + tuple + The (`x`, `width`, `y`, `height`) integer points of the bounding box. + """ + logger.trace("in: %s", coords) + coords = self.scale_from_display( + np.array(coords).reshape((2, 2))).flatten().astype("int32") + logger.trace("out: %s", coords) + return (coords[0], coords[2] - coords[0], coords[1], coords[3] - coords[1]) + + def _context_menu(self, event): + """ Create a right click context menu to delete the alignment that is being + hovered over. """ + if self._mouse_location is None or self._mouse_location[0] != "box": + return + self._right_click_menu.popup(event) + + def _delete_current_face(self, *args): # pylint:disable=unused-argument + """ Called by the right click delete event. Deletes the face that the mouse is currently + over. + + Parameters + ---------- + args: tuple (unused) + The event parameter is passed in by the hot key binding, so args is required + """ + if self._mouse_location is None or self._mouse_location[0] != "box": + logger.debug("Delete called without valid location. _mouse_location: %s", + self._mouse_location) + return + logger.debug("Deleting face. _mouse_location: %s", self._mouse_location) + self._det_faces.update.delete(self._globals.frame_index, int(self._mouse_location[1])) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/frameviewer/editor/extract_box.py b/tools/manual/frameviewer/editor/extract_box.py new file mode 100644 index 0000000000..e13cb9f9af --- /dev/null +++ b/tools/manual/frameviewer/editor/extract_box.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +""" Extract Box Editor for the manual adjustments tool """ +import gettext +import platform + +import numpy as np + +from lib.align import AlignedFace +from lib.gui.custom_widgets import RightClickMenu +from lib.gui.utils import get_config +from lib.utils import get_module_objects +from ._base import Editor, logger + + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class ExtractBox(Editor): + """ The Extract Box Editor. + + Adjust the calculated Extract Box to shift all of the 68 point landmarks in place. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + """ + def __init__(self, canvas, detected_faces): + self._right_click_menu = RightClickMenu([_("Delete Face")], + [self._delete_current_face], + ["Del"]) + control_text = _("Extract Box Editor\nMove the extract box that has been generated by the " + "aligner. Click and drag:\n\n" + " - Inside the bounding box to relocate the landmarks.\n" + " - The corner anchors to resize the landmarks.\n" + " - Outside of the corners to rotate the landmarks.") + key_bindings = {"": self._delete_current_face} + super().__init__(canvas, detected_faces, + control_text=control_text, key_bindings=key_bindings) + + @property + def _corner_order(self): + """ dict: The position index of bounding box corners """ + return {0: ("top", "left"), + 3: ("top", "right"), + 2: ("bottom", "right"), + 1: ("bottom", "left")} + + def update_annotation(self): + """ Draw the latest Extract Boxes around the faces. """ + color = self._control_color + roi = self._zoomed_roi + for idx, face in enumerate(self._face_iterator): + logger.trace("Drawing Extract Box: (idx: %s)", idx) + if self._globals.is_zoomed: + box = np.array((roi[0], roi[1], roi[2], roi[1], roi[2], roi[3], roi[0], roi[3])) + else: + aligned = AlignedFace(face.landmarks_xy, centering="face") + box = self._scale_to_display(aligned.original_roi).flatten() + top_left = box[:2] - 10 + kwargs = {"fill": color, "font": ('Default', 20, 'bold'), "text": str(idx)} + self._object_tracker("eb_text", "text", idx, top_left, kwargs) + kwargs = {"fill": '', "outline": color, "width": 1} + self._object_tracker("eb_box", "polygon", idx, box, kwargs) + self._update_anchor_annotation(idx, box, color) + logger.trace("Updated extract box annotations") + + def _update_anchor_annotation(self, face_index, extract_box, color): + """ Update the anchor annotations for each corner of the extract box. + + The anchors only display when the extract box editor is active. + + Parameters + ---------- + face_index: int + The index of the face being annotated + extract_box: :class:`numpy.ndarray` + The scaled extract box to get the corner anchors for + color: str + The hex color of the extract box line + """ + if not self._is_active or self._globals.is_zoomed: + self.hide_annotation("eb_anc_dsp") + self.hide_annotation("eb_anc_grb") + return + fill_color = "gray" + activefill_color = "white" if self._is_active else "" + anchor_points = self._get_anchor_points((extract_box[:2], + extract_box[2:4], + extract_box[4:6], + extract_box[6:])) + for idx, (anc_dsp, anc_grb) in enumerate(zip(*anchor_points)): + dsp_kwargs = {"outline": color, "fill": fill_color, "width": 1} + grb_kwargs = {"outline": '', "fill": '', "width": 1, "activefill": activefill_color} + dsp_key = f"eb_anc_dsp_{idx}" + grb_key = f"eb_anc_grb_{idx}" + self._object_tracker(dsp_key, "oval", face_index, anc_dsp, dsp_kwargs) + self._object_tracker(grb_key, "oval", face_index, anc_grb, grb_kwargs) + logger.trace("Updated extract box anchor annotations") + + # << MOUSE HANDLING >> + # Mouse cursor display + def _update_cursor(self, event): + """ Update the cursor when it is hovering over an extract box and update + :attr:`_mouse_location` with the current cursor position. + + Parameters + ---------- + event: :class:`tkinter.Event` + The current tkinter mouse event + """ + if self._check_cursor_anchors(): + return + if self._check_cursor_box(): + return + if self._check_cursor_rotate(event): + return + self._canvas.config(cursor="") + self._mouse_location = None + + def _check_cursor_anchors(self): + """ Check whether the cursor is over a corner anchor. + + If it is, set the appropriate cursor type and set :attr:`_mouse_location` to + ("anchor", `face index`, `corner_index`) + + Returns + ------- + bool + ``True`` if cursor is over an anchor point otherwise ``False`` + """ + # pylint:disable=duplicate-code + anchors = set(self._canvas.find_withtag("eb_anc_grb")) + item_ids = set(self._canvas.find_withtag("current")).intersection(anchors) + if not item_ids: + return False + item_id = list(item_ids)[0] + tags = self._canvas.gettags(item_id) + face_idx = int(next(tag for tag in tags if tag.startswith("face_")).split("_")[-1]) + corner_idx = int(next(tag for tag in tags + if tag.startswith("eb_anc_grb_") + and "face_" not in tag).split("_")[-1]) + + pos_x, pos_y = self._corner_order[corner_idx] + self._canvas.config(cursor=f"{pos_x}_{pos_y}_corner") + self._mouse_location = ("anchor", face_idx, corner_idx) + return True + + def _check_cursor_box(self): + """ Check whether the cursor is inside an extract box. + + If it is, set the appropriate cursor type and set :attr:`_mouse_location` to + ("box", `face index`) + + Returns + ------- + bool + ``True`` if cursor is over a rotate point otherwise ``False`` + """ + extract_boxes = set(self._canvas.find_withtag("eb_box")) + item_ids = set(self._canvas.find_withtag("current")).intersection(extract_boxes) + if not item_ids: + return False + item_id = list(item_ids)[0] + self._canvas.config(cursor="fleur") + self._mouse_location = ("box", next(int(tag.split("_")[-1]) + for tag in self._canvas.gettags(item_id) + if tag.startswith("face_"))) + return True + + def _check_cursor_rotate(self, event): + """ Check whether the cursor is in an area to rotate the extract box. + + If it is, set the appropriate cursor type and set :attr:`_mouse_location` to + ("rotate", `face index`) + + Notes + ----- + This code is executed after the check has been completed to see if the mouse is inside + the extract box. For this reason, we don't bother running a check to see if the mouse + is inside the box, as this code will never run if that is the case. + + Parameters + ---------- + event: :class:`tkinter.Event` + The current tkinter mouse event + + Returns + ------- + bool + ``True`` if cursor is over a rotate point otherwise ``False`` + """ + # pylint:disable=duplicate-code + distance = 30 + boxes = np.array([np.array(self._canvas.coords(item_id)).reshape(4, 2) + for item_id in self._canvas.find_withtag("eb_box") + if self._canvas.itemcget(item_id, "state") != "hidden"]) + position = np.array((event.x, event.y)).astype("float32") + for face_idx, points in enumerate(boxes): + if any(np.all(position > point - distance) and np.all(position < point + distance) + for point in points): + self._canvas.config(cursor="exchange") + self._mouse_location = ("rotate", face_idx) + return True + return False + + # Mouse click actions + def set_mouse_click_actions(self): + """ Add context menu to OS specific right click action. """ + # pylint:disable=duplicate-code + super().set_mouse_click_actions() + self._canvas.bind("" if platform.system() == "Darwin" else "", + self._context_menu) + + def _drag_start(self, event): + """ The action to perform when the user starts clicking and dragging the mouse. + + Selects the correct extract box action based on the initial cursor position. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + if self._mouse_location is None: + self._drag_data = {} + self._drag_callback = None + return + self._drag_data["current_location"] = np.array((event.x, event.y)) + callback = {"anchor": self._resize, "rotate": self._rotate, "box": self._move} + self._drag_callback = callback[self._mouse_location[0]] + + def _drag_stop(self, event): # pylint:disable=unused-argument + """ Trigger a viewport thumbnail update on click + drag release + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Required but unused. + """ + if self._mouse_location is None: + return + self._det_faces.update.post_edit_trigger(self._globals.frame_index, + self._mouse_location[1]) + + def _move(self, event): + """ Updates the underlying detected faces landmarks based on mouse dragging delta, + which moves the Extract box on a drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + if not self._drag_data: + return + shift_x = event.x - self._drag_data["current_location"][0] + shift_y = event.y - self._drag_data["current_location"][1] + scaled_shift = self.scale_from_display(np.array((shift_x, shift_y)), do_offset=False) + self._det_faces.update.landmarks(self._globals.frame_index, + self._mouse_location[1], + *scaled_shift) + self._drag_data["current_location"] = (event.x, event.y) + + def _resize(self, event): + """ Resizes the landmarks contained within an extract box on a corner anchor drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + face_idx = self._mouse_location[1] + face_tag = f"eb_box_face_{face_idx}" + position = np.array((event.x, event.y)) + box = np.array(self._canvas.coords(face_tag)) + center = np.array((sum(box[0::2]) / 4, sum(box[1::2]) / 4)) + if not self._check_in_bounds(center, box, position): + logger.trace("Drag out of bounds. Not updating") + self._drag_data["current_location"] = position + return + + start = self._drag_data["current_location"] + distance = ((np.linalg.norm(center - start) - np.linalg.norm(center - position)) + * get_config().scaling_factor) + size = ((box[2] - box[0]) ** 2 + (box[3] - box[1]) ** 2) ** 0.5 + scale = 1 - (distance / size) + logger.trace("face_index: %s, center: %s, start: %s, position: %s, distance: %s, " + "size: %s, scale: %s", face_idx, center, start, position, distance, size, + scale) + if size * scale < 20: + # Don't over shrink the box + logger.trace("Box would size to less than 20px. Not updating") + self._drag_data["current_location"] = position + return + + self._det_faces.update.landmarks_scale(self._globals.frame_index, + face_idx, + scale, + self.scale_from_display(center)) + self._drag_data["current_location"] = position + + def _check_in_bounds(self, center, box, position): + """ Ensure that a resize drag does is not going to cross the center point from it's initial + corner location. + + Parameters + ---------- + center: :class:`numpy.ndarray` + The (`x`, `y`) center point of the face extract box + box: :class:`numpy.ndarray` + The canvas coordinates of the extract box polygon's corners + position: : class:`numpy.ndarray` + The current (`x`, `y`) position of the mouse cursor + + Returns + ------- + bool + ``True`` if the drag operation does not cross the center point otherwise ``False`` + """ + # Generate lines that span the full frame (x and y) along the center point + center_x = np.array(((center[0], 0), (center[0], self._globals.frame_display_dims[1]))) + center_y = np.array(((0, center[1]), (self._globals.frame_display_dims[0], center[1]))) + + # Generate a line coming from the current corner location to the current cursor position + full_line = np.array((box[self._mouse_location[2] * 2:self._mouse_location[2] * 2 + 2], + position)) + logger.trace("center: %s, center_x_line: %s, center_y_line: %s, full_line: %s", + center, center_x, center_y, full_line) + + # Check whether any of the generated lines intersect + for line in (center_x, center_y): + if (self._is_ccw(full_line[0], *line) != self._is_ccw(full_line[1], *line) and + self._is_ccw(*full_line, line[0]) != self._is_ccw(*full_line, line[1])): + logger.trace("line: %s crosses center: %s", full_line, center) + return False + return True + + @staticmethod + def _is_ccw(point_a, point_b, point_c): + """ Check whether 3 points are counter clockwise from each other. + + Parameters + ---------- + point_a: :class:`numpy.ndarray` + The first (`x`, `y`) point to check for counter clockwise ordering + point_b: :class:`numpy.ndarray` + The second (`x`, `y`) point to check for counter clockwise ordering + point_c: :class:`numpy.ndarray` + The third (`x`, `y`) point to check for counter clockwise ordering + + Returns + ------- + bool + ``True`` if the 3 points are provided in counter clockwise order otherwise ``False`` + """ + return ((point_c[1] - point_a[1]) * (point_b[0] - point_a[0]) > + (point_b[1] - point_a[1]) * (point_c[0] - point_a[0])) + + def _rotate(self, event): + """ Rotates the landmarks contained within an extract box on a corner rotate drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + face_idx = self._mouse_location[1] + face_tag = f"eb_box_face_{face_idx}" + box = np.array(self._canvas.coords(face_tag)) + position = np.array((event.x, event.y)) + + center = np.array((sum(box[0::2]) / 4, sum(box[1::2]) / 4)) + init_to_center = self._drag_data["current_location"] - center + new_to_center = position - center + angle = np.rad2deg(np.arctan2(*new_to_center) - np.arctan2(*init_to_center)) + logger.trace("face_index: %s, box: %s, center: %s, init_to_center: %s, new_to_center: %s" + "center: %s, angle: %s", face_idx, box, center, init_to_center, new_to_center, + center, angle) + + self._det_faces.update.landmarks_rotate(self._globals.frame_index, + face_idx, + angle, + self.scale_from_display(center)) + self._drag_data["current_location"] = position + + def _get_scale(self): + """ Obtain the scaling for the extract box resize """ + + def _context_menu(self, event): + """ Create a right click context menu to delete the alignment that is being + hovered over. """ + if self._mouse_location is None or self._mouse_location[0] != "box": + return + self._right_click_menu.popup(event) + + def _delete_current_face(self, *args): # pylint:disable=unused-argument + """ Called by the right click delete event. Deletes the face that the mouse is currently + over. + + Parameters + ---------- + args: tuple (unused) + The event parameter is passed in by the hot key binding, so args is required + """ + if self._mouse_location is None or self._mouse_location[0] != "box": + return + self._det_faces.update.delete(self._globals.frame_index, self._mouse_location[1]) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/frameviewer/editor/landmarks.py b/tools/manual/frameviewer/editor/landmarks.py new file mode 100644 index 0000000000..2e502cb57d --- /dev/null +++ b/tools/manual/frameviewer/editor/landmarks.py @@ -0,0 +1,468 @@ +#!/usr/bin/env python3 +""" Landmarks Editor and Landmarks Mesh viewer for the manual adjustments tool """ +import gettext +import numpy as np + +from lib.align import AlignedFace, LANDMARK_PARTS, LandmarkType +from lib.utils import get_module_objects +from ._base import Editor, logger + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class Landmarks(Editor): + """ The Landmarks Editor. + + Adjust individual landmark points and re-generate Extract Box. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + """ + def __init__(self, canvas, detected_faces): + control_text = _("Landmark Point Editor\nEdit the individual landmark points.\n\n" + " - Click and drag individual points to relocate.\n" + " - Draw a box to select multiple points to relocate.") + self._selection_box = canvas.create_rectangle(0, 0, 0, 0, + dash=(2, 4), + state="hidden", + outline="gray", + fill="blue", + stipple="gray12") + super().__init__(canvas, detected_faces, control_text) + # Clear selection box on an editor or frame change + self._canvas._tk_action_var.trace("w", lambda *e: self._reset_selection()) + self._globals.var_frame_index.trace_add("write", lambda *e: self._reset_selection()) + + def _add_actions(self): + """ Add the optional action buttons to the viewer. Current actions are Point, Select + and Zoom. """ + self._add_action("magnify", "zoom", _("Magnify/Demagnify the View"), + group=None, hotkey="M") + self._actions["magnify"]["tk_var"].trace("w", self._toggle_zoom) + + # CALLBACKS + def _toggle_zoom(self, *args): # pylint:disable=unused-argument + """ Clear any selections when switching mode and perform an update. + + Parameters + ---------- + args: tuple + tkinter callback arguments. Required but unused. + """ + self._reset_selection() + self._globals.var_full_update.set(True) + + def _reset_selection(self, event=None): # pylint:disable=unused-argument + """ Reset the selection box and the selected landmark annotations. """ + self._canvas.itemconfig("lm_selected", outline=self._control_color) + self._canvas.dtag("lm_selected") + self._canvas.itemconfig(self._selection_box, + stipple="gray12", + fill="blue", + outline="gray", + state="hidden") + self._canvas.coords(self._selection_box, 0, 0, 0, 0) + self._drag_data = {} + if event is not None: + self._drag_start(event) + + def update_annotation(self): + """ Get the latest Landmarks points and update. """ + zoomed_offset = self._zoomed_roi[:2] + for face_idx, face in enumerate(self._face_iterator): + face_index = self._globals.face_index if self._globals.is_zoomed else face_idx + if self._globals.is_zoomed: + aligned = AlignedFace(face.landmarks_xy, + centering="face", + size=min(self._globals.frame_display_dims)) + landmarks = aligned.landmarks + zoomed_offset + # Hide all landmarks and only display selected + self._canvas.itemconfig("lm_dsp", state="hidden") + self._canvas.itemconfig(f"lm_dsp_face_{face_index}", state="normal") + else: + landmarks = self._scale_to_display(face.landmarks_xy) + for lm_idx, landmark in enumerate(landmarks): + self._display_landmark(landmark, face_index, lm_idx) + self._label_landmark(landmark, face_index, lm_idx) + self._grab_landmark(landmark, face_index, lm_idx) + logger.trace("Updated landmark annotations") + + def _display_landmark(self, bounding_box, face_index, landmark_index): + """ Add an individual landmark display annotation to the canvas. + + Parameters + ---------- + bounding_box: :class:`numpy.ndarray` + The (left, top), (right, bottom) (x, y) coordinates of the oval bounding box for this + landmark + face_index: int + The index of the face within the current frame + landmark_index: int + The index point of this landmark + """ + radius = 1 + color = self._control_color + bbox = (bounding_box[0] - radius, bounding_box[1] - radius, + bounding_box[0] + radius, bounding_box[1] + radius) + key = f"lm_dsp_{landmark_index}" + kwargs = {"outline": color, "fill": color, "width": radius} + self._object_tracker(key, "oval", face_index, bbox, kwargs) + + def _label_landmark(self, bounding_box, face_index, landmark_index): + """ Add a text label for a landmark to the canvas. + + Parameters + ---------- + bounding_box: :class:`numpy.ndarray` + The (left, top), (right, bottom) (x, y) coordinates of the oval bounding box for this + landmark + face_index: int + The index of the face within the current frame + landmark_index: int + The index point of this landmark + """ + if not self._is_active: + return + top_left = np.array(bounding_box[:2]) - 20 + # NB The text must be visible to be able to get the bounding box, so set to hidden + # after the bounding box has been retrieved + + keys = [f"lm_lbl_{landmark_index}", f"lm_lbl_bg_{landmark_index}"] + text_kwargs = {"fill": "black", "font": ("Default", 10), "text": str(landmark_index + 1)} + bg_kwargs = {"fill": "#ffffea", "outline": "black"} + + text_id = self._object_tracker(keys[0], "text", face_index, top_left, text_kwargs) + bbox = self._canvas.bbox(text_id) + bbox = [bbox[0] - 2, bbox[1] - 2, bbox[2] + 2, bbox[3] + 2] + bg_id = self._object_tracker(keys[1], "rectangle", face_index, bbox, bg_kwargs) + self._canvas.tag_lower(bg_id, text_id) + self._canvas.itemconfig(text_id, state="hidden") + self._canvas.itemconfig(bg_id, state="hidden") + + def _grab_landmark(self, bounding_box, face_index, landmark_index): + """ Add an individual landmark grab anchor to the canvas. + + Parameters + ---------- + bounding_box: :class:`numpy.ndarray` + The (left, top), (right, bottom) (x, y) coordinates of the oval bounding box for this + landmark + face_index: int + The index of the face within the current frame + landmark_index: int + The index point of this landmark + """ + if not self._is_active: + return + radius = 7 + bbox = (bounding_box[0] - radius, bounding_box[1] - radius, + bounding_box[0] + radius, bounding_box[1] + radius) + key = f"lm_grb_{landmark_index}" + kwargs = {"outline": "", + "fill": "", + "width": 1, + "dash": (2, 4)} + self._object_tracker(key, "oval", face_index, bbox, kwargs) + + # << MOUSE HANDLING >> + # Mouse cursor display + def _update_cursor(self, event): + """ Set the cursor action. + + Launch the cursor update action for the currently selected edit mode. + + Parameters + ---------- + event: :class:`tkinter.Event` + The current tkinter mouse event + """ + self._hide_labels() + if self._drag_data: + self._update_cursor_select_mode(event) + else: + objs = self._canvas.find_withtag(f"lm_grb_face_{self._globals.face_index}" + if self._globals.is_zoomed else "lm_grb") + item_ids = set(self._canvas.find_overlapping(event.x - 6, + event.y - 6, + event.x + 6, + event.y + 6)).intersection(objs) + bboxes = [self._canvas.bbox(idx) for idx in item_ids] + item_id = next((idx for idx, bbox in zip(item_ids, bboxes) + if bbox[0] <= event.x <= bbox[2] and bbox[1] <= event.y <= bbox[3]), + None) + if item_id: + self._update_cursor_point_mode(item_id) + else: + self._canvas.config(cursor="") + self._mouse_location = None + + def _hide_labels(self): + """ Clear all landmark text labels from display """ + self._canvas.itemconfig("lm_lbl", state="hidden") + self._canvas.itemconfig("lm_lbl_bg", state="hidden") + self._canvas.itemconfig("lm_grb", fill="", outline="") + + def _update_cursor_point_mode(self, item_id): + """ Update the cursor when the mouse is over an individual landmark's grab anchor. Displays + the landmark label for the landmark under the cursor. Updates :attr:`_mouse_location` with + the current cursor position. + + Parameters + ---------- + item_id: int + The tkinter canvas object id for the landmark point that the cursor is over + """ + self._canvas.itemconfig(item_id, outline="yellow") + tags = self._canvas.gettags(item_id) + face_idx = int(next(tag for tag in tags if tag.startswith("face_")).split("_")[-1]) + lm_idx = int(next(tag for tag in tags if tag.startswith("lm_grb_")).split("_")[-1]) + obj_idx = (face_idx, lm_idx) + + self._canvas.config(cursor="none") + for prefix in ("lm_lbl_", "lm_lbl_bg_"): + tag = f"{prefix}{lm_idx}_face_{face_idx}" + logger.trace("Displaying: %s tag: %s", self._canvas.type(tag), tag) + self._canvas.itemconfig(tag, state="normal") + self._mouse_location = obj_idx + + def _update_cursor_select_mode(self, event): + """ Update the mouse cursor when in select mode. + + Standard cursor returned when creating a new selection box. Move cursor returned when over + an existing selection box + + Parameters + ---------- + event: :class:`tkinter.Event` + The current tkinter mouse event + """ + bbox = self._canvas.coords(self._selection_box) + if bbox[0] <= event.x <= bbox[2] and bbox[1] <= event.y <= bbox[3]: + self._canvas.config(cursor="fleur") + else: + self._canvas.config(cursor="") + + # Mouse actions + def _drag_start(self, event): + """ The action to perform when the user starts clicking and dragging the mouse. + + The underlying Detected Face's landmark is updated for the point being edited. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + sel_box = self._canvas.coords(self._selection_box) + if self._mouse_location is not None: # Point edit mode + self._drag_data["start_location"] = (event.x, event.y) + self._drag_callback = self._move_point + elif not self._drag_data: # Initial point selection box + self._drag_data["start_location"] = (event.x, event.y) + self._drag_callback = self._select + elif sel_box[0] <= event.x <= sel_box[2] and sel_box[1] <= event.y <= sel_box[3]: + # Move point selection box + self._drag_data["start_location"] = (event.x, event.y) + self._drag_callback = self._move_selection + else: # Reset + self._drag_data = {} + self._drag_callback = None + self._reset_selection(event) + + def _drag_stop(self, event): # pylint:disable=unused-argument + """ In select mode, call the select mode callback. + + In point mode: trigger a viewport thumbnail update on click + drag release + + If there is drag data, and there are selected points in the drag data then + trigger the selected points stop code. + + Otherwise reset the selection box and return + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Required but unused. + """ + if self._mouse_location is not None: # Point edit mode + self._det_faces.update.post_edit_trigger(self._globals.frame_index, + self._mouse_location[0]) + self._mouse_location = None + self._drag_data = {} + elif self._drag_data and self._drag_data.get("selected", False): + self._drag_stop_selected() + else: + logger.debug("No selected data. Clearing. drag_data: %s", self._drag_data) + self._reset_selection() + + def _drag_stop_selected(self): + """ Action to perform when mouse drag is stopped in selected points editor mode. + + If there is already a selection, update the viewport thumbnail + + If this is a new selection, then obtain the selected points and track + """ + if "face_index" in self._drag_data: # Selected data has been moved + self._det_faces.update.post_edit_trigger(self._globals.frame_index, + self._drag_data["face_index"]) + return + + # This is a new selection + face_idx = set() + landmark_indices = [] + + for item_id in self._canvas.find_withtag("lm_selected"): + tags = self._canvas.gettags(item_id) + face_idx.add(next(int(tag.split("_")[-1]) + for tag in tags if tag.startswith("face_"))) + landmark_indices.append(next(int(tag.split("_")[-1]) + for tag in tags + if tag.startswith("lm_dsp_") and "face" not in tag)) + if len(face_idx) != 1: + logger.trace("Not exactly 1 face in selection. Aborting. Face indices: %s", face_idx) + self._reset_selection() + return + + self._drag_data["face_index"] = face_idx.pop() + self._drag_data["landmarks"] = landmark_indices + self._canvas.itemconfig(self._selection_box, stipple="", fill="", outline="#ffff00") + self._snap_selection_to_points() + + def _snap_selection_to_points(self): + """ Snap the selection box to the selected points. + + As the landmarks are calculated and redrawn, the selection box can drift. This is + particularly true in zoomed mode. The selection box is therefore redrawn to bind just + outside of the selected points. + """ + all_coords = np.array([self._canvas.coords(item_id) + for item_id in self._canvas.find_withtag("lm_selected")]) + mins = np.min(all_coords, axis=0) + maxes = np.max(all_coords, axis=0) + box_coords = [np.min(mins[[0, 2]] - 5), + np.min(mins[[1, 3]] - 5), + np.max(maxes[[0, 2]] + 5), + np.max(maxes[[1, 3]]) + 5] + self._canvas.coords(self._selection_box, *box_coords) + + def _move_point(self, event): + """ Moves the selected landmark point box and updates the underlying landmark on a point + drag event. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + face_idx, lm_idx = self._mouse_location + shift_x = event.x - self._drag_data["start_location"][0] + shift_y = event.y - self._drag_data["start_location"][1] + + if self._globals.is_zoomed: + scaled_shift = np.array((shift_x, shift_y)) + else: + scaled_shift = self.scale_from_display(np.array((shift_x, shift_y)), do_offset=False) + self._det_faces.update.landmark(self._globals.frame_index, + face_idx, + lm_idx, + *scaled_shift, + self._globals.is_zoomed) + self._drag_data["start_location"] = (event.x, event.y) + + def _select(self, event): + """ Create a selection box on mouse drag event when in "select" mode + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + if self._canvas.itemcget(self._selection_box, "state") == "hidden": + self._canvas.itemconfig(self._selection_box, state="normal") + coords = (*self._drag_data["start_location"], event.x, event.y) + self._canvas.coords(self._selection_box, *coords) + enclosed = set(self._canvas.find_enclosed(*coords)) + landmarks = set(self._canvas.find_withtag("lm_dsp")) + + for item_id in list(enclosed.intersection(landmarks)): + self._canvas.addtag_withtag("lm_selected", item_id) + self._canvas.itemconfig("lm_selected", outline="#ffff00") + self._drag_data["selected"] = True + + def _move_selection(self, event): + """ Move a selection box and the landmarks contained when in "select" mode and a selection + box has been drawn. """ + shift_x = event.x - self._drag_data["start_location"][0] + shift_y = event.y - self._drag_data["start_location"][1] + if self._globals.is_zoomed: + scaled_shift = np.array((shift_x, shift_y)) + else: + scaled_shift = self.scale_from_display(np.array((shift_x, shift_y)), do_offset=False) + self._canvas.move(self._selection_box, shift_x, shift_y) + + self._det_faces.update.landmark(self._globals.frame_index, + self._drag_data["face_index"], + self._drag_data["landmarks"], + *scaled_shift, + self._globals.is_zoomed) + self._snap_selection_to_points() + self._drag_data["start_location"] = (event.x, event.y) + + +class Mesh(Editor): + """ The Landmarks Mesh Display. + + There are no editing options for Mesh editor. It is purely aesthetic and updated when other + editors are used. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + """ + def __init__(self, canvas, detected_faces): + super().__init__(canvas, detected_faces, None) + + def update_annotation(self): # pylint:disable=too-many-locals + """ Get the latest Landmarks and update the mesh.""" + key = "mesh" + color = self._control_color + zoomed_offset = self._zoomed_roi[:2] + for face_idx, face in enumerate(self._face_iterator): + face_index = self._globals.face_index if self._globals.is_zoomed else face_idx + if self._globals.is_zoomed: + aligned = AlignedFace(face.landmarks_xy, + centering="face", + size=min(self._globals.frame_display_dims)) + landmarks = aligned.landmarks + zoomed_offset + landmark_mapping = LANDMARK_PARTS[aligned.landmark_type] + # Hide all meshes and only display selected + self._canvas.itemconfig("Mesh", state="hidden") + self._canvas.itemconfig(f"Mesh_face_{face_index}", state="normal") + else: + landmarks = self._scale_to_display(face.landmarks_xy) + landmark_mapping = LANDMARK_PARTS[LandmarkType.from_shape(landmarks.shape)] + logger.trace("Drawing Landmarks Mesh: (landmarks: %s, color: %s)", landmarks, color) + for idx, (start, end, fill) in enumerate(landmark_mapping.values()): + key = f"mesh_{idx}" + pts = landmarks[start:end].flatten() + if fill: + kwargs = {"fill": "", "outline": color, "width": 1} + asset = "polygon" + else: + kwargs = {"fill": color, "width": 1} + asset = "line" + self._object_tracker(key, asset, face_index, pts, kwargs) + # Place mesh as bottom annotation + self._canvas.tag_raise(self.__class__.__name__, "main_image") + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/frameviewer/editor/mask.py b/tools/manual/frameviewer/editor/mask.py new file mode 100644 index 0000000000..5101cde3a8 --- /dev/null +++ b/tools/manual/frameviewer/editor/mask.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +""" Mask Editor for the manual adjustments tool """ +import gettext +import tkinter as tk + +import numpy as np +import cv2 +from PIL import Image, ImageTk + +from lib.utils import get_module_objects + +from ._base import ControlPanelOption, Editor, logger + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class Mask(Editor): + """ The mask Editor. + + Edit a mask in the alignments file. + + Parameters + ---------- + canvas: :class:`tkinter.Canvas` + The canvas that holds the image and annotations + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The _detected_faces data for this manual session + """ + def __init__(self, canvas, detected_faces): + self._meta = [] + self._tk_faces = [] + self._internal_size = 512 + control_text = _("Mask Editor\nEdit the mask." + "\n - NB: For Landmark based masks (e.g. components/extended) it is " + "better to make sure the landmarks are correct rather than editing the " + "mask directly. Any change to the landmarks after editing the mask will " + "override your manual edits.") + key_bindings = {"[": lambda *e, i=False: self._adjust_brush_radius(increase=i), + "]": lambda *e, i=True: self._adjust_brush_radius(increase=i)} + super().__init__(canvas, detected_faces, + control_text=control_text, key_bindings=key_bindings) + # Bind control click for reverse painting + self._canvas.bind("", self._control_click) + self._mask_type = self._set_tk_mask_change_callback() + self._cursor_shape = self._set_tk_cursor_shape_change_callback() + self._mouse_location = [ + self._get_cursor_shape(), False] + + @property + def _opacity(self): + """ float: The mask opacity setting from the control panel from 0.0 - 1.0. """ + annotation = self.__class__.__name__ + return self._annotation_formats[annotation]["mask_opacity"].get() / 100.0 + + @property + def _brush_radius(self): + """ int: The radius of the brush to use as set in control panel options """ + return self._control_vars["brush"]["BrushSize"].get() + + @property + def _edit_mode(self): + """ str: The currently selected edit mode based on optional action button. + One of "draw" or "erase" """ + action = [name for name, option in self._actions.items() + if option["group"] == "paint" and option["tk_var"].get()] + return "draw" if not action else action[0] + + @property + def _cursor_color(self): + """ str: The hex code for the selected cursor color """ + return self._control_vars["brush"]["CursorColor"].get() + + @property + def _cursor_shape_name(self): + """ str: The selected cursor shape """ + return self._control_vars["display"]["CursorShape"].get() + + def _add_actions(self): + """ Add the optional action buttons to the viewer. Current actions are Draw, Erase + and Zoom. """ + self._add_action("magnify", "zoom", _("Magnify/Demagnify the View"), + group=None, hotkey="M") + self._add_action("draw", "draw", _("Draw Tool"), group="paint", hotkey="D") + self._add_action("erase", "erase", _("Erase Tool"), group="paint", hotkey="E") + self._actions["magnify"]["tk_var"].trace( + "w", + lambda *e: self._globals.var_full_update.set(True)) + + def _add_controls(self): + """ Add the mask specific control panel controls. + + Current controls are: + - the mask type to edit + - the size of brush to use + - the cursor display color + """ + masks = sorted(msk.title() for msk in list(self._det_faces.available_masks) + ["None"]) + default = masks[0] if len(masks) == 1 else [mask for mask in masks if mask != "None"][0] + self._add_control(ControlPanelOption("Mask type", + str, + group="Display", + choices=masks, + default=default, + is_radio=True, + helptext=_("Select which mask to edit"))) + self._add_control(ControlPanelOption("Brush Size", + int, + group="Brush", + min_max=(1, 100), + default=10, + rounding=1, + helptext=_("Set the brush size. ([ - decrease, " + "] - increase)"))) + self._add_control(ControlPanelOption("Cursor Color", + str, + group="Brush", + choices="colorchooser", + default="#ffffff", + helptext=_("Select the brush cursor color."))) + self._add_control(ControlPanelOption("Cursor Shape", + str, + group="Display", + choices=["Circle", "Rectangle"], + default="Circle", + is_radio=True, + helptext=_("Select a shape for masking cursor."))) + + def _set_tk_mask_change_callback(self): + """ Add a trace to change the displayed mask on a mask type change. """ + var = self._control_vars["display"]["MaskType"] + var.trace("w", lambda *e: self._on_mask_type_change()) + return var.get() + + def _set_tk_cursor_shape_change_callback(self): + """ Add a trace to change the displayed cursor on a cursor shape type change. """ + var = self._control_vars["display"]["CursorShape"] + var.trace("w", lambda *e: self._on_cursor_shape_change()) + return var.get() + + def _on_cursor_shape_change(self): + self._mouse_location[0] = self._get_cursor_shape() + + def _on_mask_type_change(self): + """ Update the displayed mask on a mask type change """ + mask_type = self._control_vars["display"]["MaskType"].get() + if mask_type == self._mask_type: + return + self._meta = {"position": self._globals.frame_index} + self._mask_type = mask_type + self._globals.var_full_update.set(True) + + def hide_annotation(self, tag=None): + """ Clear the mask :attr:`_meta` dict when hiding the annotation. """ + super().hide_annotation() + self._meta = {} + + def update_annotation(self): + """ Update the mask annotation with the latest mask. """ + position = self._globals.frame_index + if position != self._meta.get("position", -1): + # Reset meta information when moving to a new frame + self._meta = {"position": position} + key = self.__class__.__name__ + mask_type = self._control_vars["display"]["MaskType"].get().lower() + color = self._control_color[1:] + rgb_color = np.array(tuple(int(color[i:i + 2], 16) for i in (0, 2, 4))) + roi_color = self._annotation_formats["ExtractBox"]["color"].get() + opacity = self._opacity + for idx, face in enumerate(self._face_iterator): + face_idx = self._globals.face_index if self._globals.is_zoomed else idx + mask = face.mask.get(mask_type, None) + if mask is None: + continue + self._set_face_meta_data(mask, face_idx) + self._update_mask_image(key.lower(), face_idx, rgb_color, opacity) + self._update_roi_box(mask, face_idx, roi_color) + + self._canvas.tag_raise(self._mouse_location[0]) # Always keep brush cursor on top + logger.trace("Updated mask annotation") + + def _set_face_meta_data(self, mask, face_index): + """ Set the metadata for the current face if it has changed or is new. + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The one channel mask cropped to the ROI + face_index: int + The index pertaining to the current face + """ + masks = self._meta.get("mask", None) + if masks is not None and len(masks) - 1 == face_index: + logger.trace("Meta information already defined for face: %s", face_index) + return + + logger.debug("Defining meta information for face: %s", face_index) + scale = self._internal_size / mask.stored_size + self._set_full_frame_meta(mask, scale) + dims = (self._internal_size, self._internal_size) + self._meta.setdefault("mask", []).append(cv2.resize(mask.stored_mask, + dims, + interpolation=cv2.INTER_CUBIC)) + if self.zoomed_centering != mask.stored_centering: + self.zoomed_centering = mask.stored_centering + + def _set_full_frame_meta(self, mask, mask_scale): + """ Sets the meta information for displaying the mask in full frame mode. + + Parameters + ---------- + mask: :class:`lib.align.Mask` + The mask object + mask_scale: float + The scaling factor from the stored mask size to the internal mask size + + Sets the following parameters to :attr:`_meta`: + - roi_mask: the rectangular ROI box from the full frame that contains the original ROI + for the full frame mask + - top_left: The location that the roi_mask should be placed in the display frame + - affine_matrix: The matrix for transposing the mask to a full frame + - interpolator: The cv2 interpolation method to use for transposing mask to a + full frame + - slices: The (`x`, `y`) slice objects required to extract the mask ROI + from the full frame + """ + frame_dims = self._globals.current_frame.display_dims + scaled_mask_roi = np.rint(mask.original_roi * + self._globals.current_frame.scale).astype("int32") + + # Scale and clip the ROI to fit within display frame boundaries + clipped_roi = scaled_mask_roi.clip(min=(0, 0), max=frame_dims) + + # Obtain min and max points to get ROI as a rectangle + min_max = {"min": clipped_roi.min(axis=0), "max": clipped_roi.max(axis=0)} + + # Create a bounding box rectangle ROI + roi_dims = np.rint((min_max["max"][1] - min_max["min"][1], + min_max["max"][0] - min_max["min"][0])).astype("uint16") + roi = {"mask": np.zeros(roi_dims, dtype="uint8")[..., None], + "corners": np.expand_dims(scaled_mask_roi - min_max["min"], axis=0)} + # Block out areas outside of the actual mask ROI polygon + cv2.fillPoly(roi["mask"], roi["corners"], 255) + logger.trace("Setting Full Frame mask ROI. shape: %s", roi["mask"].shape) + + # obtain the slices for cropping mask from full frame + xy_slices = (slice(int(round(min_max["min"][1])), int(round(min_max["max"][1]))), + slice(int(round(min_max["min"][0])), int(round(min_max["max"][0])))) + + # Adjust affine matrix for internal mask size and display dimensions + adjustments = (np.array([[mask_scale, 0., 0.], [0., mask_scale, 0.]]), + np.array([[1 / self._globals.current_frame.scale, 0., 0.], + [0., 1 / self._globals.current_frame.scale, 0.], + [0., 0., 1.]])) + in_matrix = np.dot(adjustments[0], + np.concatenate((mask.affine_matrix, np.array([[0., 0., 1.]])))) + affine_matrix = np.dot(in_matrix, adjustments[1]) + + # Get the size of the mask roi box in the frame + side_sizes = (scaled_mask_roi[1][0] - scaled_mask_roi[0][0], + scaled_mask_roi[1][1] - scaled_mask_roi[0][1]) + mask_roi_size = (side_sizes[0] ** 2 + side_sizes[1] ** 2) ** 0.5 + + self._meta.setdefault("roi_mask", []).append(roi["mask"]) + self._meta.setdefault("affine_matrix", []).append(affine_matrix) + self._meta.setdefault("interpolator", []).append(mask.interpolator) + self._meta.setdefault("slices", []).append(xy_slices) + self._meta.setdefault("top_left", []).append(min_max["min"] + self._canvas.offset) + self._meta.setdefault("mask_roi_size", []).append(mask_roi_size) + + def _update_mask_image(self, key, face_index, rgb_color, opacity): + """ Obtain a mask, overlay over image and add to canvas or update. + + Parameters + ---------- + key: str + The base annotation name for creating tags + face_index: int + The index of the face within the current frame + rgb_color: tuple + The color that the mask should be displayed as + opacity: float + The opacity to apply to the mask + """ + mask = (self._meta["mask"][face_index] * opacity).astype("uint8") + if self._globals.is_zoomed: + display_image = self._update_mask_image_zoomed(mask, rgb_color) + top_left = self._zoomed_roi[:2] + # Hide all masks and only display selected + self._canvas.itemconfig("Mask", state="hidden") + self._canvas.itemconfig(f"Mask_face_{face_index}", state="normal") + else: + display_image = self._update_mask_image_full_frame(mask, rgb_color, face_index) + top_left = self._meta["top_left"][face_index] + + if len(self._tk_faces) < face_index + 1: + logger.trace("Adding new Photo Image for face index: %s", face_index) + self._tk_faces.append(ImageTk.PhotoImage(display_image)) + elif self._tk_faces[face_index].width() != display_image.width: + logger.trace("Replacing existing Photo Image on width change for face index: %s", + face_index) + self._tk_faces[face_index] = ImageTk.PhotoImage(display_image) + else: + logger.trace("Updating existing image") + self._tk_faces[face_index].paste(display_image) + + self._object_tracker(key, + "image", + face_index, + top_left, + {"image": self._tk_faces[face_index], "anchor": tk.NW}) + + def _update_mask_image_zoomed(self, mask, rgb_color): + """ Update the mask image when zoomed in. + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The raw mask + rgb_color: tuple + The rgb color selected for the mask + + Returns + ------- + :class: `PIL.Image` + The zoomed mask image formatted for display + """ + rgb = np.tile(rgb_color, self._zoomed_dims + (1, )).astype("uint8") + mask = cv2.resize(mask, + tuple(reversed(self._zoomed_dims)), + interpolation=cv2.INTER_CUBIC)[..., None] + rgba = np.concatenate((rgb, mask), axis=2) + return Image.fromarray(rgba) + + def _update_mask_image_full_frame(self, mask, rgb_color, face_index): + """ Update the mask image when in full frame view. + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The raw mask + rgb_color: tuple + The rgb color selected for the mask + face_index: int + The index of the face being displayed + + Returns + ------- + :class: `PIL.Image` + The full frame mask image formatted for display + """ + frame_dims = self._globals.current_frame.display_dims + frame = np.zeros(frame_dims + (1, ), dtype="uint8") + interpolator = self._meta["interpolator"][face_index] + slices = self._meta["slices"][face_index] + mask = cv2.warpAffine(mask, + self._meta["affine_matrix"][face_index], + frame_dims, + frame, + flags=cv2.WARP_INVERSE_MAP | interpolator, + borderMode=cv2.BORDER_CONSTANT)[slices[0], slices[1]] + mask = mask[..., None] if mask.ndim == 2 else mask + rgb = np.tile(rgb_color, mask.shape).astype("uint8") + rgba = np.concatenate((rgb, np.minimum(mask, self._meta["roi_mask"][face_index])), axis=2) + return Image.fromarray(rgba) + + def _update_roi_box(self, mask, face_index, color): + """ Update the region of interest box for the current mask. + + mask: :class:`~lib.align.Mask` + The current mask object to create an ROI box for + face_index: int + The index of the face within the current frame + color: str + The hex color code that the mask should be displayed as + """ + if self._globals.is_zoomed: + roi = self._zoomed_roi + box = np.array((roi[0], roi[1], roi[2], roi[1], roi[2], roi[3], roi[0], roi[3])) + else: + box = self._scale_to_display(mask.original_roi).flatten() + top_left = box[:2] - 10 + kwargs = {"fill": color, "font": ("Default", 20, "bold"), "text": str(face_index)} + self._object_tracker("mask_text", "text", face_index, top_left, kwargs) + kwargs = {"fill": "", "outline": color, "width": 1} + self._object_tracker("mask_roi", "polygon", face_index, box, kwargs) + if self._globals.is_zoomed: + # Raise box above zoomed image + self._canvas.tag_raise(f"mask_roi_face_{face_index}") + + # << MOUSE HANDLING >> + # Mouse cursor display + def _update_cursor(self, event): + """ Set the cursor action. + + Update :attr:`_mouse_location` with the current cursor position and display appropriate + icon. + + Checks whether the mouse is over a mask ROI box and pops the paint icon. + + Parameters + ---------- + event: :class:`tkinter.Event` + The current tkinter mouse event + """ + roi_boxes = self._canvas.find_withtag("mask_roi") + item_ids = set(self._canvas.find_withtag("current")).intersection(roi_boxes) + if not item_ids: + self._canvas.config(cursor="") + self._canvas.itemconfig(self._mouse_location[0], state="hidden") + self._mouse_location[1] = None + return + item_id = list(item_ids)[0] + tags = self._canvas.gettags(item_id) + face_idx = int(next(tag for tag in tags if tag.startswith("face_")).split("_")[-1]) + + radius = self._brush_radius + coords = (event.x - radius, event.y - radius, event.x + radius, event.y + radius) + self._canvas.config(cursor="none") + self._canvas.coords(self._mouse_location[0], *coords) + self._canvas.itemconfig(self._mouse_location[0], + state="normal", + outline=self._cursor_color) + self._mouse_location[1] = face_idx + self._canvas.update_idletasks() + + def _control_click(self, event): + """ The action to perform when the user starts clicking and dragging the mouse whilst + pressing the control button. + + For editing the mask this will activate the opposite action than what is currently selected + (e.g. it will erase if draw is set and it will draw if erase is set) + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + self._drag_start(event, control_click=True) + + def _drag_start(self, event, control_click=False): # pylint:disable=arguments-differ + """ The action to perform when the user starts clicking and dragging the mouse. + + Paints on the mask with the appropriate draw or erase action. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + control_click: bool, optional + Indicates whether the control button is depressed when drag has commenced. If ``True`` + then the opposite of the selected action is performed. Default: ``False`` + """ + face_idx = self._mouse_location[1] + if face_idx is None: + self._drag_data = {} + self._drag_callback = None + else: + self._drag_data["starting_location"] = np.array((event.x, event.y)) + self._drag_data["control_click"] = control_click + self._drag_data["color"] = np.array(tuple(int(self._control_color[1:][i:i + 2], 16) + for i in (0, 2, 4))) + self._drag_data["opacity"] = self._opacity + self._get_cursor_shape_mark( + self._meta["mask"][face_idx], + np.array(((event.x, event.y), )), + face_idx) + self._drag_callback = self._paint + + def _paint(self, event): + """ Paint or erase from Mask and update cursor on click and drag. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. + """ + face_idx = self._mouse_location[1] + line = np.array((self._drag_data["starting_location"], (event.x, event.y))) + line, scale = self._transform_points(face_idx, line) + brush_radius = int(round(self._brush_radius * scale)) + color = 0 if self._edit_mode == "erase" else 255 + # Reverse action on control click + color = abs(color - 255) if self._drag_data["control_click"] else color + cv2.line(self._meta["mask"][face_idx], + tuple(line[0]), + tuple(line[1]), + color, + brush_radius * 2) + self._update_mask_image("mask", + face_idx, + self._drag_data["color"], + self._drag_data["opacity"]) + self._drag_data["starting_location"] = np.array((event.x, event.y)) + self._update_cursor(event) + + def _transform_points(self, face_index, points): + """ Transform the edit points from a full frame or zoomed view back to the mask. + + Parameters + ---------- + face_index: int + The index of the face within the current frame + points: :class:`numpy.ndarray` + The points that are to be translated from the viewer to the underlying + Detected Face + """ + if self._globals.is_zoomed: + offset = self._zoomed_roi[:2] + scale = self._internal_size / self._zoomed_dims[0] + t_points = np.rint((points - offset) * scale).astype("int32").squeeze() + else: + scale = self._internal_size / self._meta["mask_roi_size"][face_index] + t_points = np.expand_dims(points - self._canvas.offset, axis=0) + t_points = cv2.transform(t_points, self._meta["affine_matrix"][face_index]).squeeze() + t_points = np.rint(t_points).astype("int32") + logger.trace("original points: %s, transformed points: %s, scale: %s", + points, t_points, scale) + return t_points, scale + + def _drag_stop(self, event): + """ The action to perform when the user stops clicking and dragging the mouse. + + If a line hasn't been drawn then draw a circle. Update alignments. + + Parameters + ---------- + event: :class:`tkinter.Event` + The tkinter mouse event. Unused but required + """ + if not self._drag_data: + return + face_idx = self._mouse_location[1] + location = np.array(((event.x, event.y), )) + if np.array_equal(self._drag_data["starting_location"], location[0]): + self._get_cursor_shape_mark(self._meta["mask"][face_idx], location, face_idx) + self._mask_to_alignments(face_idx) + self._drag_data = {} + self._update_cursor(event) + + def _get_cursor_shape_mark(self, img, location, face_idx): + """ Draw object depending on the cursor shape selection. Defaults to circle. + + Parameters + ---------- + img: Image to draw on (mask) + location: Cursor location coordinates that will be transformed to correct + coordinates + face_index: int + The index of the face within the current frame + """ + points, scale = self._transform_points(face_idx, location) + radius = int(round(self._brush_radius * scale)) + color = 0 if self._edit_mode == "erase" else 255 + # Reverse action on control click + color = abs(color - 255) if self._drag_data["control_click"] else color + + if self._cursor_shape_name == "Rectangle": + point2 = points.copy() + points[0] = points[0] - radius + points[1] = points[1] - radius + point2[0] = point2[0] + radius + point2[1] = point2[1] + radius + cv2.rectangle(img, tuple(points), tuple(point2), color, -1) + else: + cv2.circle(img, tuple(points), radius, color, thickness=-1) + + def _get_cursor_shape(self, x_1=0, y_1=0, x_2=0, y_2=0, outline="black", state="hidden"): + if self._cursor_shape_name == "Rectangle": + return self._canvas.create_rectangle(x_1, y_1, x_2, y_2, outline=outline, state=state) + return self._canvas.create_oval(x_1, y_1, x_2, y_2, outline=outline, state=state) + + def _mask_to_alignments(self, face_index): + """ Update the annotated mask to alignments. + + Parameters + ---------- + face_index: int + The index of the face in the current frame + """ + mask_type = self._control_vars["display"]["MaskType"].get().lower() + mask = self._meta["mask"][face_index].astype("float32") / 255.0 + self._det_faces.update.mask(self._globals.frame_index, face_index, mask, mask_type) + + def _adjust_brush_radius(self, increase=True): # pylint:disable=unused-argument + """ Adjust the brush radius up or down by 2px. + + Sets the control panel option for brush radius to 2 less or 2 more than its current value + + Parameters + ---------- + increase: bool, optional + ``True`` to increment brush radius, ``False`` to decrement. Default: ``True`` + """ + radius_var = self._control_vars["brush"]["BrushSize"] + current_val = radius_var.get() + new_val = min(100, current_val + 2) if increase else max(1, current_val - 2) + logger.trace("Adjusting brush radius from %s to %s", current_val, new_val) + radius_var.set(new_val) + + delta = new_val - current_val + if delta == 0: + return + current_coords = self._canvas.coords(self._mouse_location[0]) + new_coords = tuple(coord - delta if idx < 2 else coord + delta + for idx, coord in enumerate(current_coords)) + logger.trace("Adjusting brush coordinates from %s to %s", current_coords, new_coords) + self._canvas.coords(self._mouse_location[0], new_coords) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/frameviewer/frame.py b/tools/manual/frameviewer/frame.py new file mode 100644 index 0000000000..e7ed015198 --- /dev/null +++ b/tools/manual/frameviewer/frame.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +""" The frame viewer section of the manual tool GUI """ +import gettext +import logging +import tkinter as tk +from tkinter import ttk, TclError + +from functools import partial +from time import time + +from lib.gui.control_helper import set_slider_rounding +from lib.gui.custom_widgets import Tooltip +from lib.gui.utils import get_images +from lib.utils import get_module_objects + +from .control import Navigation, BackgroundImage +from .editor import (BoundingBox, ExtractBox, Landmarks, Mask, # noqa pylint:disable=unused-import + Mesh, View) + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("tools.manual", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class DisplayFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """ The main video display frame (top left section of GUI). + + Parameters + ---------- + parent: :class:`ttk.PanedWindow` + The paned window that the display frame resides in + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + detected_faces: :class:`tools.manual.detected_faces.DetectedFaces` + The detected faces stored in the alignments file + """ + def __init__(self, parent, tk_globals, detected_faces): + logger.debug("Initializing %s: (parent: %s, tk_globals: %s, detected_faces: %s)", + self.__class__.__name__, parent, tk_globals, detected_faces) + super().__init__(parent) + + self._globals = tk_globals + self._det_faces = detected_faces + self._optional_widgets = {} + + self._actions_frame = ActionsFrame(self) + main_frame = ttk.Frame(self) + + self._transport_frame = ttk.Frame(main_frame) + self._nav = self._add_nav() + self._navigation = Navigation(self) + self._buttons = self._add_transport() + self._add_transport_tk_trace() + + video_frame = ttk.Frame(main_frame) + video_frame.bind("", self._resize) + + self._canvas = FrameViewer(video_frame, + self._globals, + self._det_faces, + self._actions_frame.actions, + self._actions_frame.tk_selected_action) + + self._actions_frame.add_optional_buttons(self.editors) + + self._transport_frame.pack(side=tk.BOTTOM, padx=5, fill=tk.X) + video_frame.pack(side=tk.TOP, expand=True, fill=tk.BOTH) + main_frame.pack(side=tk.RIGHT, expand=True, fill=tk.BOTH) + self.pack(side=tk.LEFT, anchor=tk.NW, expand=True, fill=tk.BOTH) + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def _helptext(self): + """ dict: {`name`: `help text`} Helptext lookup for navigation buttons """ + return { + "play": _("Play/Pause (SPACE)"), + "beginning": _("Go to First Frame (HOME)"), + "prev": _("Go to Previous Frame (Z)"), + "next": _("Go to Next Frame (X)"), + "end": _("Go to Last Frame (END)"), + "extract": _("Extract the faces to a folder... (Ctrl+E)"), + "save": _("Save the Alignments file (Ctrl+S)"), + "mode": _("Filter Frames to only those Containing the Selected Item (F)"), + "distance": _("Set the distance from an 'average face' to be considered misaligned. " + "Higher distances are more restrictive")} + + @property + def _btn_action(self): + """ dict: {`name`: `action`} Command lookup for navigation buttons """ + actions = {"play": self._navigation.handle_play_button, + "beginning": self._navigation.goto_first_frame, + "prev": self._navigation.decrement_frame, + "next": self._navigation.increment_frame, + "end": self._navigation.goto_last_frame, + "extract": self._det_faces.extract, + "save": self._det_faces.save} + return actions + + @property + def tk_selected_action(self): + """ :class:`tkinter.StringVar`: The variable holding the currently selected action """ + return self._actions_frame.tk_selected_action + + @property + def active_editor(self): + """ :class:`Editor`: The current editor in use based on :attr:`selected_action`. """ + return self._canvas.active_editor + + @property + def editors(self): + """ dict: All of the :class:`Editor` that the canvas holds """ + return self._canvas.editors + + @property + def navigation(self): + """ :class:`~tools.manual.frameviewer.control.Navigation`: Class that handles frame + Navigation and transport. """ + return self._navigation + + @property + def tk_control_colors(self): + """ :dict: Editor key with :class:`tkinter.StringVar` containing the selected color hex + code for each annotation """ + return {key: val["color"].tk_var for key, val in self._canvas.annotation_formats.items()} + + @property + def tk_selected_mask(self): + """ :dict: Editor key with :class:`tkinter.StringVar` containing the selected color hex + code for each annotation """ + return self._canvas.control_tk_vars["Mask"]["display"]["MaskType"] + + @property + def _filter_modes(self): + """ list: The filter modes combo box values """ + return ["All Frames", "Has Face(s)", "No Faces", "Multiple Faces", "Misaligned Faces"] + + def _add_nav(self): + """ Add the slider to navigate through frames """ + max_frame = self._globals.frame_count - 1 + frame = ttk.Frame(self._transport_frame) + + frame.pack(side=tk.TOP, fill=tk.X, pady=(0, 5)) + lbl_frame = ttk.Frame(frame) + lbl_frame.pack(side=tk.RIGHT) + tbox = ttk.Entry(lbl_frame, + width=7, + textvariable=self._globals.var_transport_index, + justify=tk.RIGHT) + tbox.pack(padx=0, side=tk.LEFT) + lbl = ttk.Label(lbl_frame, text=f"/{max_frame}") + lbl.pack(side=tk.RIGHT) + + cmd = partial(set_slider_rounding, + var=self._globals.var_transport_index, + d_type=int, + round_to=1, + min_max=(0, max_frame)) + + nav = ttk.Scale(frame, + variable=self._globals.var_transport_index, + from_=0, + to=max_frame, + command=cmd) + nav.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self._globals.var_transport_index.trace_add("write", self._set_frame_index) + return {"entry": tbox, "scale": nav, "label": lbl} + + def _set_frame_index(self, *args): # pylint:disable=unused-argument + """ Set the actual frame index based on current slider position and filter mode. """ + try: + slider_position = self._globals.var_transport_index.get() + except TclError: + # don't update the slider when the entry box has been cleared of any value + return + frames = self._det_faces.filter.frames_list + actual_position = max(0, min(len(frames) - 1, slider_position)) + if actual_position != slider_position: + self._globals.var_transport_index.set(actual_position) + frame_idx = frames[actual_position] if frames else -1 + logger.trace("slider_position: %s, frame_idx: %s", actual_position, frame_idx) + self._globals.var_frame_index.set(frame_idx) + + def _add_transport(self): + """ Add video transport controls """ + frame = ttk.Frame(self._transport_frame) + frame.pack(side=tk.BOTTOM, fill=tk.X) + icons = get_images().icons + buttons = {} + for action in ("play", "beginning", "prev", "next", "end", "save", "extract", "mode"): + padx = (0, 6) if action in ("play", "prev", "mode") else (0, 0) + side = tk.RIGHT if action in ("extract", "save", "mode") else tk.LEFT + state = ["!disabled"] if action != "save" else ["disabled"] + if action != "mode": + icon = action if action != "extract" else "folder" + wgt = ttk.Button(frame, image=icons[icon], command=self._btn_action[action]) + wgt.state(state) + else: + wgt = self._add_filter_section(frame) + wgt.pack(side=side, padx=padx) + if action != "mode": + Tooltip(wgt, text=self._helptext[action]) + buttons[action] = wgt + logger.debug("Transport buttons: %s", buttons) + return buttons + + def _add_transport_tk_trace(self): + """ Add the tkinter variable traces to buttons """ + self._navigation.tk_is_playing.trace("w", self._play) + self._det_faces.tk_unsaved.trace("w", self._toggle_save_state) + + def _add_filter_section(self, frame): + """ Add the section that holds the filter mode combo and any optional filter widgets + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The Frame that holds the filter section + + Returns + ------- + :class:`tkinter.ttk.Frame` + The filter section frame + """ + filter_frame = ttk.Frame(frame) + self._add_filter_mode_combo(filter_frame) + self._add_filter_threshold_slider(filter_frame) + filter_frame.pack(side=tk.RIGHT) + return filter_frame + + def _add_filter_mode_combo(self, frame): + """ Add the navigation mode combo box to the filter frame. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The Filter Frame that holds the filter combo box + """ + self._globals.var_filter_mode.set("All Frames") + self._globals.var_filter_mode.trace("w", self._navigation.nav_scale_callback) + nav_frame = ttk.Frame(frame) + lbl = ttk.Label(nav_frame, text="Filter:") + lbl.pack(side=tk.LEFT, padx=(0, 5)) + combo = ttk.Combobox( + nav_frame, + textvariable=self._globals.var_filter_mode, + state="readonly", + values=self._filter_modes) + combo.pack(side=tk.RIGHT) + Tooltip(nav_frame, text=self._helptext["mode"]) + nav_frame.pack(side=tk.RIGHT) + + def _add_filter_threshold_slider(self, frame): + """ Add the optional filter threshold slider for misaligned filter to the filter frame. + + Parameters + ---------- + frame: :class:`tkinter.ttk.Frame` + The Filter Frame that holds the filter threshold slider + """ + slider_frame = ttk.Frame(frame) + tk_var = self._globals.var_filter_distance + + min_max = (5, 20) + ctl_frame = ttk.Frame(slider_frame) + ctl_frame.pack(padx=2, side=tk.RIGHT) + + lbl = ttk.Label(ctl_frame, text="Distance:", anchor=tk.W) + lbl.pack(side=tk.LEFT, anchor=tk.N, expand=True) + + tbox = ttk.Entry(ctl_frame, width=6, textvariable=tk_var, justify=tk.RIGHT) + tbox.pack(padx=(0, 5), side=tk.RIGHT) + + ctl = ttk.Scale( + ctl_frame, + variable=tk_var, + command=lambda val, var=tk_var, dt=int, rn=1, mm=min_max: + set_slider_rounding(val, var, dt, rn, mm)) + ctl["from_"] = min_max[0] + ctl["to"] = min_max[1] + ctl.pack(padx=5, fill=tk.X, expand=True) + for item in (tbox, ctl): + Tooltip(item, + text=self._helptext["distance"], + wrap_length=200) + tk_var.trace_add("write", self._navigation.nav_scale_callback) + self._optional_widgets["distance_slider"] = slider_frame + + def pack_threshold_slider(self): + """ Display or hide the threshold slider depending on the current filter mode. For + misaligned faces filter, display the slider. Hide for all other filters. """ + if self._globals.var_filter_mode.get() == "Misaligned Faces": + self._optional_widgets["distance_slider"].pack(side=tk.LEFT) + else: + self._optional_widgets["distance_slider"].pack_forget() + + def cycle_filter_mode(self): + """ Cycle the navigation mode combo entry """ + current_mode = self._globals.var_filter_mode.get() + idx = (self._filter_modes.index(current_mode) + 1) % len(self._filter_modes) + self._globals.var_filter_mode.set(self._filter_modes[idx]) + + def set_action(self, key): + """ Set the current action based on keyboard shortcut + + Parameters + ---------- + key: str + The pressed key + """ + # Allow key pad keys for numeric presses + key = key.replace("KP_", "") if key.startswith("KP_") else key + self._actions_frame.on_click(self._actions_frame.key_bindings[key]) + + def _resize(self, event): + """ Resize the image to fit the frame, maintaining aspect ratio """ + framesize = (event.width, event.height) + logger.trace("Resizing video frame. Framesize: %s", framesize) + self._globals.set_frame_display_dims(*framesize) + self._globals.var_full_update.set(True) + + # << TRANSPORT >> # + def _play(self, *args, frame_count=None): # pylint:disable=unused-argument + """ Play the video file. """ + start = time() + is_playing = self._navigation.tk_is_playing.get() + icon = "pause" if is_playing else "play" + self._buttons["play"].config(image=get_images().icons[icon]) + + if not is_playing: + logger.debug("Pause detected. Stopping.") + return + + # Populate the filtered frames count on first frame + frame_count = self._det_faces.filter.count if frame_count is None else frame_count + self._navigation.increment_frame(frame_count=frame_count, is_playing=True) + delay = 16 # Cap speed at approx 60fps max. Unlikely to hit, but just in case + duration = int((time() - start) * 1000) + delay = max(1, delay - duration) + self.after(delay, lambda f=frame_count: self._play(f)) + + def _toggle_save_state(self, *args): # pylint:disable=unused-argument + """ Toggle the state of the save button when alignments are updated. """ + state = ["!disabled"] if self._det_faces.tk_unsaved.get() else ["disabled"] + self._buttons["save"].state(state) + + +class ActionsFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """ The left hand action frame holding the action buttons. + + Parameters + ---------- + parent: :class:`DisplayFrame` + The Display frame that the Actions reside in + """ + def __init__(self, parent): + super().__init__(parent) + self.pack(side=tk.LEFT, fill=tk.Y, padx=(2, 4), pady=2) + self._globals = parent._globals + self._det_faces = parent._det_faces + + self._configure_styles() + self._actions = ("View", "BoundingBox", "ExtractBox", "Landmarks", "Mask") + self._initial_action = "View" + self._buttons = self._add_buttons() + self._static_buttons = self._add_static_buttons() + self._selected_action = self._set_selected_action_tkvar() + self._optional_buttons = {} # Has to be set from parent after canvas is initialized + + @property + def actions(self): + """ tuple: The available action names as a tuple of strings. """ + return self._actions + + @property + def tk_selected_action(self): + """ :class:`tkinter.StringVar`: The variable holding the currently selected action """ + return self._selected_action + + @property + def key_bindings(self): + """ dict: {`key`: `action`}. The mapping of key presses to actions. Keyboard shortcut is + the first letter of each action. """ + return {f"F{idx + 1}": action for idx, action in enumerate(self._actions)} + + @property + def _helptext(self): + """ dict: `button key`: `button helptext`. The help text to display for each button. """ + inverse_keybindings = {val: key for key, val in self.key_bindings.items()} + retval = {"View": _('View alignments'), + "BoundingBox": _('Bounding box editor'), + "ExtractBox": _("Location editor"), + "Mask": _("Mask editor"), + "Landmarks": _("Landmark point editor")} + for item in retval: + retval[item] += f" ({inverse_keybindings[item]})" + return retval + + def _configure_styles(self): + """ Configure background color for Actions widget """ + style = ttk.Style() + style.configure("actions.TFrame", background='#d3d3d3') + style.configure("actions_selected.TButton", relief="flat", background="#bedaf1") + style.configure("actions_deselected.TButton", relief="flat") + self.config(style="actions.TFrame") + + def _add_buttons(self): + """ Add the action buttons to the Display window. + + Returns + ------- + dict: + The action name and its associated button. + """ + frame = ttk.Frame(self) + frame.pack(side=tk.TOP, fill=tk.Y) + buttons = {} + for action in self.key_bindings.values(): + if action == self._initial_action: + btn_style = "actions_selected.TButton" + state = (["pressed", "focus"]) + else: + btn_style = "actions_deselected.TButton" + state = (["!pressed", "!focus"]) + + button = ttk.Button(frame, + image=get_images().icons[action.lower()], + command=lambda t=action: self.on_click(t), + style=btn_style) + button.state(state) + button.pack() + Tooltip(button, text=self._helptext[action]) + buttons[action] = button + return buttons + + def on_click(self, action): + """ Click event for all of the main buttons. + + Parameters + ---------- + action: str + The action name for the button that has called this event as exists in :attr:`_buttons` + """ + for title, button in self._buttons.items(): + if action == title: + button.configure(style="actions_selected.TButton") + button.state(["pressed", "focus"]) + else: + button.configure(style="actions_deselected.TButton") + button.state(["!pressed", "!focus"]) + self._selected_action.set(action) + + def _set_selected_action_tkvar(self): + """ Set the tkinter string variable that holds the currently selected editor action. + Add traceback to display or hide editor specific optional buttons. + + Returns + ------- + :class:`tkinter.StringVar + The variable that holds the currently selected action + """ + var = tk.StringVar() + var.set(self._initial_action) + var.trace("w", self._display_optional_buttons) + return var + + def _add_static_buttons(self): + """ Add the buttons to copy alignments from previous and next frames """ + lookup = {"copy_prev": (_("Previous"), "C"), + "copy_next": (_("Next"), "V"), + "reload": ("", "R")} + frame = ttk.Frame(self) + frame.pack(side=tk.TOP, fill=tk.Y) + sep = ttk.Frame(frame, height=2, relief=tk.RIDGE) + sep.pack(fill=tk.X, pady=5, side=tk.TOP) + buttons = {} + for action in ("copy_prev", "copy_next", "reload"): + if action == "reload": + icon = "reload3" + cmd = lambda f=self._globals: self._det_faces.revert_to_saved(f.frame_index) # noqa:E731,E501 # pylint:disable=line-too-long,unnecessary-lambda-assignment + helptext = _("Revert to saved Alignments ({})").format(lookup[action][1]) + else: + icon = action + direction = action.replace("copy_", "") + cmd = lambda f=self._globals, d=direction: self._det_faces.update.copy( # noqa:E731,E501 # pylint:disable=line-too-long,unnecessary-lambda-assignment + f.frame_index, d) + helptext = _("Copy {} Alignments ({})").format(*lookup[action]) + state = ["!disabled"] if action == "copy_next" else ["disabled"] + button = ttk.Button(frame, + image=get_images().icons[icon], + command=cmd, + style="actions_deselected.TButton") + button.state(state) + button.pack() + Tooltip(button, text=helptext) + buttons[action] = button + self._globals.var_frame_index.trace_add("write", self._disable_enable_copy_buttons) + self._globals.var_full_update.trace_add("write", self._disable_enable_reload_button) + return buttons + + def _disable_enable_copy_buttons(self, *args): # pylint:disable=unused-argument + """ Disable or enable the static buttons """ + position = self._globals.frame_index + face_count_per_index = self._det_faces.face_count_per_index + prev_exists = position != -1 and any(count != 0 + for count in face_count_per_index[:position]) + next_exists = position != -1 and any(count != 0 + for count in face_count_per_index[position + 1:]) + states = {"prev": ["!disabled"] if prev_exists else ["disabled"], + "next": ["!disabled"] if next_exists else ["disabled"]} + for direction in ("prev", "next"): + self._static_buttons[f"copy_{direction}"].state(states[direction]) + + def _disable_enable_reload_button(self, *args): # pylint:disable=unused-argument + """ Disable or enable the static buttons """ + position = self._globals.frame_index + state = ["!disabled"] if (position != -1 and + self._det_faces.is_frame_updated(position)) else ["disabled"] + self._static_buttons["reload"].state(state) + + def add_optional_buttons(self, editors): + """ Add the optional editor specific action buttons """ + for name, editor in editors.items(): + actions = editor.actions + if not actions: + self._optional_buttons[name] = None + continue + frame = ttk.Frame(self) + sep = ttk.Frame(frame, height=2, relief=tk.RIDGE) + sep.pack(fill=tk.X, pady=5, side=tk.TOP) + seen_groups = set() + for action in actions.values(): + group = action["group"] + if group is not None and group not in seen_groups: + btn_style = "actions_selected.TButton" + state = (["pressed", "focus"]) + action["tk_var"].set(True) + seen_groups.add(group) + else: + btn_style = "actions_deselected.TButton" + state = (["!pressed", "!focus"]) + action["tk_var"].set(False) + button = ttk.Button(frame, + image=get_images().icons[action["icon"]], + style=btn_style) + button.config(command=lambda b=button: self._on_optional_click(b)) + button.state(state) + button.pack() + + helptext = action["helptext"] + hotkey = action["hotkey"] + helptext += "" if hotkey is None else f" ({hotkey.upper()})" + Tooltip(button, text=helptext) + self._optional_buttons.setdefault( + name, {})[button] = {"hotkey": hotkey, + "group": group, + "tk_var": action["tk_var"]} + self._optional_buttons[name]["frame"] = frame + self._display_optional_buttons() + + def _on_optional_click(self, button): + """ Click event for all of the optional buttons. + + Parameters + ---------- + button: str + The action name for the button that has called this event as exists in :attr:`_buttons` + """ + options = self._optional_buttons[self._selected_action.get()] + group = options[button]["group"] + for child in options["frame"].winfo_children(): + if child.winfo_class() != "TButton": + continue + child_group = options[child]["group"] + if child == button and group is not None: + child.configure(style="actions_selected.TButton") + child.state(["pressed", "focus"]) + options[child]["tk_var"].set(True) + elif child != button and group is not None and child_group == group: + child.configure(style="actions_deselected.TButton") + child.state(["!pressed", "!focus"]) + options[child]["tk_var"].set(False) + elif group is None and child_group is None: + if child.cget("style") == "actions_selected.TButton": + child.configure(style="actions_deselected.TButton") + child.state(["!pressed", "!focus"]) + options[child]["tk_var"].set(False) + else: + child.configure(style="actions_selected.TButton") + child.state(["pressed", "focus"]) + options[child]["tk_var"].set(True) + + def _display_optional_buttons(self, *args): # pylint:disable=unused-argument + """ Pack or forget the optional buttons depending on active editor """ + self._unbind_optional_hotkeys() + for editor, option in self._optional_buttons.items(): + if option is None: + continue + if editor == self._selected_action.get(): + logger.debug("Displaying optional buttons for '%s'", editor) + option["frame"].pack(side=tk.TOP, fill=tk.Y) + for child in option["frame"].winfo_children(): + if child.winfo_class() != "TButton": + continue + hotkey = option[child]["hotkey"] + if hotkey is not None: + logger.debug("Binding optional hotkey for editor '%s': %s", editor, hotkey) + self.winfo_toplevel().bind(hotkey.lower(), + lambda e, b=child: self._on_optional_click(b)) + elif option["frame"].winfo_ismapped(): + logger.debug("Hiding optional buttons for '%s'", editor) + option["frame"].pack_forget() + + def _unbind_optional_hotkeys(self): + """ Unbind all mapped optional button hotkeys """ + for editor, option in self._optional_buttons.items(): + if option is None or not option["frame"].winfo_ismapped(): + continue + for child in option["frame"].winfo_children(): + if child.winfo_class() != "TButton": + continue + hotkey = option[child]["hotkey"] + if hotkey is not None: + logger.debug("Unbinding optional hotkey for editor '%s': %s", editor, hotkey) + self.winfo_toplevel().unbind(hotkey.lower()) + + +class FrameViewer(tk.Canvas): # pylint:disable=too-many-ancestors + """ Annotation onto tkInter Canvas. + + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The parent frame for the canvas + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + detected_faces: :class:`AlignmentsData` + The alignments data for this manual session + actions: tuple + The available actions from :attr:`ActionFrame.actions` + tk_action_var: :class:`tkinter.StringVar` + The variable holding the currently selected action + """ + def __init__(self, parent, tk_globals, detected_faces, actions, tk_action_var): + logger.debug("Initializing %s: (parent: %s, tk_globals: %s, detected_faces: %s, " + "actions: %s, tk_action_var: %s)", self.__class__.__name__, + parent, tk_globals, detected_faces, actions, tk_action_var) + super().__init__(parent, bd=0, highlightthickness=0, background="black") + self.pack(side=tk.TOP, fill=tk.BOTH, expand=True, anchor=tk.E) + self._globals = tk_globals + self._det_faces = detected_faces + self._actions = actions + self._tk_action_var = tk_action_var + self._image = BackgroundImage(self) + self._editor_globals = {"control_tk_vars": {}, + "annotation_formats": {}, + "key_bindings": {}} + self._max_face_count = 0 + self._editors = self._get_editors() + self._add_callbacks() + self._change_active_editor() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def selected_action(self): + """str: The name of the currently selected Editor action """ + return self._tk_action_var.get() + + @property + def control_tk_vars(self): + """ dict: dictionary of tkinter variables as populated by the right hand control panel. + Tracking for all control panel variables, for access from all editors. """ + return self._editor_globals["control_tk_vars"] + + @property + def key_bindings(self): + """ dict: dictionary of key bindings for each editor for access from all editors. """ + return self._editor_globals["key_bindings"] + + @property + def annotation_formats(self): + """ dict: The selected formatting options for each annotation """ + return self._editor_globals["annotation_formats"] + + @property + def active_editor(self): + """ :class:`Editor`: The current editor in use based on :attr:`selected_action`. """ + return self._editors[self.selected_action] + + @property + def editors(self): + """ dict: All of the :class:`Editor` objects that exist """ + return self._editors + + @property + def editor_display(self): + """ dict: List of editors and any additional annotations they should display. """ + return {"View": ["BoundingBox", "ExtractBox", "Landmarks", "Mesh"], + "BoundingBox": ["Mesh"], + "ExtractBox": ["Mesh"], + "Landmarks": ["ExtractBox", "Mesh"], + "Mask": []} + + @property + def offset(self): + """ tuple: The (`width`, `height`) offset of the canvas based on the size of the currently + displayed image """ + frame_dims = self._globals.current_frame.display_dims + offset_x = (self._globals.frame_display_dims[0] - frame_dims[0]) / 2 + offset_y = (self._globals.frame_display_dims[1] - frame_dims[1]) / 2 + logger.trace("offset_x: %s, offset_y: %s", offset_x, offset_y) + return offset_x, offset_y + + def _get_editors(self): + """ Get the object editors for the canvas. + + Returns + ------ + dict + The {`action`: :class:`Editor`} dictionary of editors for :attr:`_actions` name. + """ + editors = {} + for editor_name in self._actions + ("Mesh", ): + editor = eval(editor_name)(self, # pylint:disable=eval-used + self._det_faces) + editors[editor_name] = editor + logger.debug(editors) + return editors + + def _add_callbacks(self): + """ Add the callback trace functions to the :class:`tkinter.Variable` s + + Adds callbacks for: + :attr:`_globals.var_full_update` Update the display for the current image + :attr:`__tk_action_var` Update the mouse display tracking for current action + """ + self._globals.var_full_update.trace_add("write", self._update_display) + self._tk_action_var.trace_add("write", self._change_active_editor) + + def _change_active_editor(self, *args): # pylint:disable=unused-argument + """ Update the display for the active editor. + + Hide the annotations that are not relevant for the selected editor. + Set the selected editor's cursor tracking. + + Parameters + ---------- + args: tuple, unused + Required for tkinter callback but unused + """ + to_display = [self.selected_action] + self.editor_display[self.selected_action] + to_hide = [editor for editor in self._editors if editor not in to_display] + for editor in to_hide: + self._editors[editor].hide_annotation() + + self.active_editor.bind_mouse_motion() + self.active_editor.set_mouse_click_actions() + self._globals.var_full_update.set(True) + + def _update_display(self, *args): # pylint:disable=unused-argument + """ Update the display on frame cache update + + Notes + ----- + A little hacky, but the editors to display or hide are processed in alphabetical + order, so that they are always processed in the same order (for tag lowering and raising) + """ + if not self._globals.var_full_update.get(): + return + zoomed_centering = self.active_editor.zoomed_centering + self._image.refresh(self.active_editor.view_mode) + to_display = sorted([self.selected_action] + self.editor_display[self.selected_action]) + self._hide_additional_faces() + for editor in to_display: + self._editors[editor].update_annotation() + self._bind_unbind_keys() + if zoomed_centering != self.active_editor.zoomed_centering: + # Refresh the image if editor annotation has changed the zoom centering of the image + self._image.refresh(self.active_editor.view_mode) + self._globals.var_full_update.set(False) + self.update_idletasks() + + def _hide_additional_faces(self): + """ Hide additional faces if the number of faces on the canvas reduces on a frame + change. """ + if self._globals.is_zoomed: + current_face_count = 1 + elif self._globals.frame_index == -1: + current_face_count = 0 + else: + current_face_count = len(self._det_faces.current_faces[self._globals.frame_index]) + + if current_face_count > self._max_face_count: + # Most faces seen to date so nothing to hide. Update max count and return + logger.debug("Incrementing max face count from: %s to: %s", + self._max_face_count, current_face_count) + self._max_face_count = current_face_count + return + for idx in range(current_face_count, self._max_face_count): + tag = f"face_{idx}" + if any(self.itemcget(item_id, "state") != "hidden" + for item_id in self.find_withtag(tag)): + logger.debug("Hiding face tag '%s'", tag) + self.itemconfig(tag, state="hidden") + + def _bind_unbind_keys(self): + """ Bind or unbind this editor's hotkeys depending on whether it is active. """ + unbind_keys = [key for key, binding in self.key_bindings.items() + if binding["bound_to"] is not None + and binding["bound_to"] != self.selected_action] + for key in unbind_keys: + logger.debug("Unbinding key '%s'", key) + self.winfo_toplevel().unbind(key) + self.key_bindings[key]["bound_to"] = None + + bind_keys = {key: binding[self.selected_action] + for key, binding in self.key_bindings.items() + if self.selected_action in binding + and binding["bound_to"] != self.selected_action} + for key, method in bind_keys.items(): + logger.debug("Binding key '%s' to method %s", key, method) + self.winfo_toplevel().bind(key, method) + self.key_bindings[key]["bound_to"] = self.selected_action + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/globals.py b/tools/manual/globals.py new file mode 100644 index 0000000000..548ebdc27d --- /dev/null +++ b/tools/manual/globals.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +""" Holds global tkinter variables and information pertaining to the entire Manual tool """ +from __future__ import annotations + +import logging +import os +import sys +import tkinter as tk + +from dataclasses import dataclass, field + +import cv2 +import numpy as np + +from lib.gui.utils import get_config +from lib.logger import parse_class_init +from lib.utils import get_module_objects, VIDEO_EXTENSIONS + +logger = logging.getLogger(__name__) + + +@dataclass +class CurrentFrame: + """ Dataclass for holding information about the currently displayed frame """ + image: np.ndarray = field(default_factory=lambda: np.zeros(1)) + """:class:`numpy.ndarry`: The currently displayed frame in original dimensions """ + scale: float = 1.0 + """float: The scaling factor to use to resize the image to the display window """ + interpolation: int = cv2.INTER_AREA + """int: The opencv interpolator ID to use for resizing the image to the display window """ + display_dims: tuple[int, int] = (0, 0) + """tuple[int, int]`: The size of the currently displayed frame, in the display window """ + filename: str = "" + """str: The filename of the currently displayed frame """ + + def __repr__(self) -> str: + """ Clean string representation showing numpy arrays as shape and dtype + + Returns + ------- + str + Loggable representation of the dataclass + """ + properties = [f"{k}={(v.shape, v.dtype) if isinstance(v, np.ndarray) else v}" + for k, v in self.__dict__.items()] + return f"{self.__class__.__name__} ({', '.join(properties)}" + + +@dataclass +class TKVars: + """ Holds the global TK Variables """ + frame_index: tk.IntVar + """:class:`tkinter.IntVar`: The absolute frame index of the currently displayed frame""" + transport_index: tk.IntVar + """:class:`tkinter.IntVar`: The transport index of the currently displayed frame when filters + have been applied """ + face_index: tk.IntVar + """:class:`tkinter.IntVar`: The face index of the currently selected face""" + filter_distance: tk.IntVar + """:class:`tkinter.IntVar`: The amount to filter by distance""" + + update: tk.BooleanVar + """:class:`tkinter.BooleanVar`: Whether an update has been performed """ + update_active_viewport: tk.BooleanVar + """:class:`tkinter.BooleanVar`: Whether the viewport needs updating """ + is_zoomed: tk.BooleanVar + """:class:`tkinter.BooleanVar`: Whether the main window is zoomed in to a face or out to a + full frame""" + + filter_mode: tk.StringVar + """:class:`tkinter.StringVar`: The currently selected filter mode """ + faces_size: tk.StringVar + """:class:`tkinter.StringVar`: The pixel size of faces in the viewport """ + + def __repr__(self) -> str: + """ Clean string representation showing variable type as well as their value + + Returns + ------- + str + Loggable representation of the dataclass + """ + properties = [f"{k}={v.__class__.__name__}({v.get()})" for k, v in self.__dict__.items()] + return f"{self.__class__.__name__} ({', '.join(properties)}" + + +class TkGlobals(): + """ Holds Tkinter Variables and other frame information that need to be accessible from all + areas of the GUI. + + Parameters + ---------- + input_location: str + The location of the input folder of frames or video file + """ + def __init__(self, input_location: str) -> None: + logger.debug(parse_class_init(locals())) + self._tk_vars = self._get_tk_vars() + + self._is_video = self._check_input(input_location) + self._frame_count = 0 # set by FrameLoader + self._frame_display_dims = (int(round(896 * get_config().scaling_factor)), + int(round(504 * get_config().scaling_factor))) + self._current_frame = CurrentFrame() + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _get_tk_vars(cls) -> TKVars: + """ Create and initialize the tkinter variables. + + Returns + ------- + :class:`TKVars` + The global tkinter variables + """ + retval = TKVars(frame_index=tk.IntVar(value=0), + transport_index=tk.IntVar(value=0), + face_index=tk.IntVar(value=0), + filter_distance=tk.IntVar(value=10), + update=tk.BooleanVar(value=False), + update_active_viewport=tk.BooleanVar(value=False), + is_zoomed=tk.BooleanVar(value=False), + filter_mode=tk.StringVar(), + faces_size=tk.StringVar()) + logger.debug(retval) + return retval + + @property + def current_frame(self) -> CurrentFrame: + """ :class:`CurrentFrame`: The currently displayed frame in the frame viewer with it's + meta information. """ + return self._current_frame + + @property + def frame_count(self) -> int: + """ int: The total number of frames for the input location """ + return self._frame_count + + @property + def frame_display_dims(self) -> tuple[int, int]: + """ tuple: The (`width`, `height`) of the video display frame in pixels. """ + return self._frame_display_dims + + @property + def is_video(self) -> bool: + """ bool: ``True`` if the input is a video file, ``False`` if it is a folder of images. """ + return self._is_video + + # TK Variables that need to be exposed + @property + def var_full_update(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: Flag to indicate that whole GUI should be refreshed """ + return self._tk_vars.update + + @property + def var_transport_index(self) -> tk.IntVar: + """ :class:`tkinter.IntVar`: The current index of the display frame's transport slider. """ + return self._tk_vars.transport_index + + @property + def var_frame_index(self) -> tk.IntVar: + """ :class:`tkinter.IntVar`: The current absolute frame index of the currently + displayed frame. """ + return self._tk_vars.frame_index + + @property + def var_filter_distance(self) -> tk.IntVar: + """ :class:`tkinter.IntVar`: The variable holding the currently selected threshold + distance for misaligned filter mode. """ + return self._tk_vars.filter_distance + + @property + def var_filter_mode(self) -> tk.StringVar: + """ :class:`tkinter.StringVar`: The variable holding the currently selected navigation + filter mode. """ + return self._tk_vars.filter_mode + + @property + def var_faces_size(self) -> tk.StringVar: + """ :class:`tkinter..IntVar`: The variable holding the currently selected Faces Viewer + thumbnail size. """ + return self._tk_vars.faces_size + + @property + def var_update_active_viewport(self) -> tk.BooleanVar: + """ :class:`tkinter.BooleanVar`: Boolean Variable that is traced by the viewport's active + frame to update. """ + return self._tk_vars.update_active_viewport + + # Raw values returned from TK Variables + @property + def face_index(self) -> int: + """ int: The currently displayed face index when in zoomed mode. """ + return self._tk_vars.face_index.get() + + @property + def frame_index(self) -> int: + """ int: The currently displayed frame index. NB This returns -1 if there are no frames + that meet the currently selected filter criteria. """ + return self._tk_vars.frame_index.get() + + @property + def is_zoomed(self) -> bool: + """ bool: ``True`` if the frame viewer is zoomed into a face, ``False`` if the frame viewer + is displaying a full frame. """ + return self._tk_vars.is_zoomed.get() + + @staticmethod + def _check_input(frames_location: str) -> bool: + """ Check whether the input is a video + + Parameters + ---------- + frames_location: str + The input location for video or images + + Returns + ------- + bool: 'True' if input is a video 'False' if it is a folder. + """ + if os.path.isdir(frames_location): + retval = False + elif os.path.splitext(frames_location)[1].lower() in VIDEO_EXTENSIONS: + retval = True + else: + logger.error("The input location '%s' is not valid", frames_location) + sys.exit(1) + logger.debug("Input '%s' is_video: %s", frames_location, retval) + return retval + + def set_face_index(self, index: int) -> None: + """ Set the currently selected face index + + Parameters + ---------- + index: int + The currently selected face index + """ + logger.trace("Setting face index from %s to %s", # type:ignore[attr-defined] + self.face_index, index) + self._tk_vars.face_index.set(index) + + def set_frame_count(self, count: int) -> None: + """ Set the count of total number of frames to :attr:`frame_count` when the + :class:`FramesLoader` has completed loading. + + Parameters + ---------- + count: int + The number of frames that exist for this session + """ + logger.debug("Setting frame_count to : %s", count) + self._frame_count = count + + def set_current_frame(self, image: np.ndarray, filename: str) -> None: + """ Set the frame and meta information for the currently displayed frame. Populates the + attribute :attr:`current_frame` + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image used to display in the Frame Viewer + filename: str + The filename of the current frame + """ + scale = min(self.frame_display_dims[0] / image.shape[1], + self.frame_display_dims[1] / image.shape[0]) + self._current_frame.image = image + self._current_frame.filename = filename + self._current_frame.scale = scale + self._current_frame.interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + self._current_frame.display_dims = (int(round(image.shape[1] * scale)), + int(round(image.shape[0] * scale))) + logger.trace(self._current_frame) # type:ignore[attr-defined] + + def set_frame_display_dims(self, width: int, height: int) -> None: + """ Set the size, in pixels, of the video frame display window and resize the displayed + frame. + + Used on a frame resize callback, sets the :attr:frame_display_dims`. + + Parameters + ---------- + width: int + The width of the frame holding the video canvas in pixels + height: int + The height of the frame holding the video canvas in pixels + """ + self._frame_display_dims = (int(width), int(height)) + image = self._current_frame.image + scale = min(self.frame_display_dims[0] / image.shape[1], + self.frame_display_dims[1] / image.shape[0]) + self._current_frame.scale = scale + self._current_frame.interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + self._current_frame.display_dims = (int(round(image.shape[1] * scale)), + int(round(image.shape[0] * scale))) + logger.trace(self._current_frame) # type:ignore[attr-defined] + + def set_zoomed(self, state: bool) -> None: + """ Set the current zoom state + + Parameters + ---------- + state: bool + ``True`` for zoomed ``False`` for full frame + """ + logger.trace("Setting zoom state from %s to %s", # type:ignore[attr-defined] + self.is_zoomed, state) + self._tk_vars.is_zoomed.set(state) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/manual.py b/tools/manual/manual.py new file mode 100644 index 0000000000..e426d41704 --- /dev/null +++ b/tools/manual/manual.py @@ -0,0 +1,776 @@ +#!/usr/bin/env python3 +""" Main entry point for the Manual Tool. A GUI app for editing alignments files """ +from __future__ import annotations + +import logging +import os +import sys +import typing as T +import tkinter as tk +from tkinter import ttk +from dataclasses import dataclass +from time import sleep + +import numpy as np + +from lib.gui.control_helper import ControlPanel +from lib.gui.utils import get_images, get_config, initialize_config, initialize_images +from lib.image import SingleFrameLoader, read_image_meta +from lib.logger import parse_class_init +from lib.multithreading import MultiThread +from lib.utils import get_module_objects, handle_deprecated_cliopts +from plugins.extract import ExtractMedia, Extractor + +from .detected_faces import DetectedFaces +from .faceviewer.frame import FacesFrame +from .frameviewer.frame import DisplayFrame +from .globals import TkGlobals +from .thumbnails import ThumbsCreator + +if T.TYPE_CHECKING: + from argparse import Namespace + from lib import align + from lib.align import DetectedFace + from lib.queue_manager import EventQueue + +logger = logging.getLogger(__name__) + +TypeManualExtractor = T.Literal["FAN", "cv2-dnn", "mask"] + + +@dataclass +class _Containers: + """ Dataclass for holding the main area containers in the GUI """ + main: ttk.PanedWindow + """:class:`tkinter.ttk.PanedWindow`: The main window holding the full GUI """ + top: ttk.Frame + """:class:`tkinter.ttk.Frame: The top part (frame viewer) of the GUI""" + bottom: ttk.Frame + """:class:`tkinter.ttk.Frame: The bottom part (face viewer) of the GUI""" + + +class Manual(tk.Tk): + """ The main entry point for Faceswap's Manual Editor Tool. This tool is part of the Faceswap + Tools suite and should be called from ``python tools.py manual`` command. + + Allows for visual interaction with frames, faces and alignments file to perform various + adjustments to the alignments file. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + + def __init__(self, arguments: Namespace) -> None: + logger.debug(parse_class_init(locals())) + super().__init__() + arguments = handle_deprecated_cliopts(arguments) + self._validate_non_faces(arguments.frames) + + self._initialize_tkinter() + self._globals = TkGlobals(arguments.frames) + + extractor = Aligner(self._globals) + self._detected_faces = DetectedFaces(self._globals, + arguments.alignments_path, + arguments.frames, + extractor) + + video_meta_data = self._detected_faces.video_meta_data + valid_meta = all(val is not None for val in video_meta_data.values()) + + loader = FrameLoader(self._globals, + arguments.frames, + video_meta_data, + self._detected_faces.frame_list) + + if valid_meta: # Load the faces whilst other threads complete if we have valid meta data + self._detected_faces.load_faces() + + self._containers = self._create_containers() + self._wait_for_threads(extractor, loader, valid_meta) + if not valid_meta: # If meta data needs updating, load faces after other threads + self._detected_faces.load_faces() + + self._generate_thumbs(arguments.frames, arguments.thumb_regen, arguments.single_process) + + self._display = DisplayFrame(self._containers.top, + self._globals, + self._detected_faces) + _Options(self._containers.top, self._globals, self._display) + + self._faces_frame = FacesFrame(self._containers.bottom, + self._globals, + self._detected_faces, + self._display) + self._display.tk_selected_action.set("View") + + self.bind("", self._handle_key_press) + self._set_initial_layout() + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _validate_non_faces(cls, frames_folder: str) -> None: + """ Quick check on the input to make sure that a folder of extracted faces is not being + passed in. """ + if not os.path.isdir(frames_folder): + logger.debug("Input '%s' is not a folder", frames_folder) + return + test_file = next((fname + for fname in os.listdir(frames_folder) + if os.path.splitext(fname)[-1].lower() == ".png"), + None) + if not test_file: + logger.debug("Input '%s' does not contain any .pngs", frames_folder) + return + test_file = os.path.join(frames_folder, test_file) + meta = read_image_meta(test_file) + logger.debug("Test file: (filename: %s, metadata: %s)", test_file, meta) + if "itxt" in meta and "alignments" in meta["itxt"]: + logger.error("The input folder '%s' contains extracted faces.", frames_folder) + logger.error("The Manual Tool works with source frames or a video file, not extracted " + "faces. Please update your input.") + sys.exit(1) + logger.debug("Test input file '%s' does not contain Faceswap header data", test_file) + + def _wait_for_threads(self, extractor: Aligner, loader: FrameLoader, valid_meta: bool) -> None: + """ The :class:`Aligner` and :class:`FramesLoader` are launched in background threads. + Wait for them to be initialized prior to proceeding. + + Parameters + ---------- + extractor: :class:`Aligner` + The extraction pipeline for the Manual Tool + loader: :class:`FramesLoader` + The frames loader for the Manual Tool + valid_meta: bool + Whether the input video had valid meta-data on import, or if it had to be created. + ``True`` if valid meta data existed previously, ``False`` if it needed to be created + + Notes + ----- + Because some of the initialize checks perform extra work once their threads are complete, + they should only return ``True`` once, and should not be queried again. + """ + extractor_init = False + frames_init = False + while True: + extractor_init = extractor_init if extractor_init else extractor.is_initialized + frames_init = frames_init if frames_init else loader.is_initialized + if extractor_init and frames_init: + logger.debug("Threads inialized") + break + logger.debug("Threads not initialized. Waiting...") + sleep(1) + + extractor.link_faces(self._detected_faces) + if not valid_meta: + logger.debug("Saving video meta data to alignments file") + self._detected_faces.save_video_meta_data( + **loader.video_meta_data) # type:ignore[arg-type] + + def _generate_thumbs(self, input_location: str, force: bool, single_process: bool) -> None: + """ Check whether thumbnails are stored in the alignments file and if not generate them. + + Parameters + ---------- + input_location: str + The input video or folder of images + force: bool + ``True`` if the thumbnails should be regenerated even if they exist, otherwise + ``False`` + single_process: bool + ``True`` will extract thumbs from a video in a single process, ``False`` will run + parallel threads + """ + thumbs = ThumbsCreator(self._detected_faces, input_location, single_process) + if thumbs.has_thumbs and not force: + return + logger.debug("Generating thumbnails cache") + thumbs.generate_cache() + logger.debug("Generated thumbnails cache") + + def _initialize_tkinter(self) -> None: + """ Initialize a standalone tkinter instance. """ + logger.debug("Initializing tkinter") + for widget in ("TButton", "TCheckbutton", "TRadiobutton"): + self.unbind_class(widget, "") + initialize_config(self, None, None) + initialize_images() + get_config().set_geometry(940, 600, fullscreen=True) + self.title("Faceswap.py - Visual Alignments") + logger.debug("Initialized tkinter") + + def _create_containers(self) -> _Containers: + """ Create the paned window containers for various GUI elements + + Returns + ------- + :class:`_Containers`: + The main containers of the manual tool. + """ + logger.debug("Creating containers") + + main = ttk.PanedWindow(self, + orient=tk.VERTICAL, + name="pw_main") + main.pack(fill=tk.BOTH, expand=True) + + top = ttk.Frame(main, name="frame_top") + main.add(top) + + bottom = ttk.Frame(main, name="frame_bottom") + main.add(bottom) + + retval = _Containers(main=main, top=top, bottom=bottom) + + logger.debug("Created containers: %s", retval) + return retval + + def _handle_key_press(self, event: tk.Event) -> None: + """ Keyboard shortcuts + + Parameters + ---------- + event: :class:`tkinter.Event()` + The tkinter key press event + + Notes + ----- + The following keys are reserved for the :mod:`tools.lib_manual.editor` classes + * Delete - Used for deleting faces + * [] - decrease / increase brush size + * B, D, E, M - Optional Actions (Brush, Drag, Erase, Zoom) + """ + # Alt modifier appears to be broken in Windows so don't use it. + modifiers = {0x0001: 'shift', + 0x0004: 'ctrl'} + + globs = self._globals + bindings = { + "z": self._display.navigation.decrement_frame, + "x": self._display.navigation.increment_frame, + "space": self._display.navigation.handle_play_button, + "home": self._display.navigation.goto_first_frame, + "end": self._display.navigation.goto_last_frame, + "down": lambda d="down": self._faces_frame.canvas_scroll(d), + "up": lambda d="up": self._faces_frame.canvas_scroll(d), + "next": lambda d="page-down": self._faces_frame.canvas_scroll(d), + "prior": lambda d="page-up": self._faces_frame.canvas_scroll(d), + "f": self._display.cycle_filter_mode, + "f1": lambda k=event.keysym: self._display.set_action(k), + "f2": lambda k=event.keysym: self._display.set_action(k), + "f3": lambda k=event.keysym: self._display.set_action(k), + "f4": lambda k=event.keysym: self._display.set_action(k), + "f5": lambda k=event.keysym: self._display.set_action(k), + "f9": lambda k=event.keysym: self._faces_frame.set_annotation_display(k), + "f10": lambda k=event.keysym: self._faces_frame.set_annotation_display(k), + "c": lambda f=globs.frame_index, d="prev": self._detected_faces.update.copy(f, d), + "v": lambda f=globs.frame_index, d="next": self._detected_faces.update.copy(f, d), + "ctrl_s": self._detected_faces.save, + "r": lambda f=globs.frame_index: self._detected_faces.revert_to_saved(f)} + + # Allow keypad keys to be used for numbers + press = event.keysym.replace("KP_", "") if event.keysym.startswith("KP_") else event.keysym + assert isinstance(event.state, int) + modifier = "_".join(val for key, val in modifiers.items() if event.state & key != 0) + key_press = "_".join([modifier, press]) if modifier else press + if key_press.lower() in bindings: + logger.trace("key press: %s, action: %s", # type:ignore[attr-defined] + key_press, bindings[key_press.lower()]) + self.focus_set() + bindings[key_press.lower()]() + + def _set_initial_layout(self) -> None: + """ Set the favicon and the bottom frame position to correct location to display full + frame window. + + Notes + ----- + The favicon pops the tkinter GUI (without loaded elements) as soon as it is called, so + this is set last. + """ + logger.debug("Setting initial layout") + self.tk.call("wm", + "iconphoto", + self._w, # type:ignore[attr-defined] # pylint:disable=protected-access + get_images().icons["favicon"]) + location = int(self.winfo_screenheight() // 1.5) + self._containers.main.sashpos(0, location) + self.update_idletasks() + + def process(self) -> None: + """ The entry point for the Visual Alignments tool from :mod:`lib.tools.manual.cli`. + + Launch the tkinter Visual Alignments Window and run main loop. + """ + logger.debug("Launching mainloop") + self.mainloop() + + +class _Options(ttk.Frame): # pylint:disable=too-many-ancestors + """ Control panel options for currently displayed Editor. This is the right hand panel of the + GUI that holds editor specific settings and annotation display settings. + + Parameters + ---------- + parent: :class:`tkinter.ttk.Frame` + The parent frame for the control panel options + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + display_frame: :class:`DisplayFrame` + The frame that holds the editors + """ + def __init__(self, + parent: ttk.Frame, + tk_globals: TkGlobals, + display_frame: DisplayFrame) -> None: + logger.debug(parse_class_init(locals())) + super().__init__(parent) + + self._globals = tk_globals + self._display_frame = display_frame + self._control_panels = self._initialize() + self._set_tk_callbacks() + self._update_options() + self.pack(side=tk.RIGHT, fill=tk.Y) + logger.debug("Initialized %s", self.__class__.__name__) + + def _initialize(self) -> dict[str, ControlPanel]: + """ Initialize all of the control panels, then display the default panel. + + Adds the control panel to :attr:`_control_panels` and sets the traceback to update + display when a panel option has been changed. + + Notes + ----- + All panels must be initialized at the beginning so that the global format options are not + reset to default when the editor is first selected. + + The Traceback must be set after the panel has first been packed as otherwise it interferes + with the loading of the faces pane. + + Returns + ------- + dict[str, :class:`~lib.gui.control_helper.ControlPanel`] + The configured control panels + """ + self._initialize_face_options() + frame = ttk.Frame(self) + frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + panels = {} + for name, editor in self._display_frame.editors.items(): + logger.debug("Initializing control panel for '%s' editor", name) + controls = editor.controls + panel = ControlPanel(frame, controls["controls"], + option_columns=2, + columns=1, + max_columns=1, + header_text=controls["header"], + blank_nones=False, + label_width=12, + style="CPanel", + scrollbar=False) + panel.pack_forget() + panels[name] = panel + return panels + + def _initialize_face_options(self) -> None: + """ Set the Face Viewer options panel, beneath the standard control options. """ + frame = ttk.Frame(self) + frame.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=5) + size_frame = ttk.Frame(frame) + size_frame.pack(side=tk.RIGHT) + lbl = ttk.Label(size_frame, text="Face Size:") + lbl.pack(side=tk.LEFT) + cmb = ttk.Combobox(size_frame, + values=["Tiny", "Small", "Medium", "Large", "Extra Large"], + state="readonly", + textvariable=self._globals.var_faces_size) + self._globals.var_faces_size.set("Medium") + cmb.pack(side=tk.RIGHT, padx=5) + + def _set_tk_callbacks(self) -> None: + """ Sets the callback to change to the relevant control panel options when the selected + editor is changed, and the display update on panel option change.""" + self._display_frame.tk_selected_action.trace("w", self._update_options) + seen_controls = set() + for name, editor in self._display_frame.editors.items(): + for ctl in editor.controls["controls"]: + if ctl in seen_controls: + # Some controls are re-used (annotation format), so skip if trace has already + # been set + continue + logger.debug("Adding control update callback: (editor: %s, control: %s)", + name, ctl.title) + seen_controls.add(ctl) + ctl.tk_var.trace("w", lambda *e: self._globals.var_full_update.set(True)) + + def _update_options(self, *args) -> None: # pylint:disable=unused-argument + """ Update the control panel display for the current editor. + + If the options have not already been set, then adds the control panel to + :attr:`_control_panels`. Displays the current editor's control panel + + Parameters + ---------- + args: tuple + Unused but required for tkinter variable callback + """ + self._clear_options_frame() + editor = self._display_frame.tk_selected_action.get() + logger.debug("Displaying control panel for editor: '%s'", editor) + self._control_panels[editor].pack(expand=True, fill=tk.BOTH) + + def _clear_options_frame(self) -> None: + """ Hides the currently displayed control panel """ + for editor, panel in self._control_panels.items(): + if panel.winfo_ismapped(): + logger.debug("Hiding control panel for: %s", editor) + panel.pack_forget() + + +class Aligner(): + """ The :class:`Aligner` class sets up an extraction pipeline for each of the current Faceswap + Aligners, along with the Landmarks based Maskers. When new landmarks are required, the bounding + boxes from the GUI are passed to this class for pushing through the pipeline. The resulting + Landmarks and Masks are then returned. + + Parameters + ---------- + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + """ + def __init__(self, tk_globals: TkGlobals) -> None: + logger.debug("Initializing: %s (tk_globals: %s)", + self.__class__.__name__, tk_globals) + self._globals = tk_globals + + self._detected_faces: DetectedFaces | None = None + self._frame_index: int | None = None + self._face_index: int | None = None + + self._aligners: dict[TypeManualExtractor, Extractor | None] = {"cv2-dnn": None, + "FAN": None, + "mask": None} + self._aligner: TypeManualExtractor = "FAN" + + self._init_thread = self._background_init_aligner() + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def _in_queue(self) -> EventQueue: + """ :class:`queue.Queue` - The input queue to the extraction pipeline. """ + aligner = self._aligners[self._aligner] + assert aligner is not None + return aligner.input_queue + + @property + def _feed_face(self) -> ExtractMedia: + """ :class:`~plugins.extract.extract_media.ExtractMedia`: The current face for feeding into + the aligner, formatted for the pipeline """ + assert self._frame_index is not None + assert self._face_index is not None + assert self._detected_faces is not None + face = self._detected_faces.current_faces[self._frame_index][self._face_index] + return ExtractMedia( + self._globals.current_frame.filename, + self._globals.current_frame.image, + detected_faces=[face]) + + @property + def is_initialized(self) -> bool: + """ bool: The Aligners are initialized in a background thread so that other tasks can be + performed whilst we wait for initialization. ``True`` is returned if the aligner has + completed initialization otherwise ``False``.""" + thread_is_alive = self._init_thread.is_alive() + if thread_is_alive: + logger.trace("Aligner not yet initialized") # type:ignore[attr-defined] + self._init_thread.check_and_raise_error() + else: + logger.trace("Aligner initialized") # type:ignore[attr-defined] + self._init_thread.join() + return not thread_is_alive + + def _background_init_aligner(self) -> MultiThread: + """ Launch the aligner in a background thread so we can run other tasks whilst + waiting for initialization + + Returns + ------- + :class:`lib.multithreading.MultiThread + The background aligner loader thread + """ + logger.debug("Launching aligner initialization thread") + thread = MultiThread(self._init_aligner, + thread_count=1, + name=f"{self.__class__.__name__}.init_aligner") + thread.start() + logger.debug("Launched aligner initialization thread") + return thread + + def _init_aligner(self) -> None: + """ Initialize Aligner in a background thread, and set it to :attr:`_aligner`. """ + logger.debug("Initialize Aligner") + # Make sure non-GPU aligner is allocated first + for model in T.get_args(TypeManualExtractor): + logger.debug("Initializing aligner: %s", model) + plugin = None if model == "mask" else model + aligner = Extractor(None, + plugin, + ["components", "extended"], + multiprocess=True, + normalize_method="hist", + disable_filter=True) + if plugin: + aligner.set_batchsize("align", 1) # Set the batchsize to 1 + aligner.launch() + logger.debug("Initialized %s Extractor", model) + self._aligners[model] = aligner + + def link_faces(self, detected_faces: DetectedFaces) -> None: + """ As the Aligner has the potential to take the longest to initialize, it is kicked off + as early as possible. At this time :class:`~tools.manual.detected_faces.DetectedFaces` is + not yet available. + + Once the Aligner has initialized, this function is called to add the + :class:`~tools.manual.detected_faces.DetectedFaces` class as a property of the Aligner. + + Parameters + ---------- + detected_faces: :class:`~tools.manual.detected_faces.DetectedFaces` + The class that holds the :class:`~lib.align.DetectedFace` objects for the + current Manual session + """ + logger.debug("Linking detected_faces: %s", detected_faces) + self._detected_faces = detected_faces + + def get_landmarks(self, frame_index: int, face_index: int, aligner: TypeManualExtractor + ) -> np.ndarray: + """ Feed the detected face into the alignment pipeline and retrieve the landmarks. + + The face to feed into the aligner is generated from the given frame and face indices. + + Parameters + ---------- + frame_index: int + The frame index to extract the aligned face for + face_index: int + The face index within the current frame to extract the face for + aligner: Literal["FAN", "cv2-dnn"] + The aligner to use to extract the face + + Returns + ------- + :class:`numpy.ndarray` + The 68 point landmark alignments + """ + logger.trace("frame_index: %s, face_index: %s, aligner: %s", # type:ignore[attr-defined] + frame_index, face_index, aligner) + self._frame_index = frame_index + self._face_index = face_index + self._aligner = aligner + self._in_queue.put(self._feed_face) + extractor = self._aligners[aligner] + assert extractor is not None + detected_face = next(extractor.detected_faces()).detected_faces[0] + logger.trace("landmarks: %s", detected_face.landmarks_xy) # type:ignore[attr-defined] + return detected_face.landmarks_xy + + def _remove_nn_masks(self, detected_face: DetectedFace) -> None: + """ Remove any non-landmarks based masks on a landmark edit + + Parameters + ---------- + detected_face: + The detected face object to remove masks from + """ + del_masks = {m for m in detected_face.mask if m not in ("components", "extended")} + logger.debug("Removing masks after landmark update: %s", del_masks) + for mask in del_masks: + del detected_face.mask[mask] + + def get_masks(self, frame_index: int, face_index: int) -> dict[str, align.aligned_mask.Mask]: + """ Feed the aligned face into the mask pipeline and retrieve the updated masks. + + The face to feed into the aligner is generated from the given frame and face indices. + This is to be called when a manual update is done on the landmarks, and new masks need + generating. + + Parameters + ---------- + frame_index: int + The frame index to extract the aligned face for + face_index: int + The face index within the current frame to extract the face for + + Returns + ------- + dict[str, :class:`~lib.align.aligned_mask.Mask`] + The updated masks + """ + logger.trace("frame_index: %s, face_index: %s", # type:ignore[attr-defined] + frame_index, face_index) + self._frame_index = frame_index + self._face_index = face_index + self._aligner = "mask" + self._in_queue.put(self._feed_face) + assert self._aligners["mask"] is not None + detected_face = next(self._aligners["mask"].detected_faces()).detected_faces[0] + self._remove_nn_masks(detected_face) + logger.debug("mask: %s", detected_face.mask) + return detected_face.mask + + def set_normalization_method(self, method: T.Literal["none", "clahe", "hist", "mean"]) -> None: + """ Change the normalization method for faces fed into the aligner. + The normalization method is user adjustable from the GUI. When this method is triggered + the method is updated for all aligner pipelines. + + Parameters + ---------- + method: Literal["none", "clahe", "hist", "mean"] + The normalization method to use + """ + logger.debug("Setting normalization method to: '%s'", method) + for plugin, aligner in self._aligners.items(): + assert aligner is not None + if plugin == "mask": + continue + logger.debug("Setting to: '%s'", method) + aligner.aligner.set_normalize_method(method) + + +class FrameLoader(): + """ Loads the frames, sets the frame count to :attr:`TkGlobals.frame_count` and handles the + return of the correct frame for the GUI. + + Parameters + ---------- + tk_globals: :class:`~tools.manual.manual.TkGlobals` + The tkinter variables that apply to the whole of the GUI + frames_location: str + The path to the input frames + video_meta_data: dict + The meta data held within the alignments file, if it exists and the input is a video + file_list: list[str] + The list of filenames that exist within the alignments file + """ + def __init__(self, + tk_globals: TkGlobals, + frames_location: str, + video_meta_data: dict[str, list[int] | list[float] | None], + file_list: list[str]) -> None: + logger.debug(parse_class_init(locals())) + self._globals = tk_globals + self._loader: SingleFrameLoader | None = None + self._current_idx = 0 + self._init_thread = self._background_init_frames(frames_location, + video_meta_data, + file_list) + self._globals.var_frame_index.trace_add("write", self._set_frame) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def is_initialized(self) -> bool: + """ bool: ``True`` if the Frame Loader has completed initialization. """ + thread_is_alive = self._init_thread.is_alive() + if thread_is_alive: + self._init_thread.check_and_raise_error() + else: + self._init_thread.join() + self._set_frame(initialize=True) # Setting initial frame must be done from main thread + return not thread_is_alive + + @property + def video_meta_data(self) -> dict[str, list[int] | list[float] | None]: + """ dict: The pts_time and key frames for the loader. """ + assert self._loader is not None + return self._loader.video_meta_data + + def _background_init_frames(self, + frames_location: str, + video_meta_data: dict[str, list[int] | list[float] | None], + frame_list: list[str]) -> MultiThread: + """ Launch the images loader in a background thread so we can run other tasks whilst + waiting for initialization. + + Parameters + ---------- + frame_location: str + The location of the source video file/frames folder + video_meta_data: dict + The meta data for video file sources + frame_list: list[str] + The list of frames that exist in the alignments file + """ + thread = MultiThread(self._load_images, + frames_location, + video_meta_data, + frame_list, + thread_count=1, + name=f"{self.__class__.__name__}.init_frames") + thread.start() + return thread + + def _load_images(self, + frames_location: str, + video_meta_data: dict[str, list[int] | list[float] | None], + frame_list: list[str]) -> None: + """ Load the images in a background thread. + + Parameters + ---------- + frame_location: str + The location of the source video file/frames folder + video_meta_data: dict + The meta data for video file sources + frame_list: list[str] + The list of frames that exist in the alignments file + """ + self._loader = SingleFrameLoader(frames_location, video_meta_data=video_meta_data) + if not self._loader.is_video and len(frame_list) < self._loader.count: + files = [os.path.basename(f) for f in self._loader.file_list] + skip_list = [idx for idx, fname in enumerate(files) if fname not in frame_list] + logger.debug("Adding %s entries to skip list for images not in alignments file", + len(skip_list)) + self._loader.add_skip_list(skip_list) + self._globals.set_frame_count(self._loader.process_count) + + def _set_frame(self, # pylint:disable=unused-argument + *args, + initialize: bool = False) -> None: + """ Set the currently loaded frame to :attr:`_current_frame` and trigger a full GUI update. + + If the loader has not been initialized, or the navigation position is the same as the + current position and the face is not zoomed in, then this returns having done nothing. + + Parameters + ---------- + args: tuple + :class:`tkinter.Event` arguments. Required but not used. + initialize: bool, optional + ``True`` if initializing for the first frame to be displayed otherwise ``False``. + Default: ``False`` + """ + position = self._globals.frame_index + if not initialize and (position == self._current_idx and not self._globals.is_zoomed): + logger.trace("Update criteria not met. Not updating: " # type:ignore[attr-defined] + "(initialize: %s, position: %s, current_idx: %s, is_zoomed: %s)", + initialize, position, self._current_idx, self._globals.is_zoomed) + return + if position == -1: + filename = "No Frame" + frame = np.ones(self._globals.frame_display_dims + (3, ), dtype="uint8") + else: + assert self._loader is not None + filename, frame = self._loader.image_from_index(position) + logger.trace("filename: %s, frame: %s, position: %s", # type:ignore[attr-defined] + filename, frame.shape, position) + self._globals.set_current_frame(frame, filename) + self._current_idx = position + self._globals.var_full_update.set(True) + self._globals.var_update_active_viewport.set(True) + + +__all__ = get_module_objects(__name__) diff --git a/tools/manual/thumbnails.py b/tools/manual/thumbnails.py new file mode 100644 index 0000000000..7586b29429 --- /dev/null +++ b/tools/manual/thumbnails.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +""" Thumbnail generator for the manual tool """ +from __future__ import annotations +import logging +import typing as T +import os + +from dataclasses import dataclass +from time import sleep +from threading import Lock + +import imageio +import numpy as np + +from tqdm import tqdm +from lib.align import AlignedFace +from lib.image import SingleFrameLoader, generate_thumbnail +from lib.multithreading import MultiThread +from lib.utils import get_module_objects + +if T.TYPE_CHECKING: + from .detected_faces import DetectedFaces + +logger = logging.getLogger(__name__) + + +@dataclass +class ProgressBar: + """ Thread-safe progress bar for tracking thumbnail generation progress """ + pbar: tqdm | None = None + lock = Lock() + + +@dataclass +class VideoMeta: + """ Holds meta information about a video file + + Parameters + ---------- + key_frames: list[int] + List of key frame indices for the video + pts_times: list[float] + List of presentation timestams for the video + """ + key_frames: list[int] | None = None + pts_times: list[float] | None = None + + +class ThumbsCreator(): + """ Background loader to generate thumbnails for the alignments file. Generates low resolution + thumbnails in parallel threads for faster processing. + + Parameters + ---------- + detected_faces: :class:`~tool.manual.faces.DetectedFaces` + The :class:`~lib.align.DetectedFace` objects for this video + input_location: str + The location of the input folder of frames or video file + single_process: bool + ``True`` to generated thumbs in a single process otherwise ``False`` + """ + def __init__(self, + detected_faces: DetectedFaces, + input_location: str, + single_process: bool) -> None: + logger.debug("Initializing %s: (detected_faces: %s, input_location: %s, " + "single_process: %s)", self.__class__.__name__, detected_faces, + input_location, single_process) + self._size = 80 + self._pbar = ProgressBar() + self._meta = VideoMeta( + key_frames=T.cast(list[int] | None, + detected_faces.video_meta_data.get("keyframes", None)), + pts_times=T.cast(list[float] | None, + detected_faces.video_meta_data.get("pts_time", None))) + self._location = input_location + self._alignments = detected_faces._alignments + self._frame_faces = detected_faces._frame_faces + + self._is_video = self._meta.pts_times is not None and self._meta.key_frames is not None + + cpu_count = os.cpu_count() + self._num_threads = 1 if cpu_count is None or cpu_count <= 2 else cpu_count - 2 + + if self._is_video and single_process: + self._num_threads = 1 + elif self._is_video and not single_process: + assert self._meta.key_frames is not None + self._num_threads = min(self._num_threads, len(self._meta.key_frames)) + else: + self._num_threads = max(self._num_threads, 32) + self._threads: list[MultiThread] = [] + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def has_thumbs(self) -> bool: + """ bool: ``True`` if the underlying alignments file holds thumbnail images + otherwise ``False``. """ + return self._alignments.thumbnails.has_thumbnails + + def generate_cache(self) -> None: + """ Extract the face thumbnails from a video or folder of images into the + alignments file. """ + self._pbar.pbar = tqdm(desc="Caching Thumbnails", + leave=False, + total=len(self._frame_faces)) + if self._is_video: + self._launch_video() + else: + self._launch_folder() + while True: + self._check_and_raise_error() + if all(not thread.is_alive() for thread in self._threads): + break + sleep(1) + self._join_threads() + self._pbar.pbar.close() + self._alignments.save() + + # << PRIVATE METHODS >> # + def _check_and_raise_error(self) -> None: + """ Monitor the loading threads for errors and raise if any occur. """ + for thread in self._threads: + thread.check_and_raise_error() + + def _join_threads(self) -> None: + """ Join the loading threads """ + logger.debug("Joining face viewer loading threads") + for thread in self._threads: + thread.join() + + def _launch_video(self) -> None: + """ Launch multiple :class:`lib.multithreading.MultiThread` objects to load faces from + a video file. + + Splits the video into segments and passes each of these segments to separate background + threads for some speed up. + """ + key_frames = self._meta.key_frames + pts_times = self._meta.pts_times + assert key_frames is not None and pts_times is not None + key_frame_split = len(key_frames) // self._num_threads + for idx in range(self._num_threads): + is_final = idx == self._num_threads - 1 + start_idx: int = idx * key_frame_split + keyframe_idx = len(key_frames) - 1 if is_final else start_idx + key_frame_split + end_idx = key_frames[keyframe_idx] + start_pts = pts_times[key_frames[start_idx]] + end_pts = False if idx + 1 == self._num_threads else pts_times[end_idx] + starting_index = pts_times.index(start_pts) + if end_pts: + segment_count = len(pts_times[key_frames[start_idx]:end_idx]) + else: + segment_count = len(pts_times[key_frames[start_idx]:]) + logger.debug("thread index: %s, start_idx: %s, end_idx: %s, start_pts: %s, " + "end_pts: %s, starting_index: %s, segment_count: %s", idx, start_idx, + end_idx, start_pts, end_pts, starting_index, segment_count) + thread = MultiThread(self._load_from_video, + start_pts, + end_pts, + starting_index, + segment_count) + thread.start() + self._threads.append(thread) + + def _launch_folder(self) -> None: + """ Launch :class:`lib.multithreading.MultiThread` to retrieve faces from a + folder of images. + + Goes through the file list one at a time, passing each file to a separate background + thread for some speed up. + """ + reader = SingleFrameLoader(self._location) + skip_list = [idx for idx, f in enumerate(reader.file_list) + if os.path.basename(f) not in self._alignments.data] + if skip_list: + reader.add_skip_list(skip_list) + num_threads = min(reader.process_count, self._num_threads) + frame_split = reader.process_count // self._num_threads + logger.debug("total images: %s, num_threads: %s, frames_per_thread: %s", + reader.process_count, num_threads, frame_split) + for idx in range(num_threads): + is_final = idx == num_threads - 1 + start_idx = idx * frame_split + end_idx = reader.process_count if is_final else start_idx + frame_split + thread = MultiThread(self._load_from_folder, reader, start_idx, end_idx) + thread.start() + self._threads.append(thread) + + def _load_from_video(self, + pts_start: float, + pts_end: float, + start_index: int, + segment_count: int) -> None: + """ Loads faces from video for the given segment of the source video. + + Each segment of the video is extracted from in a different background thread. + + Parameters + ---------- + pts_start: float + The start time to cut the segment out of the video + pts_end: float + The end time to cut the segment out of the video + start_index: int + The frame index that this segment starts from. Used for calculating the actual frame + index of each frame extracted + segment_count: int + The number of frames that appear in this segment. Used for ending early in case more + frames come out of the segment than should appear (sometimes more frames are picked up + at the end of the segment, so these are discarded) + """ + logger.debug("pts_start: %s, pts_end: %s, start_index: %s, segment_count: %s", + pts_start, pts_end, start_index, segment_count) + reader = self._get_reader(pts_start, pts_end) + idx = 0 + sample_filename, ext = os.path.splitext(next(fname for fname in self._alignments.data)) + vidname = sample_filename[:sample_filename.rfind("_")] + for idx, frame in enumerate(reader): + frame_idx = idx + start_index + filename = f"{vidname}_{frame_idx + 1:06d}{ext}" + self._set_thumbail(filename, frame[..., ::-1], frame_idx) + if idx == segment_count - 1: + # Sometimes extra frames are picked up at the end of a segment, so stop + # processing when segment frame count has been hit. + break + reader.close() + logger.debug("Segment complete: (starting_frame_index: %s, processed_count: %s)", + start_index, idx) + + def _get_reader(self, pts_start: float, pts_end: float): + """ Get an imageio iterator for this thread's segment. + + Parameters + ---------- + pts_start: float + The start time to cut the segment out of the video + pts_end: float + The end time to cut the segment out of the video + + Returns + ------- + :class:`imageio.Reader` + A reader iterator for the requested segment of video + """ + input_params = ["-ss", str(pts_start)] + if pts_end: + input_params.extend(["-to", str(pts_end)]) + logger.debug("pts_start: %s, pts_end: %s, input_params: %s", + pts_start, pts_end, input_params) + return imageio.get_reader(self._location, + "ffmpeg", # type:ignore[arg-type] + input_params=input_params) + + def _load_from_folder(self, + reader: SingleFrameLoader, + start_index: int, + end_index: int) -> None: + """ Loads faces from the given range of frame indices from a folder of images. + + Each frame range is extracted in a different background thread. + + Parameters + ---------- + reader: :class:`lib.image.SingleFrameLoader` + The reader that is used to retrieve the requested frame + start_index: int + The starting frame index for the images to extract faces from + end_index: int + The end frame index for the images to extract faces from + """ + logger.debug("reader: %s, start_index: %s, end_index: %s", + reader, start_index, end_index) + for frame_index in range(start_index, end_index): + filename, frame = reader.image_from_index(frame_index) + self._set_thumbail(filename, frame, frame_index) + logger.debug("Segment complete: (start_index: %s, processed_count: %s)", + start_index, end_index - start_index) + + def _set_thumbail(self, filename: str, frame: np.ndarray, frame_index: int) -> None: + """ Extracts the faces from the frame and adds to alignments file + + Parameters + ---------- + filename: str + The filename of the frame within the alignments file + frame: :class:`numpy.ndarray` + The frame that contains the faces + frame_index: int + The frame index of this frame in the :attr:`_frame_faces` + """ + for face_idx, face in enumerate(self._frame_faces[frame_index]): + aligned = AlignedFace(face.landmarks_xy, + image=frame, + centering="head", + size=96) + face.thumbnail = generate_thumbnail(aligned.face, size=96) + assert face.thumbnail is not None + self._alignments.thumbnails.add_thumbnail(filename, face_idx, face.thumbnail) + with self._pbar.lock: + assert self._pbar.pbar is not None + self._pbar.pbar.update(1) + + +__all__ = get_module_objects(__name__) diff --git a/tools/mask/__init__.py b/tools/mask/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/mask/cli.py b/tools/mask/cli.py new file mode 100644 index 0000000000..44a5c6c7ec --- /dev/null +++ b/tools/mask/cli.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +""" Command Line Arguments for tools """ +import gettext + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import (DirOrFileFullPaths, DirFullPaths, FileFullPaths, Radio, Slider) +from lib.utils import get_module_objects +from plugins.plugin_loader import PluginLoader + + +# pylint:disable=duplicate-code +# LOCALES +_LANG = gettext.translation("tools.mask.cli", localedir="locales", fallback=True) +_ = _LANG.gettext + +_HELPTEXT = _("This tool allows you to generate, import, export or preview masks for existing " + "alignments.") + + +class MaskArgs(FaceSwapArgs): + """ Class to parse the command line arguments for Mask tool """ + + @staticmethod + def get_info(): + """ Return command information """ + return _("Mask tool\nGenerate, import, export or preview masks for existing alignments " + "files.") + + @staticmethod + def get_argument_list(): + argument_list = [] + argument_list.append({ + "opts": ("-a", "--alignments"), + "action": FileFullPaths, + "type": str, + "group": _("data"), + "required": False, + "filetypes": "alignments", + "help": _( + "Full path to the alignments file that contains the masks if not at the " + "default location. NB: If the input-type is faces and you wish to update the " + "corresponding alignments file, then you must provide a value here as the " + "location cannot be automatically detected.")}) + argument_list.append({ + "opts": ("-i", "--input"), + "action": DirOrFileFullPaths, + "type": str, + "group": _("data"), + "filetypes": "video", + "required": True, + "help": _( + "Directory containing extracted faces, source frames, or a video file.")}) + argument_list.append({ + "opts": ("-I", "--input-type"), + "action": Radio, + "type": str.lower, + "choices": ("faces", "frames"), + "dest": "input_type", + "group": _("data"), + "default": "frames", + "help": _( + "R|Whether the `input` is a folder of faces or a folder frames/video" + "\nL|faces: The input is a folder containing extracted faces." + "\nL|frames: The input is a folder containing frames or is a video")}) + argument_list.append({ + "opts": ("-B", "--batch-mode"), + "action": "store_true", + "dest": "batch_mode", + "default": False, + "group": _("data"), + "help": _( + "R|Run the mask tool on multiple sources. If selected then the other options " + "should be set as follows:" + "\nL|input: A parent folder containing either all of the video files to be " + "processed, or containing sub-folders of frames/faces." + "\nL|output-folder: If provided, then sub-folders will be created within the " + "given location to hold the previews for each input." + "\nL|alignments: Alignments field will be ignored for batch processing. The " + "alignments files must exist at the default location (for frames). For batch " + "processing of masks with 'faces' as the input type, then only the PNG header " + "within the extracted faces will be updated.")}) + argument_list.append({ + "opts": ("-M", "--masker"), + "action": Radio, + "type": str.lower, + "choices": PluginLoader.get_available_extractors("mask"), + "default": "extended", + "group": _("process"), + "help": _( + "R|Masker to use." + "\nL|bisenet-fp: Relatively lightweight NN based mask that provides more " + "refined control over the area to be masked including full head masking " + "(configurable in mask settings)." + "\nL|components: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks to create a mask." + "\nL|custom: A dummy mask that fills the mask area with all 1s or 0s " + "(configurable in settings). This is only required if you intend to manually " + "edit the custom masks yourself in the manual tool. This mask does not use the " + "GPU." + "\nL|extended: Mask designed to provide facial segmentation based on the " + "positioning of landmark locations. A convex hull is constructed around the " + "exterior of the landmarks and the mask is extended upwards onto the forehead." + "\nL|vgg-clear: Mask designed to provide smart segmentation of mostly frontal " + "faces clear of obstructions. Profile faces and obstructions may result in " + "sub-par performance." + "\nL|vgg-obstructed: Mask designed to provide smart segmentation of mostly " + "frontal faces. The mask model has been specifically trained to recognize " + "some facial obstructions (hands and eyeglasses). Profile faces may result in " + "sub-par performance." + "\nL|unet-dfl: Mask designed to provide smart segmentation of mostly frontal " + "faces. The mask model has been trained by community members. Profile faces " + "may result in sub-par performance.")}) + argument_list.append({ + "opts": ("-p", "--processing"), + "action": Radio, + "type": str.lower, + "choices": ("all", "missing", "output", "import"), + "default": "all", + "group": _("process"), + "help": _( + "R|The Mask tool process to perform." + "\nL|all: Update the mask for all faces in the alignments file for the selected " + "'masker'." + "\nL|missing: Create a mask for all faces in the alignments file where a mask " + "does not previously exist for the selected 'masker'." + "\nL|output: Don't update the masks, just output the selected 'masker' for " + "review/editing in external tools to the given output folder." + "\nL|import: Import masks that have been edited outside of faceswap into the " + "alignments file. Note: 'custom' must be the selected 'masker' and the masks must " + "be in the same format as the 'input-type' (frames or faces)")}) + argument_list.append({ + "opts": ("-m", "--mask-path"), + "action": DirFullPaths, + "type": str, + "group": _("import"), + "help": _( + "R|Import only. The path to the folder that contains masks to be imported." + "\nL|How the masks are provided is not important, but they will be stored, " + "internally, as 8-bit grayscale images." + "\nL|If the input are images, then the masks must be named exactly the same as " + "input frames/faces (excluding the file extension)." + "\nL|If the input is a video file, then the filename of the masks is not " + "important but should contain the frame number at the end of the filename (but " + "before the file extension). The frame number can be separated from the rest of " + "the filename by any non-numeric character and can be padded by any number of " + "zeros. The frame number must correspond correctly to the frame number in the " + "original video (starting from frame 1).")}) + argument_list.append({ + "opts": ("-c", "--centering"), + "action": Radio, + "type": str.lower, + "choices": ("face", "head", "legacy"), + "default": "face", + "group": _("import"), + "help": _( + "R|Import/Output only. When importing masks, this is the centering to use. For " + "output this is only used for outputting custom imported masks, and should " + "correspond to the centering used when importing the mask. Note: For any job " + "other than 'import' and 'output' this option is ignored as mask centering is " + "handled internally." + "\nL|face: Centers the mask on the center of the face, adjusting for " + "pitch and yaw. Outside of requirements for full head masking/training, this " + "is likely to be the best choice." + "\nL|head: Centers the mask on the center of the head, adjusting for " + "pitch and yaw. Note: You should only select head centering if you intend to " + "include the full head (including hair) within the mask and are looking to " + "train a full head model." + "\nL|legacy: The 'original' extraction technique. Centers the mask near the " + " of the nose with and crops closely to the face. Can result in the edges of " + "the mask appearing outside of the training area.")}) + argument_list.append({ + "opts": ("-s", "--storage-size"), + "dest": "storage_size", + "action": Slider, + "type": int, + "group": _("import"), + "min_max": (64, 1024), + "default": 128, + "rounding": 64, + "help": _( + "Import only. The size, in pixels to internally store the mask at.\nThe default " + "is 128 which is fine for nearly all usecases. Larger sizes will result in larger " + "alignments files and longer processing.")}) + argument_list.append({ + "opts": ("-o", "--output-folder"), + "action": DirFullPaths, + "dest": "output", + "type": str, + "group": _("output"), + "help": _( + "Optional output location. If provided, a preview of the masks created will " + "be output in the given folder.")}) + argument_list.append({ + "opts": ("-b", "--blur_kernel"), + "action": Slider, + "type": int, + "group": _("output"), + "min_max": (0, 9), + "default": 0, + "rounding": 1, + "help": _( + "Apply gaussian blur to the mask output. Has the effect of smoothing the " + "edges of the mask giving less of a hard edge. the size is in pixels. This " + "value should be odd, if an even number is passed in then it will be rounded " + "to the next odd number. NB: Only effects the output preview. Set to 0 for " + "off")}) + argument_list.append({ + "opts": ("-t", "--threshold"), + "action": Slider, + "type": int, + "group": _("output"), + "min_max": (0, 50), + "default": 0, + "rounding": 1, + "help": _( + "Helps reduce 'blotchiness' on some masks by making light shades white " + "and dark shades black. Higher values will impact more of the mask. NB: " + "Only effects the output preview. Set to 0 for off")}) + argument_list.append({ + "opts": ("-O", "--output-type"), + "action": Radio, + "type": str.lower, + "choices": ("combined", "masked", "mask"), + "default": "combined", + "group": _("output"), + "help": _( + "R|How to format the output when processing is set to 'output'." + "\nL|combined: The image contains the face/frame, face mask and masked face." + "\nL|masked: Output the face/frame as rgba image with the face masked." + "\nL|mask: Only output the mask as a single channel image.")}) + argument_list.append({ + "opts": ("-f", "--full-frame"), + "action": "store_true", + "default": False, + "group": _("output"), + "help": _( + "R|Whether to output the whole frame or only the face box when using " + "output processing. Only has an effect when using frames as input.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/mask/loader.py b/tools/mask/loader.py new file mode 100644 index 0000000000..191fdd46c1 --- /dev/null +++ b/tools/mask/loader.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" Handles loading of faces/frames from source locations and pairing with alignments +information """ +from __future__ import annotations + +import logging +import os +import typing as T + +import numpy as np +from tqdm import tqdm + +from lib.align import alignments, DetectedFace, update_legacy_png_header +from lib.image import FacesLoader, ImagesLoader +from lib.utils import get_module_objects +from plugins.extract import ExtractMedia + +if T.TYPE_CHECKING: + from lib.align.alignments import PNGHeaderDict +logger = logging.getLogger(__name__) + + +class Loader: + """ Loader for reading source data from disk, and yielding the output paired with alignment + information + + Parameters + ---------- + location: str + Full path to the source files location + is_faces: bool + ``True`` if the source is a folder of faceswap extracted faces + """ + def __init__(self, location: str, is_faces: bool) -> None: + logger.debug("Initializing %s (location: %s, is_faces: %s)", + self.__class__.__name__, location, is_faces) + + self._is_faces = is_faces + self._loader = FacesLoader(location) if is_faces else ImagesLoader(location) + self._alignments: alignments.Alignments | None = None + self._skip_count = 0 + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def file_list(self) -> list[str]: + """list[str]: Full file list of source files to be loaded """ + return self._loader.file_list + + @property + def is_video(self) -> bool: + """bool: ``True`` if the source is a video file otherwise ``False`` """ + return self._loader.is_video + + @property + def location(self) -> str: + """str: Full path to the source folder/video file location """ + return self._loader.location + + @property + def skip_count(self) -> int: + """int: The number of faces/frames that have been skipped due to no match in alignments + file """ + return self._skip_count + + def add_alignments(self, alignments_object: alignments.Alignments | None) -> None: + """ Add the loaded alignments to :attr:`_alignments` for content matching + + Parameters + ---------- + alignments_object: :class:`~lib.align.Alignments` | None + The alignments file object or ``None`` if not provided + """ + logger.debug("Adding alignments to loader: %s", alignments_object) + self._alignments = alignments_object + + @classmethod + def _get_detected_face(cls, alignment: alignments.AlignmentFileDict) -> DetectedFace: + """ Convert an alignment dict item to a detected_face object + + Parameters + ---------- + alignment: :class:`lib.align.alignments.AlignmentFileDict` + The alignment dict for a face + + Returns + ------- + :class:`~lib.align.detected_face.DetectedFace`: + The corresponding detected_face object for the alignment + """ + detected_face = DetectedFace() + detected_face.from_alignment(alignment) + return detected_face + + def _process_face(self, + filename: str, + image: np.ndarray, + metadata: PNGHeaderDict) -> ExtractMedia | None: + """ Process a single face when masking from face images + + Parameters + ---------- + filename: str + the filename currently being processed + image: :class:`numpy.ndarray` + The current face being processed + metadata: dict + The source frame metadata from the PNG header + + Returns + ------- + :class:`plugins.pipeline.ExtractMedia` | None + the extract media object for the processed face or ``None`` if alignment information + could not be found + """ + frame_name = metadata["source"]["source_filename"] + face_index = metadata["source"]["face_index"] + + if self._alignments is None: # mask from PNG header + lookup_index = 0 + aligns = [T.cast(alignments.AlignmentFileDict, metadata["alignments"])] + else: # mask from Alignments file + lookup_index = face_index + aligns = self._alignments.get_faces_in_frame(frame_name) + if not aligns or face_index > len(aligns) - 1: + self._skip_count += 1 + logger.warning("Skipping Face not found in alignments file: '%s'", filename) + return None + + alignment = aligns[lookup_index] + detected_face = self._get_detected_face(alignment) + + retval = ExtractMedia(filename, image, detected_faces=[detected_face], is_aligned=True) + retval.add_frame_metadata(metadata["source"]) + return retval + + def _from_faces(self) -> T.Generator[ExtractMedia, None, None]: + """ Load content from pre-aligned faces and pair with corresponding metadata + + Yields + ------ + :class:`plugins.pipeline.ExtractMedia` + the extract media object for the processed face + """ + log_once = False + for filename, image, metadata in tqdm(self._loader.load(), total=self._loader.count): + if not metadata: # Legacy faces. Update the headers + if self._alignments is None: + logger.error("Legacy faces have been discovered, but no alignments file " + "provided. You must provide an alignments file for this face set") + break + + if not log_once: + logger.warning("Legacy faces discovered. These faces will be updated") + log_once = True + + metadata = update_legacy_png_header(filename, self._alignments) + if not metadata: # Face not found + self._skip_count += 1 + logger.warning("Legacy face not found in alignments file. This face has not " + "been updated: '%s'", filename) + continue + + if "source_frame_dims" not in metadata.get("source", {}): + logger.error("The faces need to be re-extracted as at least some of them do not " + "contain information required to correctly generate masks.") + logger.error("You can re-extract the face-set by using the Alignments Tool's " + "Extract job.") + break + + retval = self._process_face(filename, image, metadata) + if retval is None: + continue + + yield retval + + def _from_frames(self) -> T.Generator[ExtractMedia, None, None]: + """ Load content from frames and and pair with corresponding metadata + + Yields + ------ + :class:`plugins.pipeline.ExtractMedia` + the extract media object for the processed face + """ + assert self._alignments is not None + for filename, image in tqdm(self._loader.load(), total=self._loader.count): + frame = os.path.basename(filename) + + if not self._alignments.frame_exists(frame): + self._skip_count += 1 + logger.warning("Skipping frame not in alignments file: '%s'", frame) + continue + + if not self._alignments.frame_has_faces(frame): + logger.debug("Skipping frame with no faces: '%s'", frame) + continue + + faces_in_frame = self._alignments.get_faces_in_frame(frame) + detected_faces = [self._get_detected_face(alignment) for alignment in faces_in_frame] + retval = ExtractMedia(filename, image, detected_faces=detected_faces) + yield retval + + def load(self) -> T.Generator[ExtractMedia, None, None]: + """ Load content from source and pair with corresponding alignment data + + Yields + ------ + :class:`plugins.pipeline.ExtractMedia` + the extract media object for the processed face + """ + if self._is_faces: + iterator = self._from_faces + else: + iterator = self._from_frames + + yield from iterator() + + if self._skip_count > 0: + logger.warning("%s face(s) skipped due to not existing in the alignments file", + self._skip_count) + + +__all__ = get_module_objects(__name__) diff --git a/tools/mask/mask.py b/tools/mask/mask.py new file mode 100644 index 0000000000..a849cb0b5a --- /dev/null +++ b/tools/mask/mask.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" Tool to generate masks and previews of masks for existing alignments file """ +from __future__ import annotations +import logging +import os +import sys + +from argparse import Namespace +from multiprocessing import Process + +from lib.align import Alignments + +from lib.utils import get_module_objects, handle_deprecated_cliopts, VIDEO_EXTENSIONS +from plugins.extract import ExtractMedia + +from .loader import Loader +from .mask_import import Import +from .mask_generate import MaskGenerator +from .mask_output import Output + + +logger = logging.getLogger(__name__) + + +class Mask: + """ This tool is part of the Faceswap Tools suite and should be called from + ``python tools.py mask`` command. + + Faceswap Masks tool. Generate masks from existing alignments files, and output masks + for preview. + + Wrapper for the mask process to run in either batch mode or single use mode + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s", self.__class__.__name__, arguments) + if arguments.batch_mode and arguments.processing == "import": + logger.error("Batch mode is not supported for 'import' processing") + sys.exit(0) + + self._args = arguments + self._input_locations = self._get_input_locations() + + def _get_input_locations(self) -> list[str]: + """ Obtain the full path to input locations. Will be a list of locations if batch mode is + selected, or containing a single location if batch mode is not selected. + + Returns + ------- + list: + The list of input location paths + """ + if not self._args.batch_mode: + return [self._args.input] + + if not os.path.isdir(self._args.input): + logger.error("Batch mode is selected but input '%s' is not a folder", self._args.input) + sys.exit(1) + + retval = [os.path.join(self._args.input, fname) + for fname in os.listdir(self._args.input) + if os.path.isdir(os.path.join(self._args.input, fname)) + or os.path.splitext(fname)[-1].lower() in VIDEO_EXTENSIONS] + logger.info("Batch mode selected. Processing locations: %s", retval) + return retval + + def _get_output_location(self, input_location: str) -> str: + """ Obtain the path to an output folder for faces for a given input location. + + A sub-folder within the user supplied output location will be returned based on + the input filename + + Parameters + ---------- + input_location: str + The full path to an input video or folder of images + """ + retval = os.path.join(self._args.output, + os.path.splitext(os.path.basename(input_location))[0]) + logger.debug("Returning output: '%s' for input: '%s'", retval, input_location) + return retval + + @staticmethod + def _run_mask_process(arguments: Namespace) -> None: + """ The mask process to be run in a spawned process. + + In some instances, batch-mode memory leaks. Launching each job in a separate process + prevents this leak. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments to be used for the given job + """ + logger.debug("Starting process: (arguments: %s)", arguments) + mask = _Mask(arguments) + mask.process() + logger.debug("Finished process: (arguments: %s)", arguments) + + def process(self) -> None: + """ The entry point for triggering the Extraction Process. + + Should only be called from :class:`lib.cli.launcher.ScriptExecutor` + """ + for idx, location in enumerate(self._input_locations): + if self._args.batch_mode: + logger.info("Processing job %s of %s: %s", + idx + 1, len(self._input_locations), location) + arguments = Namespace(**self._args.__dict__) + arguments.input = location + # Due to differences in how alignments are handled for frames/faces, only default + # locations allowed + arguments.alignments = None + if self._args.output: + arguments.output = self._get_output_location(location) + else: + arguments = self._args + + if len(self._input_locations) > 1: + proc = Process(target=self._run_mask_process, args=(arguments, )) + proc.start() + proc.join() + else: + self._run_mask_process(arguments) + + +class _Mask: + """ This tool is part of the Faceswap Tools suite and should be called from + ``python tools.py mask`` command. + + Faceswap Masks tool. Generate masks from existing alignments files, and output masks + for preview. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) + arguments = handle_deprecated_cliopts(arguments) + self._update_type = arguments.processing + self._input_is_faces = arguments.input_type == "faces" + self._check_input(arguments.input) + + self._loader = Loader(arguments.input, self._input_is_faces) + self._alignments = self._get_alignments(arguments.alignments, arguments.input) + + if self._loader.is_video and self._alignments is not None: + self._alignments.update_legacy_has_source(os.path.basename(self._loader.location)) + + self._loader.add_alignments(self._alignments) + + self._output = Output(arguments, self._alignments, self._loader.file_list) + + self._import = None + if self._update_type == "import": + self._import = Import(arguments.mask_path, + arguments.centering, + arguments.storage_size, + self._input_is_faces, + self._loader, + self._alignments, + arguments.input, + arguments.masker) + + self._mask_gen: MaskGenerator | None = None + if self._update_type in ("all", "missing"): + self._mask_gen = MaskGenerator(arguments.masker, + self._update_type == "all", + self._input_is_faces, + self._loader, + self._alignments, + arguments.input) + + logger.debug("Initialized %s", self.__class__.__name__) + + def _check_input(self, mask_input: str) -> None: + """ Check the input is valid. If it isn't exit with a logged error + + Parameters + ---------- + mask_input: str + Path to the input folder/video + """ + if not os.path.exists(mask_input): + logger.error("Location cannot be found: '%s'", mask_input) + sys.exit(0) + if os.path.isfile(mask_input) and self._input_is_faces: + logger.error("Input type 'faces' was selected but input is not a folder: '%s'", + mask_input) + sys.exit(0) + logger.debug("input '%s' is valid", mask_input) + + def _get_alignments(self, alignments: str | None, input_location: str) -> Alignments | None: + """ Obtain the alignments from either the given alignments location or the default + location. + + Parameters + ---------- + alignments: str | None + Full path to the alignemnts file if provided or ``None`` if not + input_location: str + Full path to the source files to be used by the mask tool + + Returns + ------- + ``None`` or :class:`~lib.align.alignments.Alignments`: + If output is requested, returns a :class:`~lib.align.alignments.Alignments` otherwise + returns ``None`` + """ + if alignments: + logger.debug("Alignments location provided: %s", alignments) + return Alignments(os.path.dirname(alignments), + filename=os.path.basename(alignments)) + if self._input_is_faces and self._update_type == "output": + logger.debug("No alignments file provided for faces. Using PNG Header for output") + return None + if self._input_is_faces: + logger.warning("Faces input selected without an alignments file. Masks wil only " + "be updated in the faces' PNG Header") + return None + + folder = input_location + if self._loader.is_video: + logger.debug("Alignments from Video File: '%s'", folder) + folder, filename = os.path.split(folder) + filename = f"{os.path.splitext(filename)[0]}_alignments.fsa" + else: + logger.debug("Alignments from Input Folder: '%s'", folder) + filename = "alignments" + + retval = Alignments(folder, filename=filename) + return retval + + def _save_output(self, media: ExtractMedia) -> None: + """ Output masks to disk + + Parameters + ---------- + media: :class:`~plugins.extract.extract_media.ExtractMedia` + The extract media holding the faces to output + """ + filename = os.path.basename(media.frame_metadata["source_filename"] + if self._input_is_faces else media.filename) + dims = media.frame_metadata["source_frame_dims"] if self._input_is_faces else None + for idx, face in enumerate(media.detected_faces): + face_idx = media.frame_metadata["face_index"] if self._input_is_faces else idx + face.image = media.image + self._output.save(filename, face_idx, face, frame_dims=dims) + + def _generate_masks(self) -> None: + """ Generate masks from a mask plugin """ + assert self._mask_gen is not None + + logger.info("Generating masks") + + for media in self._mask_gen.process(): + if self._output.should_save: + self._save_output(media) + + def _import_masks(self) -> None: + """ Import masks that have been generated outside of faceswap """ + assert self._import is not None + logger.info("Importing masks") + + for media in self._loader.load(): + self._import.import_mask(media) + if self._output.should_save: + self._save_output(media) + + if self._alignments is not None and self._import.update_count > 0: + self._alignments.backup() + self._alignments.save() + + if self._import.skip_count > 0: + logger.warning("No masks were found for %s item(s), so these have not been imported", + self._import.skip_count) + + logger.info("Imported masks for %s faces of %s", + self._import.update_count, self._import.update_count + self._import.skip_count) + + def _output_masks(self) -> None: + """ Output masks to selected output folder """ + for media in self._loader.load(): + self._save_output(media) + + def process(self) -> None: + """ The entry point for the Mask tool from :file:`lib.tools.cli`. Runs the Mask process """ + logger.debug("Starting masker process") + + if self._update_type in ("all", "missing"): + self._generate_masks() + + if self._update_type == "import": + self._import_masks() + + if self._update_type == "output": + self._output_masks() + + self._output.close() + logger.debug("Completed masker process") + + +__all__ = get_module_objects(__name__) diff --git a/tools/mask/mask_generate.py b/tools/mask/mask_generate.py new file mode 100644 index 0000000000..ca4971cf38 --- /dev/null +++ b/tools/mask/mask_generate.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" Handles the generation of masks from faceswap for upating into an alignments file """ +from __future__ import annotations + +import logging +import os +import typing as T + +from lib.image import encode_image, ImagesSaver +from lib.multithreading import MultiThread +from lib.utils import get_module_objects +from plugins.extract import Extractor + +if T.TYPE_CHECKING: + from lib import align + from lib.align import DetectedFace + from lib.queue_manager import EventQueue + from plugins.extract import ExtractMedia + from plugins.extract.mask.bisenet_fp import Mask as bfp_mask + from .loader import Loader + + +logger = logging.getLogger(__name__) + + +class MaskGenerator: + """ Uses faceswap's extract pipeline to generate masks and update them into the alignments file + and/or extracted face PNG Headers + + Parameters + ---------- + mask_type: str + The mask type to generate + update_all: bool + ``True`` to update all faces, ``False`` to only update faces missing masks + input_is_faces: bool + ``True`` if the input are faceswap extracted faces otherwise ``False`` + loader: :class:`tools.mask.loader.Loader` + The loader for loading source images/video from disk + """ + def __init__(self, + mask_type: str, + update_all: bool, + input_is_faces: bool, + loader: Loader, + alignments: align.alignments.Alignments | None, + input_location: str) -> None: + logger.debug("Initializing %s (mask_type: %s, update_all: %s, input_is_faces: %s, " + "loader: %s, alignments: %s, input_location: %s)", + self.__class__.__name__, mask_type, update_all, input_is_faces, loader, + alignments, input_location) + + self._update_all = update_all + self._is_faces = input_is_faces + self._alignments = alignments + + self._extractor = self._get_extractor(mask_type) + self._mask_type = self._set_correct_mask_type(mask_type) + self._input_thread = self._set_loader_thread(loader) + self._saver = ImagesSaver(input_location, as_bytes=True) if input_is_faces else None + + self._counts: dict[T.Literal["face", "update"], int] = {"face": 0, "update": 0} + + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_extractor(self, mask_type) -> Extractor: + """ Obtain a Mask extractor plugin and launch it + + Parameters + ---------- + mask_type: str + The mask type to generate + + Returns + ------- + :class:`plugins.extract.pipeline.Extractor`: + The launched Extractor + """ + logger.debug("masker: %s", mask_type) + extractor = Extractor(None, None, mask_type) + extractor.launch() + logger.debug(extractor) + return extractor + + def _set_correct_mask_type(self, mask_type: str) -> str: + """ Some masks have multiple variants that they can be saved depending on config options + + Parameters + ---------- + mask_type: str + The mask type to generate + + Returns + ------- + str + The actual mask variant to update + """ + if mask_type != "bisenet-fp": + return mask_type + + # Hacky look up into masker to get the type of mask + mask_plugin = T.cast("bfp_mask | None", + self._extractor._mask[0]) # pylint:disable=protected-access + assert mask_plugin is not None + new_type = f"{mask_type}_{mask_plugin.storage_centering}" + logger.debug("Updating '%s' to '%s'", mask_type, new_type) + return new_type + + def _needs_update(self, frame: str, idx: int, face: DetectedFace) -> bool: + """ Check if the mask for the current alignment needs updating for the requested mask_type + + Parameters + ---------- + frame: str + The frame name in the alignments file + idx: int + The index of the face for this frame in the alignments file + face: :class:`~lib.align.DetectedFace` + The dected face object to check + + Returns + ------- + bool: + ``True`` if the mask needs to be updated otherwise ``False`` + """ + if self._update_all: + return True + + retval = not face.mask or face.mask.get(self._mask_type, None) is None + + logger.trace("Needs updating: %s, '%s' - %s", # type:ignore[attr-defined] + retval, frame, idx) + return retval + + def _feed_extractor(self, loader: Loader, extract_queue: EventQueue) -> None: + """ Process to feed the extractor from inside a thread + + Parameters + ---------- + loader: class:`tools.mask.loader.Loader` + The loader for loading source images/video from disk + extract_queue: :class:`lib.queue_manager.EventQueue` + The input queue to the extraction pipeline + """ + for media in loader.load(): + self._counts["face"] += len(media.detected_faces) + + if self._is_faces: + assert len(media.detected_faces) == 1 + needs_update = self._needs_update(media.frame_metadata["source_filename"], + media.frame_metadata["face_index"], + media.detected_faces[0]) + else: + # To keep face indexes correct/cover off where only one face in an image is missing + # a mask where there are multiple faces we process all faces again for any frames + # which have missing masks. + needs_update = any(self._needs_update(media.filename, idx, detected_face) + for idx, detected_face in enumerate(media.detected_faces)) + + if not needs_update: + logger.trace("No masks need updating in '%s'", # type:ignore[attr-defined] + media.filename) + continue + + logger.trace("Passing to extractor: '%s'", media.filename) # type:ignore[attr-defined] + extract_queue.put(media) + + logger.debug("Terminating loader thread") + extract_queue.put("EOF") + + def _set_loader_thread(self, loader: Loader) -> MultiThread: + """ Set the iterator to load ExtractMedia objects into the mask extraction pipeline + so we can just iterate through the output masks + + Parameters + ---------- + loader: class:`tools.mask.loader.Loader` + The loader for loading source images/video from disk + """ + in_queue = self._extractor.input_queue + logger.debug("Starting load thread: (loader: %s, queue: %s)", loader, in_queue) + in_thread = MultiThread(self._feed_extractor, loader, in_queue, thread_count=1) + in_thread.start() + logger.debug("Started load thread: %s", in_thread) + return in_thread + + def _update_from_face(self, media: ExtractMedia) -> None: + """ Update the alignments file and/or the extracted face + + Parameters + ---------- + media: :class:`~lib.extract.pipeline.ExtractMedia` + The ExtractMedia object with updated masks + """ + assert len(media.detected_faces) == 1 + assert self._saver is not None + + fname = media.frame_metadata["source_filename"] + idx = media.frame_metadata["face_index"] + face = media.detected_faces[0] + + if self._alignments is not None: + logger.trace("Updating face %s in frame '%s'", idx, fname) # type:ignore[attr-defined] + self._alignments.update_face(fname, idx, face.to_alignment()) + + logger.trace("Updating extracted face: '%s'", media.filename) # type:ignore[attr-defined] + meta: align.alignments.PNGHeaderDict = {"alignments": face.to_png_meta(), + "source": media.frame_metadata} + self._saver.save(media.filename, encode_image(media.image, ".png", metadata=meta)) + + def _update_from_frame(self, media: ExtractMedia) -> None: + """ Update the alignments file + + Parameters + ---------- + media: :class:`~lib.extract.pipeline.ExtractMedia` + The ExtractMedia object with updated masks + """ + assert self._alignments is not None + fname = os.path.basename(media.filename) + logger.trace("Updating %s faces in frame '%s'", # type:ignore[attr-defined] + len(media.detected_faces), fname) + for idx, face in enumerate(media.detected_faces): + self._alignments.update_face(fname, idx, face.to_alignment()) + + def _finalize(self) -> None: + """ Close thread and save alignments on completion """ + logger.debug("Finalizing MaskGenerator") + self._input_thread.join() + + if self._counts["update"] > 0 and self._alignments is not None: + logger.debug("Saving alignments") + self._alignments.backup() + self._alignments.save() + + if self._saver is not None: + logger.debug("Closing face saver") + self._saver.close() + + if self._counts["update"] == 0: + logger.warning("No masks were updated of the %s faces seen", self._counts["face"]) + else: + logger.info("Updated masks for %s faces of %s", + self._counts["update"], self._counts["face"]) + + def process(self) -> T.Generator[ExtractMedia, None, None]: + """ Process the output from the extractor pipeline + + Yields + ------ + :class:`~lib.extract.pipeline.ExtractMedia` + The ExtractMedia object with updated masks + """ + for media in self._extractor.detected_faces(): + self._input_thread.check_and_raise_error() + self._counts["update"] += len(media.detected_faces) + + if self._is_faces: + self._update_from_face(media) + else: + self._update_from_frame(media) + + yield media + + self._finalize() + logger.debug("Completed MaskGenerator process") + + +__all__ = get_module_objects(__name__) diff --git a/tools/mask/mask_import.py b/tools/mask/mask_import.py new file mode 100644 index 0000000000..2b382fa56a --- /dev/null +++ b/tools/mask/mask_import.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +""" Import mask processing for faceswap's mask tool """ +from __future__ import annotations + +import logging +import os +import re +import sys +import typing as T + +import cv2 +from tqdm import tqdm + +from lib.align import AlignedFace +from lib.image import encode_image, ImagesSaver +from lib.utils import get_image_paths, get_module_objects + +if T.TYPE_CHECKING: + import numpy as np + from .loader import Loader + from plugins.extract import ExtractMedia + from lib import align + from lib.align import DetectedFace + from lib.align.aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +class Import: + """ Import masks from disk into an Alignments file + + Parameters + ---------- + import_path: str + The path to the input images + centering: Literal["face", "head", "legacy"] + The centering to store the mask at + storage_size: int + The size to store the mask at + input_is_faces: bool + ``True`` if the input is aligned faces otherwise ``False`` + loader: :class:`~tools.mask.loader.Loader` + The source file loader object + alignments: :class:`~lib.align.alignments.Alignments` | None + The alignments file object for the faces, if provided + mask_type: str + The mask type to update to + """ + def __init__(self, + import_path: str, + centering: CenteringType, + storage_size: int, + input_is_faces: bool, + loader: Loader, + alignments: align.alignments.Alignments | None, + input_location: str, + mask_type: str) -> None: + logger.debug("Initializing %s (import_path: %s, centering: %s, storage_size: %s, " + "input_is_faces: %s, loader: %s, alignments: %s, input_location: %s, " + "mask_type: %s)", self.__class__.__name__, import_path, centering, + storage_size, input_is_faces, loader, alignments, input_location, mask_type) + + self._validate_mask_type(mask_type) + + self._centering: CenteringType = centering + self._size = storage_size + self._is_faces = input_is_faces + self._alignments = alignments + self._re_frame_num = re.compile(r"\d+$") + self._mapping = self._generate_mapping(import_path, loader) + + self._saver = ImagesSaver(input_location, as_bytes=True) if input_is_faces else None + self._counts: dict[T.Literal["skip", "update"], int] = {"skip": 0, "update": 0} + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def skip_count(self) -> int: + """ int: Number of masks that were skipped as they do not exist for given faces """ + return self._counts["skip"] + + @property + def update_count(self) -> int: + """ int: Number of masks that were skipped as they do not exist for given faces """ + return self._counts["update"] + + @classmethod + def _validate_mask_type(cls, mask_type: str) -> None: + """ Validate that the mask type is 'custom' to ensure user does not accidentally overwrite + existing masks they may have editted + + Parameters + ---------- + mask_type: str + The mask type that has been selected + """ + if mask_type == "custom": + return + + logger.error("Masker 'custom' must be selected for importing masks") + sys.exit(1) + + @classmethod + def _get_file_list(cls, path: str) -> list[str]: + """ Check the nask folder exists and obtain the list of images + + Parameters + ---------- + path: str + Full path to the location of mask images to be imported + + Returns + ------- + list[str] + list of full paths to all of the images in the mask folder + """ + if not os.path.isdir(path): + logger.error("Mask path: '%s' is not a folder", path) + sys.exit(1) + paths = get_image_paths(path) + if not paths: + logger.error("Mask path '%s' contains no images", path) + sys.exit(1) + return paths + + def _warn_extra_masks(self, file_list: list[str]) -> None: + """ Generate a warning for each mask that exists that does not correspond to a match in the + source input + + Parameters + ---------- + file_list: list[str] + List of mask files that could not be mapped to a source image + """ + if not file_list: + logger.debug("All masks exist in the source data") + return + + for fname in file_list: + logger.warning("Extra mask file found: '%s'", os.path.basename(fname)) + + logger.warning("%s mask file(s) do not exist in the source data so will not be imported " + "(see above)", len(file_list)) + + def _file_list_to_frame_number(self, file_list: list[str]) -> dict[int, str]: + """ Extract frame numbers from mask file names and return as a dictionary + + Parameters + ---------- + file_list: list[str] + List of full paths to masks to extract frame number from + + Returns + ------- + dict[int, str] + Dictionary of frame numbers to filenames + """ + retval: dict[int, str] = {} + for filename in file_list: + frame_num = self._re_frame_num.findall(os.path.splitext(os.path.basename(filename))[0]) + + if not frame_num or len(frame_num) > 1: + logger.error("Could not detect frame number from mask file '%s'. " + "Check your filenames", os.path.basename(filename)) + sys.exit(1) + + fnum = int(frame_num[0]) + + if fnum in retval: + logger.error("Frame number %s for mask file '%s' already exists from file: '%s'. " + "Check your filenames", + fnum, os.path.basename(filename), os.path.basename(retval[fnum])) + sys.exit(1) + + retval[fnum] = filename + + logger.debug("Files: %s, frame_numbers: %s", len(file_list), len(retval)) + + return retval + + def _map_video(self, file_list: list[str], source_files: list[str]) -> dict[str, str]: + """ Generate the mapping between the source data and the masks to be imported for + video sources + + Parameters + ---------- + file_list: list[str] + List of full paths to masks to be imported + source_files: list[str] + list of filenames withing the source file + + Returns + ------- + dict[str, str] + Source filenames mapped to full path location of mask to be imported + """ + retval = {} + unmapped = [] + mask_frames = self._file_list_to_frame_number(file_list) + for filename in tqdm(source_files, desc="Mapping masks to input", leave=False): + src_idx = int(os.path.splitext(filename)[0].rsplit("_", maxsplit=1)[-1]) + mapped = mask_frames.pop(src_idx, "") + if not mapped: + unmapped.append(filename) + continue + retval[os.path.basename(filename)] = mapped + + if len(unmapped) == len(source_files): + logger.error("No masks map between the source data and the mask folder. " + "Check your filenames") + sys.exit(1) + + self._warn_extra_masks(list(mask_frames.values())) + logger.debug("Source: %s, Mask: %s, Mapped: %s", + len(source_files), len(file_list), len(retval)) + return retval + + def _map_images(self, file_list: list[str], source_files: list[str]) -> dict[str, str]: + """ Generate the mapping between the source data and the masks to be imported for + folder of image sources + + Parameters + ---------- + file_list: list[str] + List of full paths to masks to be imported + source_files: list[str] + list of filenames withing the source file + + Returns + ------- + dict[str, str] + Source filenames mapped to full path location of mask to be imported + """ + mask_count = len(file_list) + retval = {} + unmapped = [] + for filename in tqdm(source_files, desc="Mapping masks to input", leave=False): + fname = os.path.splitext(os.path.basename(filename))[0] + mapped = next((f for f in file_list + if os.path.splitext(os.path.basename(f))[0] == fname), "") + if not mapped: + unmapped.append(filename) + continue + retval[os.path.basename(filename)] = file_list.pop(file_list.index(mapped)) + + if len(unmapped) == len(source_files): + logger.error("No masks map between the source data and the mask folder. " + "Check your filenames") + sys.exit(1) + + self._warn_extra_masks(file_list) + + logger.debug("Source: %s, Mask: %s, Mapped: %s", + len(source_files), mask_count, len(retval)) + return retval + + def _generate_mapping(self, import_path: str, loader: Loader) -> dict[str, str]: + """ Generate the mapping between the source data and the masks to be imported + + Parameters + ---------- + import_path: str + The path to the input images + loader: :class:`~tools.mask.loader.Loader` + The source file loader object + + Returns + ------- + dict[str, str] + Source filenames mapped to full path location of mask to be imported + """ + file_list = self._get_file_list(import_path) + if loader.is_video: + retval = self._map_video(file_list, loader.file_list) + else: + retval = self._map_images(file_list, loader.file_list) + + return retval + + def _store_mask(self, face: DetectedFace, mask: np.ndarray) -> None: + """ Store the mask to the given DetectedFace object + + Parameters + ---------- + face: :class:`~lib.align.detected_face.DetectedFace` + The detected face object to store the mask to + mask: :class:`numpy.ndarray` + The mask to store + """ + aligned = AlignedFace(face.landmarks_xy, + mask[..., None] if self._is_faces else mask, + centering=self._centering, + size=self._size, + is_aligned=self._is_faces, + dtype="float32") + assert aligned.face is not None + face.add_mask(f"custom_{self._centering}", + aligned.face / 255., + aligned.adjusted_matrix, + aligned.interpolators[1], + storage_size=self._size, + storage_centering=self._centering) + + def _store_mask_face(self, media: ExtractMedia, mask: np.ndarray) -> None: + """ Store the mask when the input is aligned faceswap faces + + Parameters + ---------- + media: :class:`~plugins.extract.extract_media.ExtractMedia` + The extract media object containing the face(s) to import the mask for + + mask: :class:`numpy.ndarray` + The mask loaded from disk + """ + assert self._saver is not None + assert len(media.detected_faces) == 1 + + logger.trace("Adding mask for '%s'", media.filename) # type:ignore[attr-defined] + + face = media.detected_faces[0] + self._store_mask(face, mask) + + if self._alignments is not None: + idx = media.frame_metadata["source_filename"] + fname = media.frame_metadata["face_index"] + logger.trace("Updating face %s in frame '%s'", idx, fname) # type:ignore[attr-defined] + self._alignments.update_face(idx, + fname, + face.to_alignment()) + + logger.trace("Updating extracted face: '%s'", media.filename) # type:ignore[attr-defined] + meta: align.alignments.PNGHeaderDict = {"alignments": face.to_png_meta(), + "source": media.frame_metadata} + self._saver.save(media.filename, encode_image(media.image, ".png", metadata=meta)) + + @classmethod + def _resize_mask(cls, mask: np.ndarray, dims: tuple[int, int]) -> np.ndarray: + """ Resize a mask to the given dimensions + + Parameters + ---------- + mask: :class:`numpy.ndarray` + The mask to resize + dims: tuple[int, int] + The (height, width) target size + + Returns + ------- + :class:`numpy.ndarray` + The resized mask, or the original mask if no resizing required + """ + if mask.shape[:2] == dims: + return mask + logger.trace("Resizing mask from %s to %s", mask.shape, dims) # type:ignore[attr-defined] + interp = cv2.INTER_AREA if mask.shape[0] > dims[0] else cv2.INTER_CUBIC + + mask = cv2.resize(mask, tuple(reversed(dims)), interpolation=interp) + return mask + + def _store_mask_frame(self, media: ExtractMedia, mask: np.ndarray) -> None: + """ Store the mask when the input is frames + + Parameters + ---------- + media: :class:`~plugins.extract.extract_media.ExtractMedia` + The extract media object containing the face(s) to import the mask for + + mask: :class:`numpy.ndarray` + The mask loaded from disk + """ + assert self._alignments is not None + logger.trace("Adding %s mask(s) for '%s'", # type:ignore[attr-defined] + len(media.detected_faces), media.filename) + + mask = self._resize_mask(mask, media.image_size) + + for idx, face in enumerate(media.detected_faces): + self._store_mask(face, mask) + self._alignments.update_face(os.path.basename(media.filename), + idx, + face.to_alignment()) + + def import_mask(self, media: ExtractMedia) -> None: + """ Import the mask for the given Extract Media object + + Parameters + ---------- + media: :class:`~plugins.extract.extract_media.ExtractMedia` + The extract media object containing the face(s) to import the mask for + """ + mask_file = self._mapping.get(os.path.basename(media.filename)) + if not mask_file: + self._counts["skip"] += 1 + logger.warning("No mask file found for: '%s'", os.path.basename(media.filename)) + return + + mask = T.cast("np.ndarray", cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)) + + logger.trace("Loaded mask for frame '%s': %s", # type:ignore[attr-defined] + os.path.basename(mask_file), mask.shape) + + self._counts["update"] += len(media.detected_faces) + + if self._is_faces: + self._store_mask_face(media, mask) + else: + self._store_mask_frame(media, mask) + + +__all__ = get_module_objects(__name__) diff --git a/tools/mask/mask_output.py b/tools/mask/mask_output.py new file mode 100644 index 0000000000..bf10f98681 --- /dev/null +++ b/tools/mask/mask_output.py @@ -0,0 +1,523 @@ +#!/usr/bin/env python3 +""" Output processing for faceswap's mask tool """ +from __future__ import annotations + +import logging +import os +import sys +import typing as T +from argparse import Namespace + +import cv2 +import numpy as np +from tqdm import tqdm + +from lib.align import AlignedFace +from lib.align.alignments import AlignmentDict + +from lib.image import ImagesSaver, read_image_meta_batch +from lib.utils import get_folder, get_module_objects +from scripts.fsmedia import Alignments as ExtractAlignments + +if T.TYPE_CHECKING: + from lib import align + from lib.align import DetectedFace + from lib.align.aligned_face import CenteringType + +logger = logging.getLogger(__name__) + + +class Output: + """ Handles outputting of masks for preview/editting to disk + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments that the mask tool was called with + alignments: :class:`~lib.align.alignments.Alignments` | None + The alignments file object (or ``None`` if not provided and input is faces) + file_list: list[str] + Full file list for the loader. Used for extracting alignments from faces + """ + def __init__(self, arguments: Namespace, + alignments: align.alignments.Alignments | None, + file_list: list[str]) -> None: + logger.debug("Initializing %s (arguments: %s, alignments: %s, file_list: %s)", + self.__class__.__name__, arguments, alignments, len(file_list)) + + self._blur_kernel: int = arguments.blur_kernel + self._threshold: int = arguments.threshold + self._type: T.Literal["combined", "masked", "mask"] = arguments.output_type + self._full_frame: bool = arguments.full_frame + self._mask_type = arguments.masker + self._centering: CenteringType = arguments.centering + + self._input_is_faces = arguments.input_type == "faces" + self._saver = self._set_saver(arguments.output, arguments.processing) + self._alignments = self._get_alignments(alignments, file_list) + + self._full_frame_cache: dict[str, list[tuple[int, DetectedFace]]] = {} + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def should_save(self) -> bool: + """bool: ``True`` if mask images should be output otherwise ``False`` """ + return self._saver is not None + + def _get_subfolder(self, output: str) -> str: + """ Obtain a subfolder within the output folder to save the output based on selected + output options. + + Parameters + ---------- + output: str + Full path to the root output folder + + Returns + ------- + str: + The full path to where masks should be saved + """ + out_type = "frame" if self._full_frame else "face" + retval = os.path.join(output, + f"{self._mask_type}_{out_type}_{self._type}") + logger.info("Saving masks to '%s'", retval) + return retval + + def _set_saver(self, output: str | None, processing: str) -> ImagesSaver | None: + """ set the saver in a background thread + + Parameters + ---------- + output: str + Full path to the root output folder if provided + processing: str + The processing that has been selected + + Returns + ------- + ``None`` or :class:`lib.image.ImagesSaver`: + If output is requested, returns a :class:`lib.image.ImagesSaver` otherwise + returns ``None`` + """ + if output is None or not output: + if processing == "output": + logger.error("Processing set as 'output' but no output folder provided.") + sys.exit(0) + logger.debug("No output provided. Not creating saver") + return None + output_dir = get_folder(self._get_subfolder(output), make_folder=True) + retval = ImagesSaver(output_dir) + logger.debug(retval) + return retval + + def _get_alignments(self, + alignments: align.alignments.Alignments | None, + file_list: list[str]) -> align.alignments.Alignments | None: + """ Obtain the alignments file. If input is faces and full frame output is requested then + the file needs to be generated from the input faces, if not provided + + Parameters + ---------- + alignments: :class:`~lib.align.alignments.Alignments` | None + The alignments file object (or ``None`` if not provided and input is faces) + file_list: list[str] + Full paths to ihe mask tool input files + + Returns + ------- + :class:`~lib.align.alignments.Alignments` | None + The alignments file if provided and/or is required otherwise ``None`` + """ + if alignments is not None or not self._full_frame: + return alignments + logger.debug("Generating alignments from faces") + + data = T.cast(dict[str, AlignmentDict], {}) + for _, meta in tqdm(read_image_meta_batch(file_list), + desc="Reading alignments from faces", + total=len(file_list), + leave=False): + fname = meta["itxt"]["source"]["source_filename"] + aln = meta["itxt"]["alignments"] + data.setdefault(fname, {}).setdefault("faces", # type:ignore[typeddict-item] + []).append(aln) + + dummy_args = Namespace(alignments_path="/dummy/alignments.fsa") + retval = ExtractAlignments(dummy_args, is_extract=True) + retval.update_from_dict(data) + return retval + + def _get_background_frame(self, detected_faces: list[DetectedFace], frame_dims: tuple[int, int] + ) -> np.ndarray: + """ Obtain the background image when final output is in full frame format. There will only + ever be one background, even when there are multiple faces + + The output image will depend on the requested output type and whether the input is faces + or frames + + Parameters + ---------- + detected_faces: list[:class:`~lib.align.detected_face.DetectedFace`] + Detected face objects for the output image + frame_dims: tuple[int, int] + The size of the original frame + + Returns + ------- + :class:`numpy.ndarray` + The full frame background image for applying masks to + """ + if self._type == "mask": + return np.zeros(frame_dims, dtype="uint8") + + if not self._input_is_faces: # Frame is in the detected faces object + assert detected_faces[0].image is not None + return np.ascontiguousarray(detected_faces[0].image) + + # Outputting to frames, but input is faces. Apply the face patches to an empty canvas + retval = np.zeros((*frame_dims, 3), dtype="uint8") + for detected_face in detected_faces: + assert detected_face.image is not None + face = AlignedFace(detected_face.landmarks_xy, + image=detected_face.image, + centering="head", + size=detected_face.image.shape[0], + is_aligned=True) + border = cv2.BORDER_TRANSPARENT if len(detected_faces) > 1 else cv2.BORDER_CONSTANT + assert face.face is not None + cv2.warpAffine(face.face, + face.adjusted_matrix, + tuple(reversed(frame_dims)), + retval, + flags=cv2.WARP_INVERSE_MAP | face.interpolators[1], + borderMode=border) + return retval + + def _get_background_face(self, + detected_face: DetectedFace, + mask_centering: CenteringType, + mask_size: int) -> np.ndarray: + """ Obtain the background images when the output is faces + + The output image will depend on the requested output type and whether the input is faces + or frames + + Parameters + ---------- + detected_face: :class:`~lib.align.detected_face.DetectedFace` + Detected face object for the output image + mask_centering: Literal["face", "head", "legacy"] + The centering of the stored mask + mask_size: int + The pixel size of the stored mask + + Returns + ------- + list[]:class:`numpy.ndarray`] + The face background image for applying masks to for each detected face object + """ + if self._type == "mask": + return np.zeros((mask_size, mask_size), dtype="uint8") + + assert detected_face.image is not None + + if self._input_is_faces: + retval = AlignedFace(detected_face.landmarks_xy, + image=detected_face.image, + centering=mask_centering, + size=mask_size, + is_aligned=True).face + else: + centering: CenteringType = ("legacy" if self._alignments is not None and + self._alignments.version == 1.0 + else mask_centering) + detected_face.load_aligned(detected_face.image, + size=mask_size, + centering=centering, + force=True) + retval = detected_face.aligned.face + + assert retval is not None + return retval + + def _get_background(self, + detected_faces: list[DetectedFace], + frame_dims: tuple[int, int], + mask_centering: CenteringType, + mask_size: int) -> np.ndarray: + """ Obtain the background image that the final outut will be placed on + + Parameters + ---------- + detected_faces: list[:class:`~lib.align.detected_face.DetectedFace`] + Detected face objects for the output image + frame_dims: tuple[int, int] + The size of the original frame + mask_centering: Literal["face", "head", "legacy"] + The centering of the stored mask + mask_size: int + The pixel size of the stored mask + + Returns + ------- + :class:`numpy.ndarray` + The background image for the mask output + """ + if self._full_frame: + retval = self._get_background_frame(detected_faces, frame_dims) + else: + assert len(detected_faces) == 1 # If outputting faces, we should only receive 1 face + retval = self._get_background_face(detected_faces[0], mask_centering, mask_size) + + logger.trace("Background image (size: %s, dtype: %s)", # type:ignore[attr-defined] + retval.shape, retval.dtype) + return retval + + def _get_mask(self, + detected_faces: list[DetectedFace], + mask_type: str, + mask_dims: tuple[int, int]) -> np.ndarray: + """ Generate the mask to be applied to the final output frame + + Parameters + ---------- + detected_faces: list[:class:`~lib.align.detected_face.DetectedFace`] + Detected face objects to generate the masks from + mask_type: str + The mask-type to use + mask_dims : tuple[int, int] + The size of the mask to output + + Returns + ------- + :class:`numpy.ndarray` + The final mask to apply to the output image + """ + retval = np.zeros(mask_dims, dtype="uint8") + for face in detected_faces: + mask_object = face.mask[mask_type] + mask_object.set_blur_and_threshold(blur_kernel=self._blur_kernel, + threshold=self._threshold) + if self._full_frame: + mask = mask_object.get_full_frame_mask(*reversed(mask_dims)) + else: + mask = mask_object.mask[..., 0] + np.maximum(retval, mask, out=retval) + logger.trace("Final mask (shape: %s, dtype: %s)", # type:ignore[attr-defined] + retval.shape, retval.dtype) + return retval + + def _build_output_image(self, background: np.ndarray, mask: np.ndarray) -> np.ndarray: + """ Collate the mask and images for the final output image, depending on selected output + type + + Parameters + ---------- + background: :class:`numpy.ndarray` + The image that the mask will be applied to + mask: :class:`numpy.ndarray` + The mask to output + + Returns + ------- + :class:`numpy.ndarray` + The final output image + """ + if self._type == "mask": + return mask + + mask = mask[..., None] + if self._type == "masked": + return np.concatenate([background, mask], axis=-1) + + height, width = background.shape[:2] + masked = (background.astype("float32") * mask.astype("float32") / 255.).astype("uint8") + mask = np.tile(mask, 3) + for img in (background, masked, mask): + cv2.rectangle(img, (0, 0), (width - 1, height - 1), (255, 255, 255), 1) + axis = 0 if background.shape[0] < background.shape[1] else 1 + retval = np.concatenate((background, masked, mask), axis=axis) + + return retval + + def _create_image(self, + detected_faces: list[DetectedFace], + mask_type: str, + frame_dims: tuple[int, int] | None) -> np.ndarray: + """ Create a mask preview image for saving out to disk + + Parameters + ---------- + detected_faces: list[:class:`~lib.align.detected_face.DetectedFace`] + Detected face objects for the output image + mask_type: str + The mask_type to process + frame_dims: tuple[int, int] | None + The size of the original frame, if input is faces otherwise ``None`` + + Returns + ------- + :class:`numpy.ndarray`: + A preview image depending on the output type in one of the following forms: + - Containing 3 sub images: The original face, the masked face and the mask + - The mask only + - The masked face + """ + assert detected_faces[0].image is not None + dims = T.cast(tuple[int, int], + frame_dims if self._input_is_faces else detected_faces[0].image.shape[:2]) + assert dims is not None and len(dims) == 2 + + mask_centering = detected_faces[0].mask[mask_type].stored_centering + mask_size = detected_faces[0].mask[mask_type].stored_size + + background = self._get_background(detected_faces, dims, mask_centering, mask_size) + mask = self._get_mask(detected_faces, + mask_type, + dims if self._full_frame else (mask_size, mask_size)) + retval = self._build_output_image(background, mask) + + logger.trace("Output image (shape: %s, dtype: %s)", # type:ignore[attr-defined] + retval.shape, retval.dtype) + return retval + + def _handle_cache(self, + frame: str, + idx: int, + detected_face: DetectedFace) -> list[tuple[int, DetectedFace]]: + """ For full frame output, cache any faces until all detected faces have been seen. For + face output, just return the detected_face object inside a list + + Parameters + ---------- + frame: str + The frame name in the alignments file + idx: int + The index of the face for this frame in the alignments file + detected_face: :class:`~lib.align.detected_face.DetectedFace` + A detected_face object for a face + + Returns + ------- + list[tuple[int, :class:`~lib.align.detected_face.DetectedFace`]] + Face index and detected face objects to be processed for this output, if any + """ + if not self._full_frame: + return [(idx, detected_face)] + + assert self._alignments is not None + faces_in_frame = self._alignments.count_faces_in_frame(frame) + if faces_in_frame == 1: + return [(idx, detected_face)] + + self._full_frame_cache.setdefault(frame, []).append((idx, detected_face)) + + if len(self._full_frame_cache[frame]) != faces_in_frame: + logger.trace("Caching face for frame '%s'", frame) # type:ignore[attr-defined] + return [] + + retval = self._full_frame_cache.pop(frame) + logger.trace("Processing '%s' from cache: %s", frame, retval) # type:ignore[attr-defined] + return retval + + def _get_mask_types(self, + frame: str, + detected_faces: list[tuple[int, DetectedFace]]) -> list[str]: + """ Get the mask type names for the select mask type. Remove any detected faces where + the selected mask does not exist + + Parameters + ---------- + frame: str + The frame name in the alignments file + idx: int + The index of the face for this frame in the alignments file + detected_face: list[tuple[int, :class:`~lib.align.detected_face.DetectedFace`] + The face index and detected_face object for output + + Returns + ------- + list[str] + List of mask type names to be processed + """ + if self._mask_type == "bisenet-fp": + mask_types = [f"{self._mask_type}_{area}" for area in ("face", "head")] + else: + mask_types = [self._mask_type] + + if self._mask_type == "custom": + mask_types.append(f"{self._mask_type}_{self._centering}") + + final_masks = set() + for idx in reversed(range(len(detected_faces))): + face_idx, detected_face = detected_faces[idx] + if detected_face.mask is None or not any(mask in detected_face.mask + for mask in mask_types): + logger.warning("Mask type '%s' does not exist for frame '%s' index %s. Skipping", + self._mask_type, frame, face_idx) + del detected_faces[idx] + continue + final_masks.update([m for m in detected_face.mask if m in mask_types]) + + retval = list(final_masks) + logger.trace("Handling mask types: %s", retval) # type:ignore[attr-defined] + return retval + + def save(self, + frame: str, + idx: int, + detected_face: DetectedFace, + frame_dims: tuple[int, int] | None = None) -> None: + """ Build the mask preview image and save + + Parameters + ---------- + frame: str + The frame name in the alignments file + idx: int + The index of the face for this frame in the alignments file + detected_face: :class:`~lib.align.detected_face.DetectedFace` + A detected_face object for a face + frame_dims: tuple[int, int] | None, optional + The size of the original frame, if input is faces otherwise ``None``. Default: ``None`` + """ + assert self._saver is not None + + faces = self._handle_cache(frame, idx, detected_face) + if not faces: + return + + mask_types = self._get_mask_types(frame, faces) + if not faces or not mask_types: + logger.debug("No valid faces/masks to process for '%s'", frame) + return + + for mask_type in mask_types: + detected_faces = [f[1] for f in faces if mask_type in f[1].mask] + if not detected_face: + logger.warning("No '%s' masks to output for '%s'", mask_type, frame) + continue + if len(detected_faces) != len(faces): + logger.warning("Some '%s' masks are missing for '%s'", mask_type, frame) + + image = self._create_image(detected_faces, mask_type, frame_dims) + filename = os.path.splitext(frame)[0] + if len(mask_types) > 1: + filename += f"_{mask_type}" + if not self._full_frame: + filename += f"_{idx}" + filename = os.path.join(self._saver.location, f"{filename}.png") + logger.trace("filename: '%s', image_shape: %s", filename, image.shape) # type: ignore + self._saver.save(filename, image) + + def close(self) -> None: + """ Shut down the image saver if it is open """ + if self._saver is None: + return + logger.debug("Shutting down saver") + self._saver.close() + + +__all__ = get_module_objects(__name__) diff --git a/tools/model/__init__.py b/tools/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/model/cli.py b/tools/model/cli.py new file mode 100644 index 0000000000..d7be707daf --- /dev/null +++ b/tools/model/cli.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +""" Command Line Arguments for tools """ +import gettext +import typing as T + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import DirFullPaths, Radio +from lib.utils import get_module_objects + +# LOCALES +_LANG = gettext.translation("tools.restore.cli", localedir="locales", fallback=True) +_ = _LANG.gettext + +_HELPTEXT = _("This tool lets you perform actions on saved Faceswap models.") + + +class ModelArgs(FaceSwapArgs): + """ Class to perform actions on model files """ + + @staticmethod + def get_info() -> str: + """ Return command information """ + return _("A tool for performing actions on Faceswap trained model files") + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Put the arguments in a list so that they are accessible from both argparse and gui """ + argument_list = [] + argument_list.append({ + "opts": ("-m", "--model-dir"), + "action": DirFullPaths, + "dest": "model_dir", + "required": True, + "help": _( + "Model directory. A directory containing the model you wish to perform an action " + "on.")}) + argument_list.append({ + "opts": ("-j", "--job"), + "action": Radio, + "type": str, + "choices": ("inference", "nan-scan", "restore"), + "required": True, + "help": _( + "R|Choose which action you want to perform." + "\nL|'inference' - Create an inference only copy of the model. Strips any layers " + "from the model which are only required for training. NB: This is for exporting " + "the model for use in external applications. Inference generated models cannot be " + "used within Faceswap. See the 'format' option for specifying the model output " + "format." + "\nL|'nan-scan' - Scan the model file for NaNs or Infs (invalid data)." + "\nL|'restore' - Restore a model from backup.")}) + argument_list.append({ + "opts": ("-s", "--swap-model"), + "action": "store_true", + "dest": "swap_model", + "default": False, + "group": _("inference"), + "help": _( + "Only used for 'inference' job. Generate the inference model for B -> A instead " + "of A -> B.")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/model/model.py b/tools/model/model.py new file mode 100644 index 0000000000..80c1520a51 --- /dev/null +++ b/tools/model/model.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" Tool to restore models from backup """ +from __future__ import annotations +import logging +import os +import sys +import typing as T + +from keras import saving +import numpy as np +import keras + + +from lib.model.backup_restore import Backup + +from lib.logger import parse_class_init +# Import the following libs for custom objects +from lib.model import initializers, layers, normalization # noqa # pylint:disable=unused-import +from lib.utils import get_module_objects +from plugins.train.model._base.model import Inference as FSInference + + +if T.TYPE_CHECKING: + import argparse + +logger = logging.getLogger(__name__) + + +class Model(): + """ Tool to perform actions on a model file. + + Parameters + ---------- + :class:`argparse.Namespace` + The command line arguments calling the model tool + """ + def __init__(self, arguments: argparse.Namespace) -> None: + logger.debug(parse_class_init(locals())) + self._model_dir = self._check_folder(arguments.model_dir) + self._job = self._get_job(arguments) + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _get_job(cls, arguments: argparse.Namespace) -> Inference | NaNScan | Restore: + """ Get the correct object that holds the selected job. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments received for the Model tool which will be used to initiate + the selected job + + Returns + ------- + :class:`Inference` | :class:`NaNScan` | :class:`Restore` + The object that will perform the selected job + """ + jobs: dict[str, T.Type[Inference | NaNScan | Restore]] = { + "inference": Inference, + "nan-scan": NaNScan, + "restore": Restore} + return jobs[arguments.job](arguments) + + @classmethod + def _check_folder(cls, model_dir: str) -> str: + """ Check that the passed in model folder exists and contains a valid model. + + If the passed in value fails any checks, process exits. + + Parameters + ---------- + model_dir: str + The model folder to be checked + + Returns + ------- + str + The confirmed location of the model folder. + """ + if not os.path.exists(model_dir): + logger.error("Model folder does not exist: '%s'", model_dir) + sys.exit(1) + + chkfiles = [fname + for fname in os.listdir(model_dir) + if fname.endswith(".keras") + and not os.path.splitext(fname)[0].endswith("_inference")] + + if not chkfiles: + logger.error("Could not find a model in the supplied folder: '%s'", model_dir) + sys.exit(1) + + if len(chkfiles) > 1: + logger.error("More than one model file found in the model folder: '%s'", model_dir) + sys.exit(1) + + model_name = os.path.splitext(chkfiles[0])[0].title() + logger.info("%s Model found", model_name) + return model_dir + + def process(self) -> None: + """ Call the selected model job.""" + self._job.process() + + +class Inference(): + """ Save an inference model from a trained Faceswap model. + + Parameters + ---------- + :class:`argparse.Namespace` + The command line arguments calling the model tool + """ + def __init__(self, arguments: argparse.Namespace) -> None: + logger.debug(parse_class_init(locals())) + self._switch = arguments.swap_model + self._input_file, self._output_file = self._get_output_file(arguments.model_dir) + logger.debug("Initialized %s", self.__class__.__name__) + + def _get_output_file(self, model_dir: str) -> tuple[str, str]: + """ Obtain the full path for the output model file/folder + + Parameters + ---------- + model_dir: str + The full path to the folder containing the Faceswap trained model .keras file + + Returns + ------- + str + The full path to the source model file + str + The full path to the inference model save location + """ + model_name = next(fname for fname in os.listdir(model_dir) + if fname.endswith(".keras") + and not fname.endswith("_inference.keras")) + in_path = os.path.join(model_dir, model_name) + logger.debug("Model input path: '%s'", in_path) + + model_name = f"{os.path.splitext(model_name)[0]}_inference.keras" + out_path = os.path.join(model_dir, model_name) + logger.debug("Inference output path: '%s'", out_path) + return in_path, out_path + + def process(self) -> None: + """ Run the inference model creation process. """ + logger.info("Loading model '%s'", self._input_file) + model = saving.load_model(self._input_file, compile=False) + logger.info("Creating inference model...") + inference = FSInference(model, self._switch).model + logger.info("Saving to: '%s'", self._output_file) + inference.save(self._output_file) + + +class NaNScan(): + """ Tool to scan for NaN and Infs in model weights. + + Parameters + ---------- + :class:`argparse.Namespace` + The command line arguments calling the model tool + """ + def __init__(self, arguments: argparse.Namespace) -> None: + logger.debug(parse_class_init(locals())) + self._model_file = self._get_model_filename(arguments.model_dir) + logger.debug("Initialized %s", self.__class__.__name__) + + @classmethod + def _get_model_filename(cls, model_dir: str) -> str: + """ Obtain the full path the model's .keras file. + + Parameters + ---------- + model_dir: str + The full path to the folder containing the model file + + Returns + ------- + str + The full path to the saved model file + """ + model_file = next(fname for fname in os.listdir(model_dir) if fname.endswith(".keras")) + return os.path.join(model_dir, model_file) + + def _parse_weights(self, + layer: keras.models.Model | keras.layers.Layer) -> dict: + """ Recursively pass through sub-models to scan layer weights""" + weights = layer.get_weights() + logger.debug("Processing weights for layer '%s', length: '%s'", + layer.name, len(weights)) + + if not weights: + logger.debug("Skipping layer with no weights: %s", layer.name) + return {} + + if hasattr(layer, "layers"): # Must be a submodel + retval = {} + for lyr in layer.layers: + info = self._parse_weights(lyr) + if not info: + continue + retval[lyr.name] = info + return retval + + nans = sum(np.count_nonzero(np.isnan(w)) for w in weights) + infs = sum(np.count_nonzero(np.isinf(w)) for w in weights) + + if nans + infs == 0: + return {} + return {"nans": nans, "infs": infs} + + def _parse_output(self, errors: dict, indent: int = 0) -> None: + """ Parse the output of the errors dictionary and print a pretty summary. + + Parameters + ---------- + errors: dict + The nested dictionary of errors found when parsing the weights + + indent: int, optional + How far should the current printed line be indented. Default: `0` + """ + for key, val in errors.items(): + logline = f"|{'--' * indent} " + logline += key.ljust(50 - len(logline)) + if isinstance(val, dict) and "nans" not in val: + logger.info(logline) + self._parse_output(val, indent + 1) + elif isinstance(val, dict) and "nans" in val: + logline += f"nans: {val['nans']}, infs: {val['infs']}" + logger.info(logline.ljust(30)) + + def process(self) -> None: + """ Scan the loaded model for NaNs and Infs and output summary. """ + logger.info("Loading model...") + model = saving.load_model(self._model_file, compile=False) + logger.info("Parsing weights for invalid values...") + errors = self._parse_weights(model) + + if not errors: + logger.info("No invalid values found in model: '%s'", self._model_file) + sys.exit(1) + + logger.info("Invalid values found in model: %s", self._model_file) + self._parse_output(errors) + + +class Restore(): + """ Restore a model from backup. + + Parameters + ---------- + :class:`argparse.Namespace` + The command line arguments calling the model tool + """ + def __init__(self, arguments: argparse.Namespace) -> None: + logger.debug(parse_class_init(locals())) + self._model_dir = arguments.model_dir + self._model_name = self._get_model_name() + logger.debug("Initialized %s", self.__class__.__name__) + + def process(self) -> None: + """ Perform the Restore process """ + logger.info("Starting Model Restore...") + backup = Backup(self._model_dir, self._model_name) + backup.restore() + logger.info("Completed Model Restore") + + def _get_model_name(self) -> str: + """ Additional checks to make sure that a backup exists in the model location. """ + bkfiles = [fname for fname in os.listdir(self._model_dir) if fname.endswith(".bk")] + if not bkfiles: + logger.error("Could not find any backup files in the supplied folder: '%s'", + self._model_dir) + sys.exit(1) + logger.verbose("Backup files: %s)", bkfiles) # type:ignore[attr-defined] + + ext = ".keras.bk" + model_name = next(fname for fname in bkfiles if fname.endswith(ext)) + return model_name[:-len(ext)] + + +__all__ = get_module_objects(__name__) diff --git a/tools/preview/__init__.py b/tools/preview/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/preview/cli.py b/tools/preview/cli.py new file mode 100644 index 0000000000..d2cbf75d14 --- /dev/null +++ b/tools/preview/cli.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" Command Line Arguments for tools """ +import gettext +import typing as T + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import DirOrFileFullPaths, DirFullPaths, FileFullPaths +from lib.utils import get_module_objects + +# pylint:disable=duplicate-code +# LOCALES +_LANG = gettext.translation("tools.preview", localedir="locales", fallback=True) +_ = _LANG.gettext + + +_HELPTEXT = _("This command allows you to preview swaps to tweak convert settings.") + + +class PreviewArgs(FaceSwapArgs): + """ Class to parse the command line arguments for Preview (Convert Settings) tool """ + + @staticmethod + def get_info() -> str: + """ Return command information + + Returns + ------- + str + Top line information about the Preview tool + """ + return _("Preview tool\nAllows you to configure your convert settings with a live preview") + + @staticmethod + def get_argument_list() -> list[dict[str, T.Any]]: + """ Put the arguments in a list so that they are accessible from both argparse and gui + + Returns + ------- + list[dict[str, Any]] + Top command line options for the preview tool + """ + argument_list = [] + argument_list.append({ + "opts": ("-i", "--input-dir"), + "action": DirOrFileFullPaths, + "filetypes": "video", + "dest": "input_dir", + "group": _("data"), + "required": True, + "help": _( + "Input directory or video. Either a directory containing the image files you wish " + "to process or path to a video file.")}) + argument_list.append({ + "opts": ("-a", "--alignments"), + "action": FileFullPaths, + "filetypes": "alignments", + "type": str, + "group": _("data"), + "dest": "alignments_path", + "help": _( + "Path to the alignments file for the input, if not at the default location")}) + argument_list.append({ + "opts": ("-m", "--model-dir"), + "action": DirFullPaths, + "dest": "model_dir", + "group": _("data"), + "required": True, + "help": _( + "Model directory. A directory containing the trained model you wish to process.")}) + argument_list.append({ + "opts": ("-s", "--swap-model"), + "action": "store_true", + "dest": "swap_model", + "default": False, + "help": _("Swap the model. Instead of A -> B, swap B -> A")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/preview/control_panels.py b/tools/preview/control_panels.py new file mode 100644 index 0000000000..6a6947515c --- /dev/null +++ b/tools/preview/control_panels.py @@ -0,0 +1,665 @@ +#!/usr/bin/env python3 +"""Manages the widgets that hold the bottom 'control' area of the preview tool.""" +from __future__ import annotations +import gettext +import logging +import typing as T + +import tkinter as tk + +from tkinter import ttk + +from lib.gui.custom_widgets import Tooltip +from lib.gui.control_helper import ControlPanel, ControlPanelOption +from lib.logger import parse_class_init +from lib.gui.utils import get_images +from lib.utils import get_module_objects +from plugins.plugin_loader import PluginLoader +from plugins.convert import convert_config + +if T.TYPE_CHECKING: + from collections.abc import Callable + from .preview import Preview + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("tools.preview", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class ConfigTools(): + """Tools for loading, saving, setting and retrieving configuration file values. + + Parameters + ---------- + config_file : str | None + Path to a custom config .ini file or ``None`` to load the default config file + + Attributes + ---------- + tk_vars : dict[str, dict[str, tk.BooleanVar | tk.StringVar | tk.IntVar | tk.DoubleVar]]] + Global tkinter variables. `Refresh` and `Busy` :class:`tkinter.BooleanVar` + """ + + def __init__(self, config_file: str | None) -> None: + logger.debug(parse_class_init(locals())) + self._config = convert_config.load_config(config_file=config_file) + self.tk_vars: dict[str, dict[str, tk.Variable]] = {} + self._config_dicts = self._get_config_dicts() # Holds currently saved config + + @property + def config_dicts(self) -> dict[str, dict[str, ControlPanelOption]]: + """dict[str, dict[str, ControlPanelOption]] : The convert configuration options in + dictionary form.""" + return self._config_dicts + + @property + def sections(self) -> list[str]: + """list: The sorted section names that exist within the convert Configuration options.""" + return sorted(set(sect.split(".")[0] for sect in self._config.sections + if sect.split(".")[0] != "writer")) + + @property + def plugins_dict(self) -> dict[str, list[str]]: + """dict[str, list[str]] : Dictionary of configuration option sections as key with a list + of containing plugin names as the value""" + return {section: sorted([sect.split(".")[1] for sect in self._config.sections + if sect.split(".")[0] == section]) + for section in self.sections} + + def _get_config_dicts(self) -> dict[str, dict[str, ControlPanelOption]]: + """Obtain a custom configuration dictionary for convert configuration items in use + by the preview tool formatted for control helper. + + Returns + ------- + dict[str, str | dict[str, ControlPanelOption]] + Each configuration section as keys, with the values as a dict of option_name to + :class:`lib.gui.control_helper.ControlOption`.""" + logger.debug("Formatting Config for GUI") + config_dicts: dict[str, dict[str, ControlPanelOption]] = {} + for section_name, section in self._config.sections.items(): + if section_name.startswith("writer."): + continue + cp_options: dict[str, ControlPanelOption] = {} + for option_name, option in section.options.items(): + cp_option = ControlPanelOption.from_config_object(option_name, option) + cp_options[option_name] = cp_option + self.tk_vars.setdefault(section_name, {})[option_name] = cp_option.tk_var + config_dicts[section_name] = cp_options + logger.debug("Formatted Config for GUI: %s", config_dicts) + return config_dicts + + def update_config(self) -> None: + """Update :attr:`config` with the currently selected values from the GUI.""" + for section, options in self.tk_vars.items(): + for option_name, tk_option in options.items(): + try: + new_value = tk_option.get() + except tk.TclError as err: + # When manually filling in text fields, blank values will + # raise an error on numeric data types so return 0 + logger.trace( # type:ignore[attr-defined] + "Error getting value. Defaulting to 0. Error: %s", str(err)) + new_value = "" if isinstance(tk_option, tk.StringVar) else 0 + option = self._config.sections[section].options[option_name] + old_value = option.value + if new_value == old_value or (isinstance(old_value, list) and + set(str(new_value).split()) == set(old_value)): + logger.trace("Skipping unchanged option '%s'", # type:ignore[attr-defined] + option_name) + logger.debug("Updating config: '%s', '%s' from %s to %s", + section, option_name, repr(old_value), repr(new_value)) + option.set(new_value) + + def reset_config_to_saved(self, section: str | None = None) -> None: + """Reset the GUI parameters to their saved values within the configuration file. + + Parameters + ---------- + section : str | None, optional + The configuration section to reset the values for, If ``None`` provided then all + sections are reset. Default: ``None`` + """ + logger.debug("Resetting to saved config: %s", section) + sections = [section] if section is not None else list(self.tk_vars.keys()) + for section_name in sections: + for option_name, tk_option in self._config_dicts[section_name].items(): + val = tk_option.value + if val != self.tk_vars[section_name][option_name].get(): + self.tk_vars[section_name][option_name].set(val) + logger.debug("Setting '%s' - '%s' to saved value %s", + section_name, option_name, repr(val)) + logger.debug("Reset to saved config: %s", section) + + def reset_config_to_default(self, section: str | None = None) -> None: + """Reset the GUI parameters to their default configuration values. + + Parameters + ---------- + section : str | None, optional + The configuration section to reset the values for, If ``None`` provided then all + sections are reset. Default: ``None`` + """ + logger.debug("Resetting to default: %s", section) + sections = [section] if section is not None else list(self.tk_vars.keys()) + for section_name in sections: + for option_name, options in self._config_dicts[section_name].items(): + default = options.default + if default != self.tk_vars[section_name][option_name].get(): + self.tk_vars[section_name][option_name].set(default) + logger.debug("Setting '%s' - '%s' to default value %s", + section_name, option_name, repr(default)) + logger.debug("Reset to default: %s", section) + + def save_config(self, section: str | None = None) -> None: + """Save the configuration ``.ini`` file with the currently stored values. + + Parameters + ---------- + section : str | None, optional + The configuration section to save, If ``None`` provided then all sections are saved. + Default: ``None`` + """ + logger.debug("Saving %s config", section) + + for section_name, sect in self._config.sections.items(): + if section_name not in self._config_dicts: + logger.debug("[%s] Skipping section not in local config", section_name) + continue + if section is not None and section_name != section: + logger.debug("[%s] Skipping section not selected for saving", section_name) + continue + for option_name, option in sect.options.items(): + new_opt = self.tk_vars[section_name][option_name].get() + fmt_opt = str(new_opt).split() if isinstance(option.value, list) else new_opt + logger.debug("[%s] Setting '%s' to %s", section_name, option_name, repr(fmt_opt)) + option.set(new_opt) + + self._config.save_config() + + +class BusyProgressBar(): + """An infinite progress bar for when a thread is running to swap/patch a group of samples""" + def __init__(self, parent: ttk.Frame) -> None: + self._progress_bar = self._add_busy_indicator(parent) + + def _add_busy_indicator(self, parent: ttk.Frame) -> ttk.Progressbar: + """Place progress bar into bottom bar to indicate when processing. + + Parameters + ---------- + parent: tkinter object + The tkinter object that holds the busy indicator + + Returns + ------- + ttk.Progressbar + A Progress bar to indicate that the Preview tool is busy + """ + logger.debug("Placing busy indicator") + pbar = ttk.Progressbar(parent, mode="indeterminate") + pbar.pack(side=tk.LEFT) + pbar.pack_forget() + return pbar + + def stop(self) -> None: + """Stop and hide progress bar""" + logger.debug("Stopping busy indicator") + if not self._progress_bar.winfo_ismapped(): + logger.debug("busy indicator already hidden") + return + self._progress_bar.stop() + self._progress_bar.pack_forget() + + def start(self) -> None: + """Start and display progress bar""" + logger.debug("Starting busy indicator") + if self._progress_bar.winfo_ismapped(): + logger.debug("busy indicator already started") + return + + self._progress_bar.pack(side=tk.LEFT, padx=5, pady=(5, 10), fill=tk.X, expand=True) + self._progress_bar.start(25) + + +class ActionFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """Frame that holds the left hand side options panel containing the command line options. + + Parameters + ---------- + app: :class:`Preview` + The main tkinter Preview app + parent: tkinter object + The parent tkinter object that holds the Action Frame + """ + def __init__(self, app: Preview, parent: ttk.Frame) -> None: + logger.debug("Initializing %s: (app: %s, parent: %s)", + self.__class__.__name__, app, parent) + self._app = app + + super().__init__(parent) + self.pack(side=tk.LEFT, anchor=tk.N, fill=tk.Y) + self._tk_vars: dict[str, tk.Variable] = {} + + self._options = { + "color": app._patch.converter.cli_arguments.color_adjustment.replace("-", "_"), + "mask_type": app._patch.converter.cli_arguments.mask_type.replace("-", "_"), + "face_scale": app._patch.converter.cli_arguments.face_scale} + defaults = {opt: self._format_to_display(val) if opt != "face_scale" else val + for opt, val in self._options.items()} + self._busy_bar = self._build_frame(defaults, + app._samples.generate, + app._refresh, + app._samples.available_masks, + app._samples.predictor.has_predicted_mask) + + @property + def convert_args(self) -> dict[str, T.Any]: + """dict: Currently selected Command line arguments from the :class:`ActionFrame`.""" + retval = {opt if opt != "color" else "color_adjustment": + self._format_from_display(self._tk_vars[opt].get()) + for opt in self._options if opt != "face_scale"} + retval["face_scale"] = self._tk_vars["face_scale"].get() + return retval + + @property + def busy_progress_bar(self) -> BusyProgressBar: + """ + :class:`BusyProgressBar`: The progress bar that appears on the left hand side whilst a + swap/patch is being applied. + """ + return self._busy_bar + + @staticmethod + def _format_from_display(var: str) -> str: + """Format a variable from the display version to the command line action version. + + Parameters + ---------- + var: str + The variable name to format + + Returns + ------- + str + The formatted variable name + """ + return var.replace(" ", "_").lower() + + @staticmethod + def _format_to_display(var: str) -> str: + """Format a variable from the command line action version to the display version. + + Parameters + ---------- + var: str + The variable name to format + + Returns + ------- + str + The formatted variable name + """ + return var.replace("_", " ").replace("-", " ").title() + + def _build_frame(self, + defaults: dict[str, T.Any], + refresh_callback: Callable[[], None], + patch_callback: Callable[[], None], + available_masks: list[str], + has_predicted_mask: bool) -> BusyProgressBar: + """Build the :class:`ActionFrame`. + + Parameters + ---------- + defaults: dict + The default command line options + patch_callback: python function + The function to execute when a patch callback is received + refresh_callback: python function + The function to execute when a refresh callback is received + available_masks: list + The available masks that exist within the alignments file + has_predicted_mask: bool + Whether the model was trained with a mask + + Returns + ------- + ttk.Progressbar + A Progress bar to indicate that the Preview tool is busy + """ + logger.debug("Building Action frame") + + bottom_frame = ttk.Frame(self) + bottom_frame.pack(side=tk.BOTTOM, fill=tk.X, anchor=tk.S) + top_frame = ttk.Frame(self) + top_frame.pack(side=tk.TOP, fill=tk.BOTH, anchor=tk.N, expand=True) + + self._add_cli_choices(top_frame, defaults, available_masks, has_predicted_mask) + + busy_indicator = BusyProgressBar(bottom_frame) + self._add_refresh_button(bottom_frame, refresh_callback) + self._add_patch_callback(patch_callback) + self._add_actions(bottom_frame) + logger.debug("Built Action frame") + return busy_indicator + + def _add_cli_choices(self, + parent: ttk.Frame, + defaults: dict[str, T.Any], + available_masks: list[str], + has_predicted_mask: bool) -> None: + """Create :class:`lib.gui.control_helper.ControlPanel` object for the command line options. + + parent: :class:`ttk.Frame` + The frame to hold the command line choices + defaults: dict + The default command line options + available_masks: list + The available masks that exist within the alignments file + has_predicted_mask: bool + Whether the model was trained with a mask + """ + cp_options = self._get_control_panel_options(defaults, available_masks, has_predicted_mask) + panel_kwargs = {"blank_nones": False, "label_width": 10, "style": "CPanel"} + ControlPanel(parent, cp_options, header_text=None, **panel_kwargs) + + def _get_control_panel_options(self, + defaults: dict[str, T.Any], + available_masks: list[str], + has_predicted_mask: bool) -> list[ControlPanelOption]: + """Create :class:`lib.gui.control_helper.ControlPanelOption` objects for the cli options. + + defaults: dict + The default command line options + available_masks: list + The available masks that exist within the alignments file + has_predicted_mask: bool + Whether the model was trained with a mask + + Returns + ------- + list + The list of `lib.gui.control_helper.ControlPanelOption` objects for the Action Frame + """ + cp_options: list[ControlPanelOption] = [] + for opt in self._options: + if opt == "face_scale": + cp_option = ControlPanelOption(title=opt, + dtype=float, + default=0.0, + rounding=2, + min_max=(-10., 10.), + group="Command Line Choices") + else: + if opt == "mask_type": + choices = self._create_mask_choices(defaults, + available_masks, + has_predicted_mask) + else: + choices = PluginLoader.get_available_convert_plugins(opt, True) + cp_option = ControlPanelOption(title=opt, + dtype=str, + default=defaults[opt], + initial_value=defaults[opt], + choices=choices, + group="Command Line Choices", + is_radio=False) + self._tk_vars[opt] = cp_option.tk_var + cp_options.append(cp_option) + return cp_options + + def _create_mask_choices(self, + defaults: dict[str, T.Any], + available_masks: list[str], + has_predicted_mask: bool) -> list[str]: + """Set the mask choices and default mask based on available masks. + + Parameters + ---------- + defaults: dict + The default command line options + available_masks: list + The available masks that exist within the alignments file + has_predicted_mask: bool + Whether the model was trained with a mask + + Returns + ------- + list + The masks that are available to use from the alignments file + """ + logger.debug("Initial mask choices: %s", available_masks) + if has_predicted_mask: + available_masks += ["predicted"] + if "none" not in available_masks: + available_masks += ["none"] + if self._format_from_display(defaults["mask_type"]) not in available_masks: + logger.debug("Setting default mask to first available: %s", available_masks[0]) + defaults["mask_type"] = available_masks[0] + logger.debug("Final mask choices: %s", available_masks) + return available_masks + + @classmethod + def _add_refresh_button(cls, + parent: ttk.Frame, + refresh_callback: Callable[[], None]) -> None: + """Add a button to refresh the images. + + Parameters + ---------- + refresh_callback: python function + The function to execute when the refresh button is pressed + """ + btn = ttk.Button(parent, text="Update Samples", command=refresh_callback) + btn.pack(padx=5, pady=5, side=tk.TOP, fill=tk.X, anchor=tk.N) + + def _add_patch_callback(self, patch_callback: Callable[[], None]) -> None: + """Add callback to re-patch images on action option change. + + Parameters + ---------- + patch_callback: python function + The function to execute when the images require patching + """ + for tk_var in self._tk_vars.values(): + tk_var.trace("w", patch_callback) + + def _add_actions(self, parent: ttk.Frame) -> None: + """Add Action Buttons to the :class:`ActionFrame`. + + Parameters + ---------- + parent: tkinter object + The tkinter object that holds the action buttons + """ + logger.debug("Adding util buttons") + frame = ttk.Frame(parent) + frame.pack(padx=5, pady=(5, 10), side=tk.RIGHT, fill=tk.X, anchor=tk.E) + text = "" + action: T.Callable[[], T.Any] | None = None + for utl in ("save", "clear", "reload"): + logger.debug("Adding button: '%s'", utl) + img = get_images().icons[utl] + if utl == "save": + text = _("Save full config") + action = self._app.config_tools.save_config + elif utl == "clear": + text = _("Reset full config to default values") + action = self._app.config_tools.reset_config_to_default + elif utl == "reload": + text = _("Reset full config to saved values") + action = self._app.config_tools.reset_config_to_saved + + assert action is not None + btnutl = ttk.Button(frame, + image=img, # type:ignore[arg-type] + command=action) + btnutl.pack(padx=2, side=tk.RIGHT) + Tooltip(btnutl, text=text, wrap_length=200) + logger.debug("Added util buttons") + + +class OptionsBook(ttk.Notebook): # pylint:disable=too-many-ancestors + + """The notebook that holds the Convert configuration options. + + Parameters + ---------- + parent: tkinter object + The parent tkinter object that holds the Options book + config_tools: :class:`ConfigTools` + Tools for loading and saving configuration files + patch_callback: python function + The function to execute when a patch callback is received + + Attributes + ---------- + config_tools: :class:`ConfigTools` + Tools for loading and saving configuration files + """ + def __init__(self, + parent: ttk.Frame, + config_tools: ConfigTools, + patch_callback: Callable[[], None]) -> None: + logger.debug("Initializing %s: (parent: %s, config: %s)", + self.__class__.__name__, parent, config_tools) + super().__init__(parent) + self.pack(side=tk.RIGHT, anchor=tk.N, fill=tk.BOTH, expand=True) + self.config_tools = config_tools + + self._tabs: dict[str, dict[str, ttk.Notebook | ConfigFrame]] = {} + self._build_tabs() + self._build_sub_tabs() + self._add_patch_callback(patch_callback) + logger.debug("Initialized %s", self.__class__.__name__) + + def _build_tabs(self) -> None: + """Build the notebook tabs for the each configuration section.""" + logger.debug("Build Tabs") + for section in self.config_tools.sections: + tab = ttk.Notebook(self) + self._tabs[section] = {"tab": tab} + self.add(tab, text=section.replace("_", " ").title()) + + def _build_sub_tabs(self) -> None: + """Build the notebook sub tabs for each convert section's plugin.""" + for section, plugins in self.config_tools.plugins_dict.items(): + for plugin in plugins: + config_key = ".".join((section, plugin)) + config_dict = self.config_tools.config_dicts[config_key] + tab = ConfigFrame(self, config_key, config_dict) + self._tabs[section][plugin] = tab + text = plugin.replace("_", " ").title() + T.cast(ttk.Notebook, self._tabs[section]["tab"]).add(tab, text=text) + + def _add_patch_callback(self, patch_callback: Callable[[], None]) -> None: + """Add callback to re-patch images on configuration option change. + + Parameters + ---------- + patch_callback: python function + The function to execute when the images require patching + """ + for plugins in self.config_tools.tk_vars.values(): + for tk_var in plugins.values(): + tk_var.trace("w", patch_callback) + + +class ConfigFrame(ttk.Frame): # pylint:disable=too-many-ancestors + """Holds the configuration options for a convert plugin inside the :class:`OptionsBook`. + + Parameters + ---------- + parent: tkinter object + The tkinter object that will hold this configuration frame + config_key: str + The section/plugin key for these configuration options + options: dict + The options for this section/plugin + """ + + def __init__(self, + parent: OptionsBook, + config_key: str, + options: dict[str, T.Any]): + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(parent) + self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + + self._options = options + + self._action_frame = ttk.Frame(self) + self._action_frame.pack(padx=0, pady=(0, 5), side=tk.BOTTOM, fill=tk.X, anchor=tk.E) + self._add_frame_separator() + + self._build_frame(parent, config_key) + logger.debug("Initialized %s", self.__class__.__name__) + + def _build_frame(self, parent: OptionsBook, config_key: str) -> None: + """Build the options frame for this command. + + Parameters + ---------- + parent: tkinter object + The tkinter object that will hold this configuration frame + config_key: str + The section/plugin key for these configuration options + """ + logger.debug("Add Config Frame") + panel_kwargs = {"columns": 2, "option_columns": 2, "blank_nones": False, "style": "CPanel"} + frame = ttk.Frame(self) + frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + cp_options = [opt for key, opt in self._options.items() if key != "helptext"] + ControlPanel(frame, cp_options, header_text=None, **panel_kwargs) + self._add_actions(parent, config_key) + logger.debug("Added Config Frame") + + def _add_frame_separator(self) -> None: + """Add a separator between top and bottom frames.""" + logger.debug("Add frame seperator") + sep = ttk.Frame(self._action_frame, height=2, relief=tk.RIDGE) + sep.pack(fill=tk.X, pady=5, side=tk.TOP) + logger.debug("Added frame seperator") + + def _add_actions(self, parent: OptionsBook, config_key: str) -> None: + """Add Action Buttons. + + Parameters + ---------- + parent: tkinter object + The tkinter object that will hold this configuration frame + config_key: str + The section/plugin key for these configuration options + """ + logger.debug("Adding util buttons") + + title = config_key.split(".")[1].replace("_", " ").title() + btn_frame = ttk.Frame(self._action_frame) + btn_frame.pack(padx=5, side=tk.BOTTOM, fill=tk.X) + text = "" + action = None + for utl in ("save", "clear", "reload"): + logger.debug("Adding button: '%s'", utl) + img = get_images().icons[utl] + if utl == "save": + text = _(f"Save {title} config") + action = parent.config_tools.save_config + elif utl == "clear": + text = _(f"Reset {title} config to default values") + action = parent.config_tools.reset_config_to_default + elif utl == "reload": + text = _(f"Reset {title} config to saved values") + action = parent.config_tools.reset_config_to_saved + + btnutl = ttk.Button(btn_frame, + image=img, # type:ignore[arg-type] + command=lambda cmd=action: cmd(config_key)) # type:ignore[misc] + btnutl.pack(padx=2, side=tk.RIGHT) + Tooltip(btnutl, text=text, wrap_length=200) + logger.debug("Added util buttons") + + +__all__ = get_module_objects(__name__) diff --git a/tools/preview/preview.py b/tools/preview/preview.py new file mode 100644 index 0000000000..9e888867bb --- /dev/null +++ b/tools/preview/preview.py @@ -0,0 +1,656 @@ +#!/usr/bin/env python3 +""" Tool to preview swaps and tweak configuration prior to running a convert """ +from __future__ import annotations +import gettext +import logging +import random +import tkinter as tk +import typing as T + +from tkinter import ttk +import os +import sys + +from threading import Event, Lock, Thread + +import numpy as np + +from lib.align import DetectedFace +from lib.cli.args_extract_convert import ConvertArgs +from lib.gui.utils import get_images, get_config, initialize_config, initialize_images +from lib.convert import Converter +from lib.utils import get_module_objects, FaceswapError, handle_deprecated_cliopts +from lib.queue_manager import queue_manager +from scripts.fsmedia import Alignments, Images +from scripts.convert import Predict, ConvertItem + +from plugins.extract import ExtractMedia + +from .control_panels import ActionFrame, ConfigTools, OptionsBook +from .viewer import FacesDisplay, ImagesCanvas + +if T.TYPE_CHECKING: + from argparse import Namespace + from lib.queue_manager import EventQueue + from .control_panels import BusyProgressBar + +logger = logging.getLogger(__name__) + +# LOCALES +_LANG = gettext.translation("tools.preview", localedir="locales", fallback=True) +_ = _LANG.gettext + + +class Preview(tk.Tk): + """ This tool is part of the Faceswap Tools suite and should be called from + ``python tools.py preview`` command. + + Loads up 5 semi-random face swaps and displays them, cropped, in place in the final frame. + Allows user to live tweak settings, before saving the final config to + :file:`./config/convert.ini` + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + """ + _w: str + + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: '%s'", self.__class__.__name__, arguments) + super().__init__() + arguments = handle_deprecated_cliopts(arguments) + self._config_tools = ConfigTools(arguments.configfile) + self._lock = Lock() + self._dispatcher = Dispatcher(self) + self._display = FacesDisplay(self, 256, 64) + self._samples = Samples(self, arguments, 5) + self._patch = Patch(self, arguments) + + self._initialize_tkinter() + self._image_canvas: ImagesCanvas | None = None + self._opts_book: OptionsBook | None = None + self._cli_frame: ActionFrame | None = None # cli frame holds cli options + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def config_tools(self) -> "ConfigTools": + """ :class:`ConfigTools`: The object responsible for parsing configuration options and + updating to/from the GUI """ + return self._config_tools + + @property + def dispatcher(self) -> "Dispatcher": + """ :class:`Dispatcher`: The object responsible for triggering events and variables and + handling global GUI state """ + return self._dispatcher + + @property + def display(self) -> FacesDisplay: + """ :class:`~tools.preview.viewer.FacesDisplay`: The object that holds the sample, + converted and patched faces """ + return self._display + + @property + def lock(self) -> Lock: + """ :class:`threading.Lock`: The threading lock object for the Preview GUI """ + return self._lock + + @property + def progress_bar(self) -> BusyProgressBar: + """ :class:`~tools.preview.control_panels.BusyProgressBar`: The progress bar that indicates + a swap/patch thread is running """ + assert self._cli_frame is not None + return self._cli_frame.busy_progress_bar + + def update_display(self): + """ Update the images in the canvas and redraw """ + if not hasattr(self, "_image_canvas"): # On first call object not yet created + return + assert self._image_canvas is not None + self._image_canvas.reload() + + def _initialize_tkinter(self) -> None: + """ Initialize a standalone tkinter instance. """ + logger.debug("Initializing tkinter") + initialize_config(self, None, None) + initialize_images() + get_config().set_geometry(940, 600, fullscreen=False) + self.title("Faceswap.py - Convert Settings") + self.tk.call( + "wm", + "iconphoto", + self._w, + get_images().icons["favicon"]) # pylint:disable=protected-access + logger.debug("Initialized tkinter") + + def process(self) -> None: + """ The entry point for the Preview tool from :file:`lib.tools.cli`. + + Launch the tkinter preview Window and run main loop. + """ + self._build_ui() + self.mainloop() + + def _refresh(self, *args) -> None: + """ Patch faces with current convert settings. + + Parameters + ---------- + *args: tuple + Unused, but required for tkinter callback. + """ + logger.debug("Patching swapped faces. args: %s", args) + self._dispatcher.set_busy() + self._config_tools.update_config() + with self._lock: + assert self._cli_frame is not None + self._patch.converter_arguments = self._cli_frame.convert_args + + self._dispatcher.set_needs_patch() + logger.debug("Patched swapped faces") + + def _build_ui(self) -> None: + """ Build the elements for displaying preview images and options panels. """ + container = ttk.PanedWindow(self, + orient=tk.VERTICAL) + container.pack(fill=tk.BOTH, expand=True) + setattr(container, "preview_display", self._display) # TODO subclass not setattr + self._image_canvas = ImagesCanvas(self, container) + container.add(self._image_canvas, weight=3) + + options_frame = ttk.Frame(container) + self._cli_frame = ActionFrame(self, options_frame) + self._opts_book = OptionsBook(options_frame, + self._config_tools, + self._refresh) + container.add(options_frame, weight=1) + self.update_idletasks() + container.sashpos(0, int(400 * get_config().scaling_factor)) + + +class Dispatcher(): + """ Handles the app level tk.Variables and the threading events. Dispatches events to the + correct location and handles GUI state whilst events are handled + + Parameters + ---------- + app: :class:`Preview` + The main tkinter Preview app + """ + def __init__(self, app: Preview): + logger.debug("Initializing %s: (app: %s)", self.__class__.__name__, app) + self._app = app + self._tk_busy = tk.BooleanVar(value=False) + self._evnt_needs_patch = Event() + self._is_updating = False + self._stacked_event = False + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def needs_patch(self) -> Event: + """:class:`threading.Event`. Set by the parent and cleared by the child. Informs the child + patching thread that a run needs to be processed """ + return self._evnt_needs_patch + + # TKInter Variables + def set_busy(self) -> None: + """ Set the tkinter busy variable to ``True`` and display the busy progress bar """ + if self._tk_busy.get(): + logger.debug("Busy event is already set. Doing nothing") + return + if not hasattr(self._app, "progress_bar"): + logger.debug("Not setting busy during initial startup") + return + + logger.debug("Setting busy event to True") + self._tk_busy.set(True) + self._app.progress_bar.start() + self._app.update_idletasks() + + def _unset_busy(self) -> None: + """ Set the tkinter busy variable to ``False`` and hide the busy progress bar """ + self._is_updating = False + if not self._tk_busy.get(): + logger.debug("busy unset when already unset. Doing nothing") + return + logger.debug("Setting busy event to False") + self._tk_busy.set(False) + self._app.progress_bar.stop() + self._app.update_idletasks() + + # Threading Events + def _wait_for_patch(self) -> None: + """ Wait for a patch thread to complete before triggering a display refresh and unsetting + the busy indicators """ + logger.debug("Checking for patch completion...") + if self._evnt_needs_patch.is_set(): + logger.debug("Samples not patched. Waiting...") + self._app.after(1000, self._wait_for_patch) + return + + logger.debug("Patch completion detected") + self._app.update_display() + self._unset_busy() + + if self._stacked_event: + logger.debug("Processing last stacked event") + self.set_busy() + self._stacked_event = False + self.set_needs_patch() + return + + def set_needs_patch(self) -> None: + """ Sends a trigger to the patching thread that it needs to be run. Waits for the patching + to complete prior to triggering a display refresh and unsetting the busy indicators """ + if self._is_updating: + logger.debug("Request to run patch when it is already running. Adding stacked event.") + self._stacked_event = True + return + self._is_updating = True + logger.debug("Triggering patch") + self._evnt_needs_patch.set() + self._wait_for_patch() + + +class Samples(): + """ The display samples. + + Obtains and holds :attr:`sample_size` semi random test faces for displaying in the + preview GUI. + + The file list is split into evenly sized groups of :attr:`sample_size`. When a display set is + generated, a random image from each of the groups is selected to provide an array of images + across the length of the video. + + Parameters + ---------- + app: :class:`Preview` + The main tkinter Preview app + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + sample_size: int + The number of samples to take from the input video/images + """ + + def __init__(self, app: Preview, arguments: Namespace, sample_size: int) -> None: + logger.debug("Initializing %s: (app: %s, arguments: '%s', sample_size: %s)", + self.__class__.__name__, app, arguments, sample_size) + self._sample_size = sample_size + self._app = app + self._input_images: list[ConvertItem] = [] + self._predicted_images: list[tuple[ConvertItem, np.ndarray]] = [] + + self._images = Images(arguments) + self._alignments = Alignments(arguments, + is_extract=False, + input_is_video=self._images.is_video) + if self._alignments.version == 1.0: + logger.error("The alignments file format has been updated since the given alignments " + "file was generated. You need to update the file to proceed.") + logger.error("To do this run the 'Alignments Tool' > 'Extract' Job.") + sys.exit(1) + + if not self._alignments.have_alignments_file: + logger.error("Alignments file not found at: '%s'", self._alignments.file) + sys.exit(1) + + if self._images.is_video: + assert isinstance(self._images.input_images, str) + self._alignments.update_legacy_has_source(os.path.basename(self._images.input_images)) + + self._filelist = self._get_filelist() + self._indices = self._get_indices() + + self._predictor = Predict(self._sample_size, arguments) + self._predictor.launch(queue_manager.get_queue("preview_predict_in")) + self._app._display.set_centering(self._predictor.centering) + self.generate() + + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def available_masks(self) -> list[str]: + """ list: The mask names that are available for every face in the alignments file """ + retval = [key + for key, val in self.alignments.mask_summary.items() + if val == self.alignments.faces_count] + return retval + + @property + def sample_size(self) -> int: + """ int: The number of samples to take from the input video/images """ + return self._sample_size + + @property + def predicted_images(self) -> list[tuple[ConvertItem, np.ndarray]]: + """ list: The predicted faces output from the Faceswap model """ + return self._predicted_images + + @property + def alignments(self) -> Alignments: + """ :class:`~lib.align.Alignments`: The alignments for the preview faces """ + return self._alignments + + @property + def predictor(self) -> Predict: + """ :class:`~scripts.convert.Predict`: The Predictor for the Faceswap model """ + return self._predictor + + @property + def _random_choice(self) -> list[int]: + """ list: Random indices from the :attr:`_indices` group """ + retval = [random.choice(indices) for indices in self._indices] + logger.debug(retval) + return retval + + def _get_filelist(self) -> list[str]: + """ Get a list of files for the input, filtering out those frames which do + not contain faces. + + Returns + ------- + list + A list of filenames of frames that contain faces. + """ + logger.debug("Filtering file list to frames with faces") + if isinstance(self._images.input_images, str): + vid_name, ext = os.path.splitext(self._images.input_images) + filelist = [f"{vid_name}_{frame_no:06d}{ext}" + for frame_no in range(1, self._images.images_found + 1)] + else: + filelist = self._images.input_images + + retval = [filename for filename in filelist + if self._alignments.frame_has_faces(os.path.basename(filename))] + logger.debug("Filtered out frames: %s", self._images.images_found - len(retval)) + try: + assert retval + except AssertionError as err: + msg = ("No faces were found in any of the frames passed in. Make sure you are passing " + "in a frames source rather than extracted faces, and that you have provided " + "the correct alignments file.") + raise FaceswapError(msg) from err + return retval + + def _get_indices(self) -> list[list[int]]: + """ Get indices for each sample group. + + Obtain :attr:`self.sample_size` evenly sized groups of indices + pertaining to the filtered :attr:`self._file_list` + + Returns + ------- + list + list of indices relating to the filtered file list, split into groups + """ + # Remove start and end values to get a list divisible by self.sample_size + no_files = len(self._filelist) + self._sample_size = min(self._sample_size, no_files) + crop = no_files % self._sample_size + top_tail = list(range(no_files))[ + crop // 2:no_files - (crop - (crop // 2))] + # Partition the indices + size = len(top_tail) + retval = [top_tail[start:start + size // self._sample_size] + for start in range(0, size, size // self._sample_size)] + logger.debug("Indices pools: %s", [f"{idx}: (start: {min(pool)}, " + f"end: {max(pool)}, size: {len(pool)})" + for idx, pool in enumerate(retval)]) + return retval + + def generate(self) -> None: + """ Generate a sample set. + + Selects :attr:`sample_size` random faces. Runs them through prediction to obtain the + swap, then trigger the patch event to run the faces through patching. + """ + logger.debug("Generating new random samples") + self._app.dispatcher.set_busy() + self._load_frames() + self._predict() + self._app.dispatcher.set_needs_patch() + logger.debug("Generated new random samples") + + def _load_frames(self) -> None: + """ Load a sample of random frames. + + * Picks a random face from each indices group. + + * Takes the first face from the image (if there are multiple faces). Adds the images to \ + :attr:`self._input_images`. + + * Sets :attr:`_display.source` to the input images and flags that the display should be \ + updated + """ + self._input_images = [] + for selection in self._random_choice: + filename = os.path.basename(self._filelist[selection]) + image = self._images.load_one_image(self._filelist[selection]) + # Get first face only + face = self._alignments.get_faces_in_frame(filename)[0] + detected_face = DetectedFace() + detected_face.from_alignment(face, image=image) + inbound = ExtractMedia(filename=filename, image=image, detected_faces=[detected_face]) + self._input_images.append(ConvertItem(inbound=inbound)) + self._app.display.source = self._input_images + self._app.display.update_source = True + logger.debug("Selected frames: %s", + [frame.inbound.filename for frame in self._input_images]) + + def _predict(self) -> None: + """ Predict from the loaded frames. + + With a threading lock (to prevent stacking), run the selected faces through the Faceswap + model predict function and add the output to :attr:`predicted` + """ + with self._app.lock: + self._predicted_images = [] + for frame in self._input_images: + self._predictor.in_queue.put(frame) + idx = 0 + while idx < self._sample_size: + logger.debug("Predicting face %s of %s", idx + 1, self._sample_size) + items: (T.Literal["EOF"] | + list[tuple[ConvertItem, np.ndarray]]) = self._predictor.out_queue.get() + if items == "EOF": + logger.debug("Received EOF") + break + for item in items: + self._predicted_images.append(item) + logger.debug("Predicted face %s of %s", idx + 1, self._sample_size) + idx += 1 + logger.debug("Predicted faces") + + +class Patch(): + """ The Patch pipeline + + Runs in it's own thread. Takes the output from the Faceswap model predictor and runs the faces + through the convert pipeline using the currently selected options. + + Parameters + ---------- + app: :class:`Preview` + The main tkinter Preview app + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + + Attributes + ---------- + converter_arguments: dict + The currently selected converter command line arguments for the patch queue + """ + def __init__(self, app: Preview, arguments: Namespace) -> None: + logger.debug("Initializing %s: (app: %s, arguments: '%s')", + self.__class__.__name__, app, arguments) + self._app = app + self._queue_patch_in = queue_manager.get_queue("preview_patch_in") + self.converter_arguments: dict[str, T.Any] | None = None # Updated converter args + + configfile = arguments.configfile if hasattr(arguments, "configfile") else None + self._converter = Converter(output_size=app._samples.predictor.output_size, + coverage_ratio=app._samples.predictor.coverage_ratio, + centering=app._samples.predictor.centering, + draw_transparent=False, + pre_encode=None, + arguments=self._generate_converter_arguments( + arguments, + app._samples.available_masks), + configfile=configfile) + self._thread = Thread(target=self._process, + name="patch_thread", + args=(self._queue_patch_in, + self._app.dispatcher.needs_patch, + app._samples), + daemon=True) + self._thread.start() + logger.debug("Initializing %s", self.__class__.__name__) + + @property + def converter(self) -> Converter: + """ :class:`lib.convert.Converter`: The converter to use for patching the images. """ + return self._converter + + @staticmethod + def _generate_converter_arguments(arguments: Namespace, + available_masks: list[str]) -> Namespace: + """ Add the default converter arguments to the initial arguments. Ensure the mask selection + is available. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in from :mod:`tools.py` + available_masks: list + The masks that are available for convert + Returns + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments as passed in with converter default + arguments added + """ + valid_masks = available_masks + ["none"] + converter_arguments = ConvertArgs(None, "convert").get_optional_arguments() # type: ignore + for item in converter_arguments: + value = item.get("default", None) + # Skip options without a default value + if value is None: + continue + option = item.get("dest", item["opts"][1].replace("--", "")) + if option == "mask_type" and value not in valid_masks: + logger.debug("Amending default mask from '%s' to '%s'", value, valid_masks[0]) + value = valid_masks[0] + # Skip options already in arguments + if hasattr(arguments, option): + continue + # Add option to arguments + setattr(arguments, option, value) + logger.debug(arguments) + return arguments + + def _process(self, + patch_queue_in: EventQueue, + trigger_event: Event, + samples: Samples) -> None: + """ The face patching process. + + Runs in a thread, and waits for an event to be set. Once triggered, runs a patching + cycle and sets the :class:`Display` destination images. + + Parameters + ---------- + patch_queue_in: :class:`~lib.queue_manager.EventQueue` + The input queue for the patching process + trigger_event: :class:`threading.Event` + The event that indicates a patching run needs to be processed + samples: :class:`Samples` + The Samples for display. + """ + logger.debug("Launching patch process thread: (patch_queue_in: %s, trigger_event: %s, " + "samples: %s)", patch_queue_in, trigger_event, samples) + patch_queue_out = queue_manager.get_queue("preview_patch_out") + while True: + trigger = trigger_event.wait(1) + if not trigger: + continue + logger.debug("Patch Triggered") + queue_manager.flush_queue("preview_patch_in") + self._feed_swapped_faces(patch_queue_in, samples) + with self._app.lock: + self._update_converter_arguments() + self._converter.reinitialize() + swapped = self._patch_faces(patch_queue_in, patch_queue_out, samples.sample_size) + with self._app.lock: + self._app.display.destination = swapped + + logger.debug("Patch complete") + trigger_event.clear() + + logger.debug("Closed patch process thread") + + def _update_converter_arguments(self) -> None: + """ Update the converter arguments to the currently selected values. """ + logger.debug("Updating Converter cli arguments") + if self.converter_arguments is None: + logger.debug("No arguments to update") + return + for key, val in self.converter_arguments.items(): + logger.debug("Updating %s to %s", key, val) + setattr(self._converter.cli_arguments, key, val) + logger.debug("Updated Converter cli arguments") + + @staticmethod + def _feed_swapped_faces(patch_queue_in: EventQueue, samples: Samples) -> None: + """ Feed swapped faces to the converter's in-queue. + + Parameters + ---------- + patch_queue_in: :class:`~lib.queue_manager.EventQueue` + The input queue for the patching process + samples: :class:`Samples` + The Samples for display. + """ + logger.debug("feeding swapped faces to converter") + for item in samples.predicted_images: + patch_queue_in.put(item) + logger.debug("fed %s swapped faces to converter", + len(samples.predicted_images)) + logger.debug("Putting EOF to converter") + patch_queue_in.put("EOF") + + def _patch_faces(self, + queue_in: EventQueue, + queue_out: EventQueue, + sample_size: int) -> list[np.ndarray]: + """ Patch faces. + + Run the convert process on the swapped faces and return the patched faces. + + patch_queue_in: :class:`~lib.queue_manager.EventQueue` + The input queue for the patching process + queue_out: :class:`~lib.queue_manager.EventQueue` + The output queue from the patching process + sample_size: int + The number of samples to be displayed + + Returns + ------- + list + The swapped faces patched with the selected convert settings + """ + logger.debug("Patching faces") + self._converter.process(queue_in, queue_out) + swapped = [] + idx = 0 + while idx < sample_size: + logger.debug("Patching image %s of %s", idx + 1, sample_size) + item = queue_out.get() + swapped.append(item[1]) + logger.debug("Patched image %s of %s", idx + 1, sample_size) + idx += 1 + logger.debug("Patched faces") + return swapped + + +__all__ = get_module_objects(__name__) diff --git a/tools/preview/viewer.py b/tools/preview/viewer.py new file mode 100644 index 0000000000..fc5586ba50 --- /dev/null +++ b/tools/preview/viewer.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +""" Manages the widgets that hold the top 'viewer' area of the preview tool """ +from __future__ import annotations +import logging +import os +import tkinter as tk +import typing as T + +from tkinter import ttk +from dataclasses import dataclass, field + +import cv2 +import numpy as np +from PIL import Image, ImageTk + +from lib.align import transform_image +from lib.align.aligned_face import CenteringType +from lib.utils import get_module_objects +from scripts.convert import ConvertItem + + +if T.TYPE_CHECKING: + from .preview import Preview + +logger = logging.getLogger(__name__) + + +@dataclass +class _Faces: + """ Dataclass for holding faces """ + filenames: list[str] = field(default_factory=list) + matrix: list[np.ndarray] = field(default_factory=list) + src: list[np.ndarray] = field(default_factory=list) + dst: list[np.ndarray] = field(default_factory=list) + + +class FacesDisplay(): # pylint:disable=too-many-instance-attributes + """ Compiles the 2 rows of sample faces (original and swapped) into a single image + + Parameters + ---------- + app: :class:`Preview` + The main tkinter Preview app + size: int + The size of each individual face sample in pixels + padding: int + The amount of extra padding to apply to the outside of the face + + Attributes + ---------- + update_source: bool + Flag to indicate that the source images for the preview have been updated, so the preview + should be recompiled. + source: list + The list of :class:`numpy.ndarray` source preview images for top row of display + destination: list + The list of :class:`numpy.ndarray` swapped and patched preview images for bottom row of + display + """ + def __init__(self, app: Preview, size: int, padding: int) -> None: + logger.trace("Initializing %s: (app: %s, size: %s, padding: %s)", # type: ignore + self.__class__.__name__, app, size, padding) + self._size = size + self._display_dims = (1, 1) + self._app = app + self._padding = padding + + self._faces = _Faces() + self._centering: CenteringType | None = None + self._faces_source: np.ndarray = np.array([]) + self._faces_dest: np.ndarray = np.array([]) + self._tk_image: ImageTk.PhotoImage | None = None + + # Set from Samples + self.update_source = False + self.source: list[ConvertItem] = [] # Source images, filenames + detected faces + # Set from Patch + self.destination: list[np.ndarray] = [] # Swapped + patched images + + logger.trace("Initialized %s", self.__class__.__name__) # type: ignore + + @property + def tk_image(self) -> ImageTk.PhotoImage | None: + """ :class:`PIL.ImageTk.PhotoImage`: The compiled preview display in tkinter display + format """ + return self._tk_image + + @property + def _total_columns(self) -> int: + """ int: The total number of images that are being displayed """ + return len(self.source) + + def set_centering(self, centering: CenteringType) -> None: + """ The centering that the model uses is not known at initialization time. + Set :attr:`_centering` when the model has been loaded. + + Parameters + ---------- + centering: str + The centering that the model was trained on + """ + self._centering = centering + + def set_display_dimensions(self, dimensions: tuple[int, int]) -> None: + """ Adjust the size of the frame that will hold the preview samples. + + Parameters + ---------- + dimensions: tuple + The (`width`, `height`) of the frame that holds the preview + """ + self._display_dims = dimensions + + def update_tk_image(self) -> None: + """ Build the full preview images and compile :attr:`tk_image` for display. """ + logger.trace("Updating tk image") # type: ignore + self._build_faces_image() + img = np.vstack((self._faces_source, self._faces_dest)) + size = self._get_scale_size(img) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + pilimg = Image.fromarray(img) + pilimg = pilimg.resize(size, Image.Resampling.BICUBIC) + self._tk_image = ImageTk.PhotoImage(pilimg) + logger.trace("Updated tk image") # type: ignore + + def _get_scale_size(self, image: np.ndarray) -> tuple[int, int]: + """ Get the size that the full preview image should be resized to fit in the + display window. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The full sized compiled preview image + + Returns + ------- + tuple + The (`width`, `height`) that the display image should be sized to fit in the display + window + """ + frameratio = float(self._display_dims[0]) / float(self._display_dims[1]) + imgratio = float(image.shape[1]) / float(image.shape[0]) + + if frameratio <= imgratio: + scale = self._display_dims[0] / float(image.shape[1]) + size = (self._display_dims[0], max(1, int(image.shape[0] * scale))) + else: + scale = self._display_dims[1] / float(image.shape[0]) + size = (max(1, int(image.shape[1] * scale)), self._display_dims[1]) + logger.trace("scale: %s, size: %s", scale, size) # type: ignore + return size + + def _build_faces_image(self) -> None: + """ Compile the source and destination rows of the preview image. """ + logger.trace("Building Faces Image") # type: ignore + update_all = self.update_source + self._faces_from_frames() + if update_all: + header = self._header_text() + source = np.hstack([self._draw_rect(face) for face in self._faces.src]) + self._faces_source = np.vstack((header, source)) + self._faces_dest = np.hstack([self._draw_rect(face) for face in self._faces.dst]) + logger.debug("source row shape: %s, swapped row shape: %s", + self._faces_dest.shape, self._faces_source.shape) + + def _faces_from_frames(self) -> None: + """ Extract the preview faces from the source frames and apply the requisite padding. """ + logger.debug("Extracting faces from frames: Number images: %s", len(self.source)) + if self.update_source: + self._crop_source_faces() + self._crop_destination_faces() + logger.debug("Extracted faces from frames: %s", + {k: len(v) for k, v in self._faces.__dict__.items()}) + + def _crop_source_faces(self) -> None: + """ Extract the source faces from the source frames, along with their filenames and the + transformation matrix used to extract the faces. """ + logger.debug("Updating source faces") + self._faces = _Faces() # Init new class + for item in self.source: + detected_face = item.inbound.detected_faces[0] + src_img = item.inbound.image + detected_face.load_aligned(src_img, + size=self._size, + centering=T.cast(CenteringType, self._centering)) + matrix = detected_face.aligned.matrix + self._faces.filenames.append(os.path.splitext(item.inbound.filename)[0]) + self._faces.matrix.append(matrix) + self._faces.src.append(transform_image(src_img, matrix, self._size, self._padding)) + self.update_source = False + logger.debug("Updated source faces") + + def _crop_destination_faces(self) -> None: + """ Extract the swapped faces from the swapped frames using the source face destination + matrices. """ + logger.debug("Updating destination faces") + self._faces.dst = [] + destination = self.destination if self.destination else [np.ones_like(src.inbound.image) + for src in self.source] + for idx, image in enumerate(destination): + self._faces.dst.append(transform_image(image, + self._faces.matrix[idx], + self._size, + self._padding)) + logger.debug("Updated destination faces") + + def _header_text(self) -> np.ndarray: + """ Create the header text displaying the frame name for each preview column. + + Returns + ------- + :class:`numpy.ndarray` + The header row of the preview image containing the frame names for each column + """ + font_scale = self._size / 640 + height = self._size // 8 + font = cv2.FONT_HERSHEY_SIMPLEX + # Get size of placed text for positioning + text_sizes = [cv2.getTextSize(self._faces.filenames[idx], + font, + font_scale, + 1)[0] + for idx in range(self._total_columns)] + # Get X and Y co-ordinates for each text item + text_y = int((height + text_sizes[0][1]) / 2) + text_x = [int((self._size - text_sizes[idx][0]) / 2) + self._size * idx + for idx in range(self._total_columns)] + logger.debug("filenames: %s, text_sizes: %s, text_x: %s, text_y: %s", + self._faces.filenames, text_sizes, text_x, text_y) + header_box = np.ones((height, self._size * self._total_columns, 3), np.uint8) * 255 + for idx, text in enumerate(self._faces.filenames): + cv2.putText(header_box, + text, + (text_x[idx], text_y), + font, + font_scale, + (0, 0, 0), + 1, + lineType=cv2.LINE_AA) + logger.debug("header_box.shape: %s", header_box.shape) + return header_box + + def _draw_rect(self, image: np.ndarray) -> np.ndarray: + """ Place a white border around a given image. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to place a border on to + Returns + ------- + :class:`numpy.ndarray` + The given image with a border drawn around the outside + """ + cv2.rectangle(image, (0, 0), (self._size - 1, self._size - 1), (255, 255, 255), 1) + image = np.clip(image, 0.0, 255.0) + return image.astype("uint8") + + +class ImagesCanvas(ttk.Frame): # pylint:disable=too-many-ancestors + """ tkinter Canvas that holds the preview images. + + Parameters + ---------- + app: :class:`Preview` + The main tkinter Preview app + parent: tkinter object + The parent tkinter object that holds the canvas + """ + def __init__(self, app: Preview, parent: ttk.PanedWindow) -> None: + logger.debug("Initializing %s: (app: %s, parent: %s)", + self.__class__.__name__, app, parent) + super().__init__(parent) + self.pack(expand=True, fill=tk.BOTH, padx=2, pady=2) + + self._display: FacesDisplay = parent.preview_display # type: ignore + self._canvas = tk.Canvas(self, bd=0, highlightthickness=0) + self._canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + self._displaycanvas = self._canvas.create_image(0, 0, + image=self._display.tk_image, + anchor=tk.NW) + self.bind("", self._resize) + logger.debug("Initialized %s", self.__class__.__name__) + + def _resize(self, event: tk.Event) -> None: + """ Resize the image to fit the frame, maintaining aspect ratio """ + logger.debug("Resizing preview image") + framesize = (event.width, event.height) + self._display.set_display_dimensions(framesize) + self.reload() + + def reload(self) -> None: + """ Update the images in the canvas and redraw """ + logger.debug("Reloading preview image") + self._display.update_tk_image() + self._canvas.itemconfig(self._displaycanvas, image=self._display.tk_image) + logger.debug("Reloaded preview image") + + +__all__ = get_module_objects(__name__) diff --git a/tools/sort.py b/tools/sort.py deleted file mode 100644 index ade3b76330..0000000000 --- a/tools/sort.py +++ /dev/null @@ -1,910 +0,0 @@ -#!/usr/bin/env python3 -""" -A tool that allows for sorting and grouping images in different ways. -""" -import logging -import os -import sys -import operator -from shutil import copyfile - -import numpy as np -import cv2 -from tqdm import tqdm - -# faceswap imports -import face_recognition - -from lib.cli import FullHelpArgumentParser -from lib import Serializer -from lib.faces_detect import DetectedFace -from lib.multithreading import SpawnProcess -from lib.queue_manager import queue_manager, QueueEmpty -from plugins.plugin_loader import PluginLoader - -from . import cli - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Sort(): - """ Sorts folders of faces based on input criteria """ - # pylint: disable=no-member - def __init__(self, arguments): - self.args = arguments - self.changes = None - self.serializer = None - - def process(self): - """ Main processing function of the sort tool """ - - # Setting default argument values that cannot be set by argparse - - # Set output dir to the same value as input dir - # if the user didn't specify it. - if self.args.output_dir.lower() == "_output_dir": - self.args.output_dir = self.args.input_dir - - # Assigning default threshold values based on grouping method - if (self.args.final_process == "folders" - and self.args.min_threshold < 0.0): - method = self.args.group_method.lower() - if method == 'face': - self.args.min_threshold = 0.6 - elif method == 'face-cnn': - self.args.min_threshold = 7.2 - elif method == 'hist': - self.args.min_threshold = 0.3 - - # If logging is enabled, prepare container - if self.args.log_changes: - self.changes = dict() - - # Assign default sort_log.json value if user didn't specify one - if self.args.log_file_path == 'sort_log.json': - self.args.log_file_path = os.path.join(self.args.input_dir, - 'sort_log.json') - - # Set serializer based on logfile extension - serializer_ext = os.path.splitext( - self.args.log_file_path)[-1] - self.serializer = Serializer.get_serializer_from_ext( - serializer_ext) - - # Prepare sort, group and final process method names - _sort = "sort_" + self.args.sort_method.lower() - _group = "group_" + self.args.group_method.lower() - _final = "final_process_" + self.args.final_process.lower() - self.args.sort_method = _sort.replace('-', '_') - self.args.group_method = _group.replace('-', '_') - self.args.final_process = _final.replace('-', '_') - - self.sort_process() - - def launch_aligner(self): - """ Load the aligner plugin to retrieve landmarks """ - out_queue = queue_manager.get_queue("out") - kwargs = {"in_queue": queue_manager.get_queue("in"), - "out_queue": out_queue} - - for plugin in ("fan", "dlib"): - aligner = PluginLoader.get_aligner(plugin)(loglevel=self.args.loglevel) - process = SpawnProcess(aligner.run, **kwargs) - event = process.event - process.start() - # Wait for Aligner to take init - # The first ever load of the model for FAN has reportedly taken - # up to 3-4 minutes, hence high timeout. - event.wait(300) - - if not event.is_set(): - if plugin == "fan": - process.join() - logger.error("Error initializing FAN. Trying Dlib") - continue - else: - raise ValueError("Error inititalizing Aligner") - if plugin == "dlib": - return - - try: - err = None - err = out_queue.get(True, 1) - except QueueEmpty: - pass - if not err: - break - process.join() - logger.error("Error initializing FAN. Trying Dlib") - - @staticmethod - def alignment_dict(image): - """ Set the image to a dict for alignment """ - height, width = image.shape[:2] - face = DetectedFace(x=0, w=width, y=0, h=height) - face = face.to_dlib_rect() - return {"image": image, - "detected_faces": [face]} - - @staticmethod - def get_landmarks(filename): - """ Extract the face from a frame (If not alignments file found) """ - image = cv2.imread(filename) - queue_manager.get_queue("in").put(Sort.alignment_dict(image)) - face = queue_manager.get_queue("out").get() - landmarks = face["landmarks"][0] - return landmarks - - def sort_process(self): - """ - This method dynamically assigns the functions that will be used to run - the core process of sorting, optionally grouping, renaming/moving into - folders. After the functions are assigned they are executed. - """ - sort_method = self.args.sort_method.lower() - group_method = self.args.group_method.lower() - final_method = self.args.final_process.lower() - - img_list = getattr(self, sort_method)() - if "folders" in final_method: - # Check if non-dissim sort method and group method are not the same - if group_method.replace('group_', '') not in sort_method: - img_list = self.reload_images(group_method, img_list) - img_list = getattr(self, group_method)(img_list) - else: - img_list = getattr(self, group_method)(img_list) - - getattr(self, final_method)(img_list) - - logger.info("Done.") - - # Methods for sorting - def sort_blur(self): - """ Sort by blur amount """ - input_dir = self.args.input_dir - - logger.info("Sorting by blur...") - img_list = [[img, self.estimate_blur(img)] - for img in - tqdm(self.find_images(input_dir), - desc="Loading", - file=sys.stdout)] - logger.info("Sorting...") - - img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) - - return img_list - - def sort_face(self): - """ Sort by face similarity """ - input_dir = self.args.input_dir - - logger.info("Sorting by face similarity...") - - img_list = [[img, face_recognition.face_encodings(cv2.imread(img))] - for img in - tqdm(self.find_images(input_dir), - desc="Loading", - file=sys.stdout)] - - img_list_len = len(img_list) - for i in tqdm(range(0, img_list_len - 1), - desc="Sorting", - file=sys.stdout): - min_score = float("inf") - j_min_score = i + 1 - for j in range(i + 1, len(img_list)): - f1encs = img_list[i][1] - f2encs = img_list[j][1] - if f1encs and f2encs: - score = face_recognition.face_distance(f1encs[0], - f2encs)[0] - else: - score = float("inf") - - if score < min_score: - min_score = score - j_min_score = j - (img_list[i + 1], - img_list[j_min_score]) = (img_list[j_min_score], - img_list[i + 1]) - return img_list - - def sort_face_dissim(self): - """ Sort by face dissimilarity """ - input_dir = self.args.input_dir - - logger.info("Sorting by face dissimilarity...") - - img_list = [[img, face_recognition.face_encodings(cv2.imread(img)), 0] - for img in - tqdm(self.find_images(input_dir), - desc="Loading", - file=sys.stdout)] - - img_list_len = len(img_list) - for i in tqdm(range(0, img_list_len), desc="Sorting", file=sys.stdout): - score_total = 0 - for j in range(0, img_list_len): - if i == j: - continue - try: - score_total += face_recognition.face_distance( - [img_list[i][1]], - [img_list[j][1]]) - except: - logger.info("except") - pass - - img_list[i][2] = score_total - - logger.info("Sorting...") - img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True) - return img_list - - def sort_face_cnn(self): - """ Sort by CNN similarity """ - self.launch_aligner() - input_dir = self.args.input_dir - - logger.info("Sorting by face-cnn similarity...") - img_list = [] - for img in tqdm(self.find_images(input_dir), - desc="Loading", - file=sys.stdout): - landmarks = self.get_landmarks(img) - img_list.append([img, np.array(landmarks) - if landmarks - else np.zeros((68, 2))]) - - queue_manager.terminate_queues() - img_list_len = len(img_list) - for i in tqdm(range(0, img_list_len - 1), - desc="Sorting", - file=sys.stdout): - min_score = float("inf") - j_min_score = i + 1 - for j in range(i + 1, len(img_list)): - fl1 = img_list[i][1] - fl2 = img_list[j][1] - score = np.sum(np.absolute((fl2 - fl1).flatten())) - - if score < min_score: - min_score = score - j_min_score = j - (img_list[i + 1], - img_list[j_min_score]) = (img_list[j_min_score], - img_list[i + 1]) - return img_list - - def sort_face_cnn_dissim(self): - """ Sort by CNN dissimilarity """ - self.launch_aligner() - input_dir = self.args.input_dir - - logger.info("Sorting by face-cnn dissimilarity...") - - img_list = [] - for img in tqdm(self.find_images(input_dir), - desc="Loading", - file=sys.stdout): - landmarks = self.get_landmarks(img) - img_list.append([img, np.array(landmarks) - if landmarks - else np.zeros((68, 2)), 0]) - - img_list_len = len(img_list) - for i in tqdm(range(0, img_list_len - 1), - desc="Sorting", - file=sys.stdout): - score_total = 0 - for j in range(i + 1, len(img_list)): - if i == j: - continue - fl1 = img_list[i][1] - fl2 = img_list[j][1] - score_total += np.sum(np.absolute((fl2 - fl1).flatten())) - - img_list[i][2] = score_total - - logger.info("Sorting...") - img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True) - - return img_list - - def sort_face_yaw(self): - """ Sort by yaw of face """ - self.launch_aligner() - input_dir = self.args.input_dir - - img_list = [] - for img in tqdm(self.find_images(input_dir), - desc="Loading", - file=sys.stdout): - landmarks = self.get_landmarks(img) - img_list.append( - [img, self.calc_landmarks_face_yaw(np.array(landmarks))]) - - logger.info("Sorting by face-yaw...") - img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) - - return img_list - - def sort_hist(self): - """ Sort by histogram of face similarity """ - input_dir = self.args.input_dir - - logger.info("Sorting by histogram similarity...") - - img_list = [ - [img, cv2.calcHist([cv2.imread(img)], [0], None, [256], [0, 256])] - for img in - tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout) - ] - - img_list_len = len(img_list) - for i in tqdm(range(0, img_list_len - 1), desc="Sorting", - file=sys.stdout): - min_score = float("inf") - j_min_score = i + 1 - for j in range(i + 1, len(img_list)): - score = cv2.compareHist(img_list[i][1], - img_list[j][1], - cv2.HISTCMP_BHATTACHARYYA) - if score < min_score: - min_score = score - j_min_score = j - (img_list[i + 1], - img_list[j_min_score]) = (img_list[j_min_score], - img_list[i + 1]) - return img_list - - def sort_hist_dissim(self): - """ Sort by histigram of face dissimilarity """ - input_dir = self.args.input_dir - - logger.info("Sorting by histogram dissimilarity...") - - img_list = [ - [img, - cv2.calcHist([cv2.imread(img)], [0], None, [256], [0, 256]), 0] - for img in - tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout) - ] - - img_list_len = len(img_list) - for i in tqdm(range(0, img_list_len), desc="Sorting", file=sys.stdout): - score_total = 0 - for j in range(0, img_list_len): - if i == j: - continue - score_total += cv2.compareHist(img_list[i][1], - img_list[j][1], - cv2.HISTCMP_BHATTACHARYYA) - - img_list[i][2] = score_total - - logger.info("Sorting...") - img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True) - - return img_list - - # Methods for grouping - def group_blur(self, img_list): - """ Group into bins by blur """ - # Starting the binning process - num_bins = self.args.num_bins - - # The last bin will get all extra images if it's - # not possible to distribute them evenly - num_per_bin = len(img_list) // num_bins - remainder = len(img_list) % num_bins - - logger.info("Grouping by blur...") - bins = [[] for _ in range(num_bins)] - idx = 0 - for i in range(num_bins): - for _ in range(num_per_bin): - bins[i].append(img_list[idx][0]) - idx += 1 - - # If remainder is 0, nothing gets added to the last bin. - for i in range(1, remainder + 1): - bins[-1].append(img_list[-i][0]) - - return bins - - def group_face(self, img_list): - """ Group into bins by face similarity """ - logger.info("Grouping by face similarity...") - - # Groups are of the form: group_num -> reference face - reference_groups = dict() - - # Bins array, where index is the group number and value is - # an array containing the file paths to the images in that group. - # The first group (0), is always the non-face group. - bins = [[]] - - # Comparison threshold used to decide how similar - # faces have to be to be grouped together. - min_threshold = self.args.min_threshold - - img_list_len = len(img_list) - - for i in tqdm(range(1, img_list_len), - desc="Grouping", - file=sys.stdout): - f1encs = img_list[i][1] - - # Check if current image is a face, if not then - # add it immediately to the non-face list. - if f1encs is None or len(f1encs) <= 0: - bins[0].append(img_list[i][0]) - - else: - current_best = [-1, float("inf")] - - for key, references in reference_groups.items(): - # Non-faces are not added to reference_groups dict, thus - # removing the need to check that f2encs is a face. - # The try-catch block is to handle the first face that gets - # processed, as the first value is None. - try: - score = self.get_avg_score_faces(f1encs, references) - except TypeError: - score = float("inf") - except ZeroDivisionError: - score = float("inf") - if score < current_best[1]: - current_best[0], current_best[1] = key, score - - if current_best[1] < min_threshold: - reference_groups[current_best[0]].append(f1encs[0]) - bins[current_best[0]].append(img_list[i][0]) - else: - reference_groups[len(reference_groups)] = img_list[i][1] - bins.append([img_list[i][0]]) - - return bins - - def group_face_cnn(self, img_list): - """ Group into bins by CNN face similarity """ - logger.info("Grouping by face-cnn similarity...") - - # Groups are of the form: group_num -> reference faces - reference_groups = dict() - - # Bins array, where index is the group number and value is - # an array containing the file paths to the images in that group. - bins = [] - - # Comparison threshold used to decide how similar - # faces have to be to be grouped together. - # It is multiplied by 1000 here to allow the cli option to use smaller - # numbers. - min_threshold = self.args.min_threshold * 1000 - - img_list_len = len(img_list) - - for i in tqdm(range(0, img_list_len - 1), - desc="Grouping", - file=sys.stdout): - fl1 = img_list[i][1] - - current_best = [-1, float("inf")] - - for key, references in reference_groups.items(): - try: - score = self.get_avg_score_faces_cnn(fl1, references) - except TypeError: - score = float("inf") - except ZeroDivisionError: - score = float("inf") - if score < current_best[1]: - current_best[0], current_best[1] = key, score - - if current_best[1] < min_threshold: - reference_groups[current_best[0]].append(fl1[0]) - bins[current_best[0]].append(img_list[i][0]) - else: - reference_groups[len(reference_groups)] = [img_list[i][1]] - bins.append([img_list[i][0]]) - - return bins - - def group_face_yaw(self, img_list): - """ Group into bins by yaw of face """ - # Starting the binning process - num_bins = self.args.num_bins - - # The last bin will get all extra images if it's - # not possible to distribute them evenly - num_per_bin = len(img_list) // num_bins - remainder = len(img_list) % num_bins - - logger.info("Grouping by face-yaw...") - bins = [[] for _ in range(num_bins)] - idx = 0 - for i in range(num_bins): - for _ in range(num_per_bin): - bins[i].append(img_list[idx][0]) - idx += 1 - - # If remainder is 0, nothing gets added to the last bin. - for i in range(1, remainder + 1): - bins[-1].append(img_list[-i][0]) - - return bins - - def group_hist(self, img_list): - """ Group into bins by histogram """ - logger.info("Grouping by histogram...") - - # Groups are of the form: group_num -> reference histogram - reference_groups = dict() - - # Bins array, where index is the group number and value is - # an array containing the file paths to the images in that group - bins = [] - - min_threshold = self.args.min_threshold - - img_list_len = len(img_list) - reference_groups[0] = [img_list[0][1]] - bins.append([img_list[0][0]]) - - for i in tqdm(range(1, img_list_len), - desc="Grouping", - file=sys.stdout): - current_best = [-1, float("inf")] - for key, value in reference_groups.items(): - score = self.get_avg_score_hist(img_list[i][1], value) - if score < current_best[1]: - current_best[0], current_best[1] = key, score - - if current_best[1] < min_threshold: - reference_groups[current_best[0]].append(img_list[i][1]) - bins[current_best[0]].append(img_list[i][0]) - else: - reference_groups[len(reference_groups)] = [img_list[i][1]] - bins.append([img_list[i][0]]) - - return bins - - # Final process methods - def final_process_rename(self, img_list): - """ Rename the files """ - output_dir = self.args.output_dir - - process_file = self.set_process_file_method(self.args.log_changes, - self.args.keep_original) - - # Make sure output directory exists - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - description = ( - "Copying and Renaming" if self.args.keep_original - else "Moving and Renaming" - ) - - for i in tqdm(range(0, len(img_list)), - desc=description, - leave=False, - file=sys.stdout): - src = img_list[i][0] - src_basename = os.path.basename(src) - - dst = os.path.join(output_dir, '{:05d}_{}'.format(i, src_basename)) - try: - process_file(src, dst, self.changes) - except FileNotFoundError as err: - logger.error(err) - logger.error('fail to rename %s', src) - - for i in tqdm(range(0, len(img_list)), - desc=description, - file=sys.stdout): - renaming = self.set_renaming_method(self.args.log_changes) - src, dst = renaming(img_list[i][0], output_dir, i, self.changes) - - try: - os.rename(src, dst) - except FileNotFoundError as err: - logger.error(err) - logger.error('fail to rename %s', format(src)) - - if self.args.log_changes: - self.write_to_log(self.changes) - - def final_process_folders(self, bins): - """ Move the files to folders """ - output_dir = self.args.output_dir - - process_file = self.set_process_file_method(self.args.log_changes, - self.args.keep_original) - - # First create new directories to avoid checking - # for directory existence in the moving loop - logger.info("Creating group directories.") - for i in range(len(bins)): - directory = os.path.join(output_dir, str(i)) - if not os.path.exists(directory): - os.makedirs(directory) - - description = ( - "Copying into Groups" if self.args.keep_original - else "Moving into Groups" - ) - - logger.info("Total groups found: %s", len(bins)) - for i in tqdm(range(len(bins)), desc=description, file=sys.stdout): - for j in range(len(bins[i])): - src = bins[i][j] - src_basename = os.path.basename(src) - - dst = os.path.join(output_dir, str(i), src_basename) - try: - process_file(src, dst, self.changes) - except FileNotFoundError as err: - logger.error(err) - logger.error("Failed to move '%s' to '%s'", src, dst) - - if self.args.log_changes: - self.write_to_log(self.changes) - - # Various helper methods - def write_to_log(self, changes): - """ Write the changes to log file """ - logger.info("Writing sort log to: '%s'", self.args.log_file_path) - with open(self.args.log_file_path, 'w') as lfile: - lfile.write(self.serializer.marshal(changes)) - - def reload_images(self, group_method, img_list): - """ - Reloads the image list by replacing the comparative values with those - that the chosen grouping method expects. - :param group_method: str name of the grouping method that will be used. - :param img_list: image list that has been sorted by one of the sort - methods. - :return: img_list but with the comparative values that the chosen - grouping method expects. - """ - input_dir = self.args.input_dir - logger.info("Preparing to group...") - if group_method == 'group_blur': - temp_list = [[img, self.estimate_blur(cv2.imread(img))] - for img in - tqdm(self.find_images(input_dir), - desc="Reloading", - file=sys.stdout)] - elif group_method == 'group_face': - temp_list = [ - [img, face_recognition.face_encodings(cv2.imread(img))] - for img in tqdm(self.find_images(input_dir), - desc="Reloading", - file=sys.stdout)] - elif group_method == 'group_face_cnn': - self.launch_aligner() - temp_list = [] - for img in tqdm(self.find_images(input_dir), - desc="Reloading", - file=sys.stdout): - landmarks = self.get_landmarks(img) - temp_list.append([img, np.array(landmarks) - if landmarks - else np.zeros((68, 2))]) - elif group_method == 'group_face_yaw': - self.launch_aligner() - temp_list = [] - for img in tqdm(self.find_images(input_dir), - desc="Reloading", - file=sys.stdout): - landmarks = self.get_landmarks(img) - temp_list.append( - [img, - self.calc_landmarks_face_yaw(np.array(landmarks))]) - elif group_method == 'group_hist': - temp_list = [ - [img, - cv2.calcHist([cv2.imread(img)], [0], None, [256], [0, 256])] - for img in - tqdm(self.find_images(input_dir), - desc="Reloading", - file=sys.stdout) - ] - else: - raise ValueError("{} group_method not found.".format(group_method)) - - return self.splice_lists(img_list, temp_list) - - @staticmethod - def splice_lists(sorted_list, new_vals_list): - """ - This method replaces the value at index 1 in each sub-list in the - sorted_list with the value that is calculated for the same img_path, - but found in new_vals_list. - - Format of lists: [[img_path, value], [img_path2, value2], ...] - - :param sorted_list: list that has been sorted by one of the sort - methods. - :param new_vals_list: list that has been loaded by a different method - than the sorted_list. - :return: list that is sorted in the same way as the input sorted list - but the values corresponding to each image are from new_vals_list. - """ - new_list = [] - # Make new list of just image paths to serve as an index - val_index_list = [i[0] for i in new_vals_list] - for i in tqdm(range(len(sorted_list)), - desc="Splicing", - file=sys.stdout): - current_image = sorted_list[i][0] - new_val_index = val_index_list.index(current_image) - new_list.append([current_image, new_vals_list[new_val_index][1]]) - - return new_list - - @staticmethod - def find_images(input_dir): - """ Return list of images at specified location """ - result = [] - extensions = [".jpg", ".png", ".jpeg"] - for root, _, files in os.walk(input_dir): - for file in files: - if os.path.splitext(file)[1].lower() in extensions: - result.append(os.path.join(root, file)) - return result - - @staticmethod - def estimate_blur(image_file): - """ - Estimate the amount of blur an image has - with the variance of the Laplacian. - Normalize by pixel number to offset the effect - of image size on pixel gradients & variance - """ - image = cv2.imread(image_file) - if image.ndim == 3: - image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - blur_map = cv2.Laplacian(image, cv2.CV_32F) - score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1]) - return score - - @staticmethod - def calc_landmarks_face_pitch(flm): - """ UNUSED - Calculate the amount of pitch in a face """ - var_t = ((flm[6][1] - flm[8][1]) + (flm[10][1] - flm[8][1])) / 2.0 - var_b = flm[8][1] - return var_b - var_t - - @staticmethod - def calc_landmarks_face_yaw(flm): - """ Calculate the amount of yaw in a face """ - var_l = ((flm[27][0] - flm[0][0]) - + (flm[28][0] - flm[1][0]) - + (flm[29][0] - flm[2][0])) / 3.0 - var_r = ((flm[16][0] - flm[27][0]) - + (flm[15][0] - flm[28][0]) - + (flm[14][0] - flm[29][0])) / 3.0 - return var_r - var_l - - @staticmethod - def set_process_file_method(log_changes, keep_original): - """ - Assigns the final file processing method based on whether changes are - being logged and whether the original files are being kept in the - input directory. - Relevant cli arguments: -k, -l - :return: function reference - """ - if log_changes: - if keep_original: - def process_file(src, dst, changes): - """ Process file method if logging changes - and keeping original """ - copyfile(src, dst) - changes[src] = dst - - else: - def process_file(src, dst, changes): - """ Process file method if logging changes - and not keeping original """ - os.rename(src, dst) - changes[src] = dst - - else: - if keep_original: - def process_file(src, dst, changes): # pylint: disable=unused-argument - """ Process file method if not logging changes - and keeping original """ - copyfile(src, dst) - - else: - def process_file(src, dst, changes): # pylint: disable=unused-argument - """ Process file method if not logging changes - and not keeping original """ - os.rename(src, dst) - return process_file - - @staticmethod - def set_renaming_method(log_changes): - """ Set the method for renaming files """ - if log_changes: - def renaming(src, output_dir, i, changes): - """ Rename files method if logging changes """ - src_basename = os.path.basename(src) - - __src = os.path.join(output_dir, - '{:05d}_{}'.format(i, src_basename)) - dst = os.path.join( - output_dir, - '{:05d}{}'.format(i, os.path.splitext(src_basename)[1])) - changes[src] = dst - return __src, dst - else: - def renaming(src, output_dir, i, changes): # pylint: disable=unused-argument - """ Rename files method if not logging changes """ - src_basename = os.path.basename(src) - - src = os.path.join(output_dir, - '{:05d}_{}'.format(i, src_basename)) - dst = os.path.join( - output_dir, - '{:05d}{}'.format(i, os.path.splitext(src_basename)[1])) - return src, dst - return renaming - - @staticmethod - def get_avg_score_hist(img1, references): - """ Return the average histogram score between a face and - reference image """ - scores = [] - for img2 in references: - score = cv2.compareHist(img1, img2, cv2.HISTCMP_BHATTACHARYYA) - scores.append(score) - return sum(scores) / len(scores) - - @staticmethod - def get_avg_score_faces(f1encs, references): - """ Return the average similarity score between a face and - reference image """ - scores = [] - for f2encs in references: - score = face_recognition.face_distance(f1encs, f2encs)[0] - scores.append(score) - return sum(scores) / len(scores) - - @staticmethod - def get_avg_score_faces_cnn(fl1, references): - """ Return the average CNN similarity score - between a face and reference image """ - scores = [] - for fl2 in references: - score = np.sum(np.absolute((fl2 - fl1).flatten())) - scores.append(score) - return sum(scores) / len(scores) - - -def bad_args(args): # pylint: disable=unused-argument - """ Print help on bad arguments """ - PARSER.print_help() - exit(0) - - -if __name__ == "__main__": - __WARNING_STRING = "Important: face-cnn method will cause an error when " - __WARNING_STRING += "this tool is called directly instead of through the " - __WARNING_STRING += "tools.py command script." - print(__WARNING_STRING) - print("Images sort tool.\n") - - PARSER = FullHelpArgumentParser() - SUBPARSER = PARSER.add_subparsers() - SORT = cli.SortArgs( - SUBPARSER, "sort", "Sort images using various methods.") - PARSER.set_defaults(func=bad_args) - ARGUMENTS = PARSER.parse_args() - ARGUMENTS.func(ARGUMENTS) diff --git a/tools/sort/__init__.py b/tools/sort/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/sort/cli.py b/tools/sort/cli.py new file mode 100644 index 0000000000..607bd68a2e --- /dev/null +++ b/tools/sort/cli.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" Command Line Arguments for tools """ +import gettext + +from lib.cli.args import FaceSwapArgs +from lib.cli.actions import DirFullPaths, SaveFileFullPaths, Radio, Slider +from lib.utils import get_module_objects + + +# pylint:disable=duplicate-code +# # LOCALES +_LANG = gettext.translation("tools.sort.cli", localedir="locales", fallback=True) +_ = _LANG.gettext + + +_HELPTEXT = _("This command lets you sort images using various methods.") +_SORT_METHODS = ( + "none", "blur", "blur-fft", "distance", "face", "face-cnn", "face-cnn-dissim", + "yaw", "pitch", "roll", "hist", "hist-dissim", "color-black", "color-gray", "color-luma", + "color-green", "color-orange", "size") + +_GPTHRESHOLD = _(" Adjust the '-t' ('--threshold') parameter to control the strength of grouping.") +_GPCOLOR = _(" Adjust the '-b' ('--bins') parameter to control the number of bins for grouping. " + "Each image is allocated to a bin by the percentage of color pixels that appear in " + "the image.") +_GPDEGREES = _(" Adjust the '-b' ('--bins') parameter to control the number of bins for grouping. " + "Each image is allocated to a bin by the number of degrees the face is orientated " + "from center.") +_GPLINEAR = _(" Adjust the '-b' ('--bins') parameter to control the number of bins for grouping. " + "The minimum and maximum values are taken for the chosen sort metric. The bins " + "are then populated with the results from the group sorting.") +_METHOD_TEXT = { + "blur": _("faces by blurriness."), + "blur-fft": _("faces by fft filtered blurriness."), + "distance": _("faces by the estimated distance of the alignments from an 'average' face. This " + "can be useful for eliminating misaligned faces. Sorts from most like an " + "average face to least like an average face."), + "face": _("faces using VGG Face2 by face similarity. This uses a pairwise clustering " + "algorithm to check the distances between 512 features on every face in your set " + "and order them appropriately."), + "face-cnn": _("faces by their landmarks."), + "face-cnn-dissim": _("Like 'face-cnn' but sorts by dissimilarity."), + "yaw": _("faces by Yaw (rotation left to right)."), + "pitch": _("faces by Pitch (rotation up and down)."), + "roll": _("faces by Roll (rotation). Aligned faces should have a roll value close to zero. " + "The further the Roll value from zero the higher liklihood the face is misaligned."), + "hist": _("faces by their color histogram."), + "hist-dissim": _("Like 'hist' but sorts by dissimilarity."), + "color-gray": _("images by the average intensity of the converted grayscale color channel."), + "color-black": _("images by their number of black pixels. Useful when faces are near borders " + "and a large part of the image is black."), + "color-luma": _("images by the average intensity of the converted Y color channel. Bright " + "lighting and oversaturated images will be ranked first."), + "color-green": _("images by the average intensity of the converted Cg color channel. Green " + "images will be ranked first and red images will be last."), + "color-orange": _("images by the average intensity of the converted Co color channel. Orange " + "images will be ranked first and blue images will be last."), + "size": _("images by their size in the original frame. Faces further from the camera and from " + "lower resolution sources will be sorted first, whilst faces closer to the camera " + "and from higher resolution sources will be sorted last.")} + +_BIN_TYPES = [ + (("face", "face-cnn", "face-cnn-dissim", "hist", "hist-dissim"), _GPTHRESHOLD), + (("color-black", "color-gray", "color-luma", "color-green", "color-orange"), _GPCOLOR), + (("yaw", "pitch", "roll"), _GPDEGREES), + (("blur", "blur-fft", "distance", "size"), _GPLINEAR)] +_sort_help = "" +_GROUP_HELP = "" + +for method in sorted(_METHOD_TEXT): + _sort_help += f"\nL|{method}: {_('Sort')} {_METHOD_TEXT[method]}" + _GROUP_HELP += (f"\nL|{method}: {_('Group')} {_METHOD_TEXT[method]} " + f"{next((x[1] for x in _BIN_TYPES if method in x[0]), '')}") + + +class SortArgs(FaceSwapArgs): + """ Class to parse the command line arguments for sort tool """ + + @staticmethod + def get_info(): + """ Return command information """ + return _("Sort faces using a number of different techniques") + + @staticmethod + def get_argument_list(): + """ Put the arguments in a list so that they are accessible from both argparse and gui """ + argument_list = [] + argument_list.append({ + "opts": ('-i', '--input'), + "action": DirFullPaths, + "dest": "input_dir", + "group": _("data"), + "help": _("Input directory of aligned faces."), + "required": True}) + argument_list.append({ + "opts": ('-o', '--output'), + "action": DirFullPaths, + "dest": "output_dir", + "group": _("data"), + "help": _( + "Output directory for sorted aligned faces. If not provided and 'keep' is " + "selected then a new folder called 'sorted' will be created within the input " + "folder to house the output. If not provided and 'keep' is not selected then the " + "images will be sorted in-place, overwriting the original contents of the " + "'input_dir'")}) + argument_list.append({ + "opts": ("-B", "--batch-mode"), + "action": "store_true", + "dest": "batch_mode", + "default": False, + "group": _("data"), + "help": _( + "R|If selected then the input_dir should be a parent folder containing multiple " + "folders of faces you wish to sort. The faces will be output to separate sub-" + "folders in the output_dir")}) + argument_list.append({ + "opts": ('-s', '--sort-by'), + "action": Radio, + "type": str, + "choices": _SORT_METHODS, + "dest": 'sort_method', + "group": _("sort settings"), + "default": "face", + "help": _( + "R|Choose how images are sorted. Selecting a sort method gives the images a new " + "filename based on the order the image appears within the given method." + "\nL|'none': Don't sort the images. When a 'group-by' method is selected, " + "selecting 'none' means that the files will be moved/copied into their respective " + "bins, but the files will keep their original filenames. Selecting 'none' for " + "both 'sort-by' and 'group-by' will do nothing" + _sort_help + "\nDefault: face")}) + argument_list.append({ + "opts": ('-g', '--group-by'), + "action": Radio, + "type": str, + "choices": _SORT_METHODS, + "dest": 'group_method', + "group": _("group settings"), + "default": "none", + "help": _( + "R|Selecting a group by method will move/copy files into numbered bins based on " + "the selected method." + "\nL|'none': Don't bin the images. Folders will be sorted by the selected 'sort-" + "by' but will not be binned, instead they will be sorted into a single folder. " + "Selecting 'none' for both 'sort-by' and 'group-by' will do nothing" + + _GROUP_HELP + "\nDefault: none")}) + argument_list.append({ + "opts": ('-k', '--keep'), + "action": 'store_true', + "dest": 'keep_original', + "default": False, + "group": _("data"), + "help": _( + "Whether to keep the original files in their original location. Choosing a 'sort-" + "by' method means that the files have to be renamed. Selecting 'keep' means that " + "the original files will be kept, and the renamed files will be created in the " + "specified output folder. Unselecting keep means that the original files will be " + "moved and renamed based on the selected sort/group criteria.")}) + argument_list.append({ + "opts": ('-t', '--threshold'), + "action": Slider, + "min_max": (-1.0, 10.0), + "rounding": 2, + "type": float, + "dest": 'threshold', + "group": _("group settings"), + "default": -1.0, + "help": _( + "R|Float value. Minimum threshold to use for grouping comparison with 'face-cnn' " + "'hist' and 'face' methods." + "\nThe lower the value the more discriminating the grouping is. Leaving -1.0 will " + "allow Faceswap to choose the default value." + "\nL|For 'face-cnn' 7.2 should be enough, with 4 being very discriminating. " + "\nL|For 'hist' 0.3 should be enough, with 0.2 being very discriminating. " + "\nL|For 'face' between 0.1 (more bins) to 0.5 (fewer bins) should be about right." + "\nBe careful setting a value that's too extrene in a directory with many images, " + "as this could result in a lot of folders being created. Defaults: face-cnn 7.2, " + "hist 0.3, face 0.25")}) + argument_list.append({ + "opts": ('-b', '--bins'), + "action": Slider, + "min_max": (1, 100), + "rounding": 1, + "type": int, + "dest": 'num_bins', + "group": _("group settings"), + "default": 5, + "help": _( + "R|Integer value. Used to control the number of bins created for grouping by: any " + "'blur' methods, 'color' methods or 'face metric' methods ('distance', 'size') " + "and 'orientation; methods ('yaw', 'pitch'). For any other grouping " + "methods see the '-t' ('--threshold') option." + "\nL|For 'face metric' methods the bins are filled, according the the " + "distribution of faces between the minimum and maximum chosen metric." + "\nL|For 'color' methods the number of bins represents the divider of the " + "percentage of colored pixels. Eg. For a bin number of '5': The first folder will " + "have the faces with 0%% to 20%% colored pixels, second 21%% to 40%%, etc. Any " + "empty bins will be deleted, so you may end up with fewer bins than selected." + "\nL|For 'blur' methods folder 0 will be the least blurry, while the last folder " + "will be the blurriest." + "\nL|For 'orientation' methods the number of bins is dictated by how much 180 " + "degrees is divided. Eg. If 18 is selected, then each folder will be a 10 degree " + "increment. Folder 0 will contain faces looking the most to the left/down whereas " + "the last folder will contain the faces looking the most to the right/up. NB: " + "Some bins may be empty if faces do not fit the criteria. \nDefault value: 5")}) + argument_list.append({ + "opts": ('-l', '--log-changes'), + "action": 'store_true', + "group": _("settings"), + "default": False, + "help": _( + "Logs file renaming changes if grouping by renaming, or it logs the file copying/" + "movement if grouping by folders. If no log file is specified with '--log-file', " + "then a 'sort_log.json' file will be created in the input directory.")}) + argument_list.append({ + "opts": ('-f', '--log-file'), + "action": SaveFileFullPaths, + "filetypes": "alignments", + "group": _("settings"), + "dest": 'log_file_path', + "default": 'sort_log.json', + "help": _( + "Specify a log file to use for saving the renaming or grouping information. If " + "specified extension isn't 'json' or 'yaml', then json will be used as the " + "serializer, with the supplied filename. Default: sort_log.json")}) + return argument_list + + +__all__ = get_module_objects(__name__) diff --git a/tools/sort/sort.py b/tools/sort/sort.py new file mode 100644 index 0000000000..80dfc9566e --- /dev/null +++ b/tools/sort/sort.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +""" +A tool that allows for sorting and grouping images in different ways. +""" +from __future__ import annotations +import logging +import os +import sys +import typing as T + +from argparse import Namespace +from shutil import copyfile, rmtree + +from tqdm import tqdm + +# faceswap imports +from lib.serializer import Serializer, get_serializer_from_filename +from lib.utils import get_module_objects, handle_deprecated_cliopts + +from .sort_methods import SortBlur, SortColor, SortFace, SortHistogram, SortMultiMethod +from .sort_methods_aligned import SortDistance, SortFaceCNN, SortPitch, SortSize, SortYaw, SortRoll + +if T.TYPE_CHECKING: + from .sort_methods import SortMethod + +logger = logging.getLogger(__name__) + + +class Sort(): + """ Sorts folders of faces based on input criteria + + Wrapper for the sort process to run in either batch mode or single use mode + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The arguments to be passed to the extraction process as generated from Faceswap's command + line arguments + """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing: %s (args: %s)", self.__class__.__name__, arguments) + self._args = handle_deprecated_cliopts(arguments) + self._input_locations = self._get_input_locations() + logger.debug("Initialized: %s", self.__class__.__name__) + + def _get_input_locations(self) -> list[str]: + """ Obtain the full path to input locations. Will be a list of locations if batch mode is + selected, or a containing a single location if batch mode is not selected. + + Returns + ------- + list: + The list of input location paths + """ + if not self._args.batch_mode: + return [self._args.input_dir] + + retval = [os.path.join(self._args.input_dir, fname) + for fname in os.listdir(self._args.input_dir) + if os.path.isdir(os.path.join(self._args.input_dir, fname))] + logger.debug("Input locations: %s", retval) + return retval + + def _output_for_input(self, input_location: str) -> str: + """ Obtain the path to an output folder for faces for a given input location. + + If not running in batch mode, then the user supplied output location will be returned, + otherwise a sub-folder within the user supplied output location will be returned based on + the input filename + + Parameters + ---------- + input_location: str + The full path to an input video or folder of images + """ + if not self._args.batch_mode or self._args.output_dir is None: + return self._args.output_dir + + retval = os.path.join(self._args.output_dir, os.path.basename(input_location)) + logger.debug("Returning output: '%s' for input: '%s'", retval, input_location) + return retval + + def process(self) -> None: + """ The entry point for triggering the Sort Process. + + Should only be called from :class:`lib.cli.launcher.ScriptExecutor` + """ + logger.info('Starting, this may take a while...') + inputs = self._input_locations + if self._args.batch_mode: + logger.info("Batch mode selected processing: %s", self._input_locations) + for job_no, location in enumerate(self._input_locations): + if self._args.batch_mode: + logger.info("Processing job %s of %s: '%s'", job_no + 1, len(inputs), location) + arguments = Namespace(**self._args.__dict__) + arguments.input_dir = location + arguments.output_dir = self._output_for_input(location) + else: + arguments = self._args + sort = _Sort(arguments) + sort.process() + + +class _Sort(): + """ Sorts folders of faces based on input criteria """ + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: arguments: %s", self.__class__.__name__, arguments) + self._processes = {"blur": SortBlur, + "blur_fft": SortBlur, + "distance": SortDistance, + "yaw": SortYaw, + "pitch": SortPitch, + "roll": SortRoll, + "size": SortSize, + "face": SortFace, + "face_cnn": SortFaceCNN, + "face_cnn_dissim": SortFaceCNN, + "hist": SortHistogram, + "hist_dissim": SortHistogram, + "color_black": SortColor, + "color_gray": SortColor, + "color_luma": SortColor, + "color_green": SortColor, + "color_orange": SortColor} + + self._args = self._parse_arguments(arguments) + self._changes: dict[str, str] = {} + self.serializer: Serializer | None = None + + if arguments.log_changes: + self.serializer = get_serializer_from_filename(arguments.log_file_path) + + self._sorter = self._get_sorter() + logger.debug("Initialized %s", self.__class__.__name__) + + def _set_output_folder(self, arguments): + """ Set the output folder correctly if it has not been provided + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + + Returns + ------- + :class:`argparse.Namespace` + The command line arguments with output folder correctly set + """ + logger.debug("setting output folder: %s", arguments.output_dir) + input_dir = arguments.input_dir + output_dir = arguments.output_dir + sort_method = arguments.sort_method + group_method = arguments.group_method + + needs_rename = sort_method != "none" and group_method == "none" + + if needs_rename and arguments.keep_original and (not output_dir or + output_dir == input_dir): + output_dir = os.path.join(input_dir, "sorted") + logger.warning("No output folder selected, but files need renaming. " + "Outputting to: '%s'", output_dir) + elif not output_dir: + output_dir = input_dir + logger.warning("No output folder selected, files will be sorted in place in: '%s'", + output_dir) + + arguments.output_dir = output_dir + logger.debug("Set output folder: %s", arguments.output_dir) + return arguments + + def _parse_arguments(self, arguments): + """ Parse the arguments and update/format relevant choices + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + + Returns + ------- + :class:`argparse.Namespace` + The formatted command line arguments + """ + logger.debug("Cleaning arguments: %s", arguments) + if arguments.sort_method == "none" and arguments.group_method == "none": + logger.error("Both sort-by and group-by are 'None'. Nothing to do.") + sys.exit(1) + + # Prepare sort, group and final process method names + arguments.sort_method = arguments.sort_method.lower().replace("-", "_") + arguments.group_method = arguments.group_method.lower().replace("-", "_") + + arguments = self._set_output_folder(arguments) + + if arguments.log_changes and arguments.log_file_path == "sort_log.json": + # Assign default sort_log.json value if user didn't specify one + arguments.log_file_path = os.path.join(self._args.input_dir, 'sort_log.json') + + logger.debug("Cleaned arguments: %s", arguments) + return arguments + + def _get_sorter(self) -> SortMethod: + """ Obtain a sorter/grouper combo for the selected sort/group by options + + Returns + ------- + :class:`SortMethod` + The sorter or combined sorter for sorting and grouping based on user selections + """ + sort_method = self._args.sort_method + group_method = self._args.group_method + + sort_method = group_method if sort_method == "none" else sort_method + sorter = self._processes[sort_method](self._args, + is_group=self._args.sort_method == "none") + + if sort_method != "none" and group_method != "none" and group_method != sort_method: + grouper = self._processes[group_method](self._args, is_group=True) + retval = SortMultiMethod(self._args, sorter, grouper) + logger.debug("Got sorter + grouper: %s (%s, %s)", retval, sorter, grouper) + + else: + + retval = sorter + + logger.debug("Final sorter: %s", retval) + return retval + + def _write_to_log(self, changes): + """ Write the changes to log file """ + logger.info("Writing sort log to: '%s'", self._args.log_file_path) + self.serializer.save(self._args.log_file_path, changes) + + def process(self) -> None: + """ Main processing function of the sort tool + + This method dynamically assigns the functions that will be used to run + the core process of sorting, optionally grouping, renaming/moving into + folders. After the functions are assigned they are executed. + """ + if self._args.group_method != "none": + # Check if non-dissimilarity sort method and group method are not the same + self._output_groups() + else: + self._output_non_grouped() + + if self._args.log_changes: + self._write_to_log(self._changes) + + logger.info("Done.") + + def _sort_file(self, source: str, destination: str) -> None: + """ Copy or move a file based on whether 'keep original' has been selected and log changes + if required. + + Parameters + ---------- + source: str + The full path to the source file that is being sorted + destination: str + The full path to where the source file should be moved/renamed + """ + try: + if self._args.keep_original: + copyfile(source, destination) + else: + os.rename(source, destination) + except FileNotFoundError as err: + logger.error("Failed to sort '%s' to '%s'. Original error: %s", + source, destination, str(err)) + + if self._args.log_changes: + self._changes[source] = destination + + def _output_groups(self) -> None: + """ Move the files to folders. + + Obtains the bins and original filenames from :attr:`_sorter` and outputs into appropriate + bins in the output location + """ + is_rename = self._args.sort_method != "none" + + logger.info("Creating %s group folders in '%s'.", + len(self._sorter.binned), self._args.output_dir) + bin_names = [f"_{b}" for b in self._sorter.bin_names] + if is_rename: + bin_names = [f"{name}_by_{self._args.sort_method}" for name in bin_names] + for name in bin_names: + folder = os.path.join(self._args.output_dir, name) + if os.path.exists(folder): + rmtree(folder) + os.makedirs(folder) + + description = f"{'Copying' if self._args.keep_original else 'Moving'} into groups" + description += " and renaming" if is_rename else "" + + pbar = tqdm(range(len(self._sorter.sorted_filelist)), + desc=description, + file=sys.stdout, + leave=False) + idx = 0 + for bin_id, bin_ in enumerate(self._sorter.binned): + pbar.set_description(f"{description}: Bin {bin_id + 1} of {len(self._sorter.binned)}") + output_path = os.path.join(self._args.output_dir, bin_names[bin_id]) + if not bin_: + logger.debug("Removing empty bin: %s", output_path) + os.rmdir(output_path) + for source in bin_: + basename = os.path.basename(source) + dst_name = f"{idx:06d}_{basename}" if is_rename else basename + dest = os.path.join(output_path, dst_name) + self._sort_file(source, dest) + idx += 1 + pbar.update(1) + + # Output methods + def _output_non_grouped(self) -> None: + """ Output non-grouped files. + + These are files which are sorted but not binned, so just the filename gets updated + """ + output_dir = self._args.output_dir + os.makedirs(output_dir, exist_ok=True) + + description = f"{'Copying' if self._args.keep_original else 'Moving'} and renaming" + for idx, source in enumerate(tqdm(self._sorter.sorted_filelist, + desc=description, + file=sys.stdout, + leave=False)): + dest = os.path.join(output_dir, f"{idx:06d}_{os.path.basename(source)}") + + self._sort_file(source, dest) + + +__all__ = get_module_objects(__name__) diff --git a/tools/sort/sort_methods.py b/tools/sort/sort_methods.py new file mode 100644 index 0000000000..273f7fe8be --- /dev/null +++ b/tools/sort/sort_methods.py @@ -0,0 +1,1116 @@ +#!/usr/bin/env python3 +""" Sorting methods for the sorting tool. + +All sorting methods inherit from :class:`SortMethod` and control functions for scorting one item, +sorting a full list of scores and binning based on those sorted scores. +""" +from __future__ import annotations +import logging +import operator +import sys +import typing as T + +from collections.abc import Generator + +import cv2 +import numpy as np +from tqdm import tqdm + +from lib.align import AlignedFace, DetectedFace, LandmarkType +from lib.image import FacesLoader, ImagesLoader, read_image_meta_batch, update_existing_metadata +from lib.utils import get_module_objects, FaceswapError +from plugins.extract.recognition.vgg_face2 import Cluster, Recognition as VGGFace + +if T.TYPE_CHECKING: + from argparse import Namespace + from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderSourceDict + +logger = logging.getLogger(__name__) + + +ImgMetaType: T.TypeAlias = Generator[tuple[str, + np.ndarray | None, + T.Union["PNGHeaderAlignmentsDict", None]], None, None] + + +class InfoLoader(): + """ Loads aligned faces and/or face metadata + + Parameters + ---------- + input_dir: str + Full path to containing folder of faces to be supported + loader_type: ["face", "meta", "all"] + Dictates the type of iterator that will be used. "face" just loads the image with the + filename, "meta" just loads the image alignment data with the filename. "all" loads + the image and the alignment data with the filename + """ + def __init__(self, + input_dir: str, + info_type: T.Literal["face", "meta", "all"]) -> None: + logger.debug("Initializing: %s (input_dir: %s, info_type: %s)", + self.__class__.__name__, input_dir, info_type) + self._info_type = info_type + self._iterator = None + self._description = "Reading image statistics..." + self._loader = ImagesLoader(input_dir) if info_type == "face" else FacesLoader(input_dir) + self._cached_source_data: dict[str, PNGHeaderSourceDict] = {} + if self._loader.count == 0: + logger.error("No images to process in location: '%s'", input_dir) + sys.exit(1) + + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def filelist_count(self) -> int: + """ int: The number of files to be processed """ + return len(self._loader.file_list) + + def _get_iterator(self) -> ImgMetaType: + """ Obtain the iterator for the selected :attr:`info_type`. + + Returns + ------- + generator + The correct generator for the given info_type + """ + if self._info_type == "all": + return self._full_data_reader() + if self._info_type == "meta": + return self._metadata_reader() + return self._image_data_reader() + + def __call__(self) -> ImgMetaType: + """ Return the selected iterator + + The resulting generator: + + Yields + ------ + filename: str + The filename that has been read + image: :class:`numpy.ndarray or ``None`` + The aligned face image loaded from disk for 'face' and 'all' info_types + otherwise ``None`` + alignments: dict or ``None`` + The alignments dict for 'all' and 'meta' infor_types otherwise ``None`` + """ + iterator = self._get_iterator() + return iterator + + def _get_alignments(self, + filename: str, + metadata: dict[str, T.Any]) -> PNGHeaderAlignmentsDict | None: + """ Obtain the alignments from a PNG Header. + + The other image metadata is cached locally in case a sort method needs to write back to the + PNG header + + Parameters + ---------- + filename: str + Full path to the image PNG file + metadata: dict + The header data from a PNG file + + Returns + ------- + dict or ``None`` + The alignments dictionary from the PNG header, if it exists, otherwise ``None`` + """ + if not metadata or not metadata.get("alignments") or not metadata.get("source"): + return None + self._cached_source_data[filename] = metadata["source"] + return metadata["alignments"] + + def _metadata_reader(self) -> ImgMetaType: + """ Load metadata from saved aligned faces + + Yields + ------ + filename: str + The filename that has been read + image: None + This will always be ``None`` with the metadata reader + alignments: dict or ``None`` + The alignment data for the given face or ``None`` if no alignments found + """ + for filename, metadata in tqdm(read_image_meta_batch(self._loader.file_list), + total=self._loader.count, + desc=self._description, + leave=False): + alignments = self._get_alignments(filename, metadata.get("itxt", {})) + yield filename, None, alignments + + def _full_data_reader(self) -> ImgMetaType: + """ Load the image and metadata from a folder of aligned faces + + Yields + ------ + filename: str + The filename that has been read + image: :class:`numpy.ndarray + The aligned face image loaded from disk + alignments: dict or ``None`` + The alignment data for the given face or ``None`` if no alignments found + """ + for filename, image, metadata in tqdm(self._loader.load(), + desc=self._description, + total=self._loader.count, + leave=False): + alignments = self._get_alignments(filename, metadata) + yield filename, image, alignments + + def _image_data_reader(self) -> ImgMetaType: + """ Just loads the images with their filenames + + Yields + ------ + filename: str + The filename that has been read + image: :class:`numpy.ndarray + The aligned face image loaded from disk + alignments: ``None`` + Alignments will always be ``None`` with the image data reader + """ + for filename, image in tqdm(self._loader.load(), + desc=self._description, + total=self._loader.count, + leave=False): + yield filename, image, None + + def update_png_header(self, filename: str, alignments: PNGHeaderAlignmentsDict) -> None: + """ Update the PNG header of the given file with the given alignments. + + NB: Header information can only be updated if the face is already on at least alignment + version 2.2. If below this version, then the header is not updated + + + Parameters + ---------- + filename: str + Full path to the PNG file to update + alignments: dict + The alignments to update into the PNG header + """ + vers = self._cached_source_data[filename]["alignments_version"] + if vers < 2.2: + return + + self._cached_source_data[filename]["alignments_version"] = 2.3 if vers == 2.2 else vers + header = {"alignments": alignments, "source": self._cached_source_data[filename]} + update_existing_metadata(filename, header) + + +class SortMethod(): + """ Parent class for sort methods. All sort methods should inherit from this class + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + loader_type: ["face", "meta", "all"] + The type of image loader to use. "face" just loads the image with the filename, "meta" + just loads the image alignment data with the filename. "all" loads the image and the + alignment data with the filename + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + _log_mask_once = False + + def __init__(self, + arguments: Namespace, + loader_type: T.Literal["face", "meta", "all"] = "meta", + is_group: bool = False) -> None: + logger.debug("Initializing %s: loader_type: '%s' is_group: %s, arguments: %s", + self.__class__.__name__, loader_type, is_group, arguments) + self._is_group = is_group + self._log_once = True + self._method = arguments.group_method if self._is_group else arguments.sort_method + + self._num_bins: int = arguments.num_bins + self._bin_names: list[str] = [] + + self._loader_type = loader_type + self._iterator = self._get_file_iterator(arguments.input_dir) + + self._result: list[tuple[str, float | np.ndarray]] = [] + self._binned: list[list[str]] = [] + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def loader_type(self) -> T.Literal["face", "meta", "all"]: + """ ["face", "meta", "all"]: The loader that this sorter uses """ + return self._loader_type + + @property + def binned(self) -> list[list[str]]: + """ list: List of bins (list) containing the filenames belonging to the bin. The binning + process is called when this property is first accessed""" + if not self._binned: + self._binned = self._binning() + logger.debug({f"bin_{idx}": len(bin_) for idx, bin_ in enumerate(self._binned)}) + return self._binned + + @property + def sorted_filelist(self) -> list[str]: + """ list: List of sorted filenames for given sorter in a single list. The sort process is + called when this property is first accessed """ + if not self._result: + self._sort_filelist() + retval = [item[0] for item in self._result] + logger.debug(retval) + else: + retval = [item[0] for item in self._result] + return retval + + @property + def bin_names(self) -> list[str]: + """ list: The name of each created bin, if they exist, otherwise an empty list """ + return self._bin_names + + def _get_file_iterator(self, input_dir: str) -> InfoLoader: + """ Override for method specific iterators. + + Parameters + ---------- + input_dir: str + Full path to containing folder of faces to be supported + + Returns + ------- + :class:`InfoLoader` + The correct InfoLoader iterator for the current sort method + """ + return InfoLoader(input_dir, self.loader_type) + + def _sort_filelist(self) -> None: + """ Call the sort method's logic to populate the :attr:`_results` attribute. + + Put logic for scoring an individual frame in in :attr:`score_image` of the child + + Returns + ------- + list + The sorted file. A list of tuples with the filename in the first position and score in + the second position + """ + for filename, image, alignments in self._iterator(): + self.score_image(filename, image, alignments) + + self.sort() + logger.debug("sorted list: %s", + [r[0] if isinstance(r, (tuple, list)) else r for r in self._result]) + + @classmethod + def _get_unique_labels(cls, numbers: np.ndarray) -> list[str]: + """ For a list of threshold values for displaying in the bin name, get the lowest number of + decimal figures (down to int) required to have a unique set of folder names and return the + formatted numbers. + + Parameters + ---------- + numbers: :class:`numpy.ndarray` + The list of floating point threshold numbers being used as boundary points + + Returns + ------- + list[str] + The string formatted numbers at the lowest precision possible to represent them + uniquely + """ + i = 0 + while True: + rounded = [round(n, i) for n in numbers] + if len(set(rounded)) == len(numbers): + break + i += 1 + + if i == 0: + retval = [str(int(n)) for n in rounded] + else: + pre, post = zip(*[str(r).split(".") for r in rounded]) + rpad = max(len(x) for x in post) + retval = [f"{str(int(left))}.{str(int(right)).ljust(rpad, '0')}" + for left, right in zip(pre, post)] + logger.debug("rounded values: %s, formatted labels: %s", rounded, retval) + return retval + + def _binning_linear_threshold(self, units: str = "", multiplier: int = 1) -> list[list[str]]: + """ Standard linear binning method for binning by threshold. + + The minimum and maximum result from :attr:`_result` are taken, A range is created between + these min and max values and is divided to get the number of bins to hold the data + + Parameters + ---------- + units, str, optional + The units to use for the bin name for displaying the threshold values. This this should + correspond the value in position 1 of :attr:`_result`. + Default: "" (no units) + multiplier: int, optional + The amount to multiply the contents in position 1 of :attr:`_results` for displaying in + the bin folder name + + Returns + ------- + list + List of bins of filenames + """ + sizes = np.array([i[1] for i in self._result]) + thresholds = np.linspace(sizes.min(), sizes.max(), self._num_bins + 1) + labels = self._get_unique_labels(thresholds * multiplier) + + self._bin_names = [f"{self._method}_{idx:03d}_" + f"{labels[idx]}{units}_to_{labels[idx + 1]}{units}" + for idx in range(self._num_bins)] + + bins: list[list[str]] = [[] for _ in range(self._num_bins)] + for filename, result in self._result: + bin_idx = next(bin_id for bin_id, thresh in enumerate(thresholds) + if result <= thresh) - 1 + bins[bin_idx].append(filename) + + return bins + + def _binning(self) -> list[list[str]]: + """ Called when :attr:`binning` is first accessed. Checks if sorting has been done, if not + triggers it, then does binning + + Returns + ------- + list + List of bins of filenames + """ + if not self._result: + self._sort_filelist() + retval = self.binning() + + if not self._bin_names: + self._bin_names = [f"{self._method}_{i:03d}" for i in range(len(retval))] + + logger.debug({bin_name: len(bin_) for bin_name, bin_ in zip(self._bin_names, retval)}) + + return retval + + def sort(self) -> None: + """ Override for method specific logic for sorting the loaded statistics + + The scored list :attr:`_result` should be sorted in place + """ + raise NotImplementedError() + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Override for sort method's specificic logic. This method should be executed to get a + single score from a single image and add the result to :attr:`_result` + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` or ``None`` + A face image loaded from disk or ``None`` + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + raise NotImplementedError() + + def binning(self) -> list[list[str]]: + """ Group into bins by their sorted score. Override for method specific binning techniques. + + Binning takes the results from :attr:`_result` compiled during :func:`_sort_filelist` and + organizes into bins for output. + + Returns + ------- + list + List of bins of filenames + """ + raise NotImplementedError() + + @classmethod + def _mask_face(cls, image: np.ndarray, alignments: PNGHeaderAlignmentsDict) -> np.ndarray: + """ Function for applying the mask to an aligned face if both the face image and alignment + data are available. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The aligned face image loaded from disk + alignments: Dict + The alignments data corresponding to the loaded image + + Returns + ------- + :class:`numpy.ndarray` + The original image with the mask applied + """ + det_face = DetectedFace() + det_face.from_png_meta(alignments) + aln_face = AlignedFace(np.array(alignments["landmarks_xy"], dtype="float32"), + image=image, + centering="legacy", + size=256, + is_aligned=True) + assert aln_face.face is not None + + mask = det_face.mask.get("components", det_face.mask.get("extended", None)) + + if mask is None and not cls._log_mask_once: + logger.warning("No masks are available for masking the data. Results are likely to be " + "sub-standard") + cls._log_mask_once = True + + if mask is None: + return aln_face.face + + mask.set_sub_crop(aln_face.pose.offset[mask.stored_centering], + aln_face.pose.offset["legacy"], + centering="legacy") + nmask = cv2.resize(mask.mask, (256, 256), interpolation=cv2.INTER_CUBIC)[..., None] + return np.minimum(aln_face.face, nmask) + + +class SortMultiMethod(SortMethod): + """ A Parent sort method that runs 2 different underlying methods (one for sorting one for + binning) in instances where grouping has been requested, but the sort method is different from + the group method + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + sort_method: :class:`SortMethod` + A sort method object for sorting the images + group_method: :class:`SortMethod` + A sort method object used for sorting and binning the images + """ + def __init__(self, + arguments: Namespace, + sort_method: SortMethod, + group_method: SortMethod) -> None: + self._sorter = sort_method + self._grouper = group_method + self._is_built = False + super().__init__(arguments) + + def _get_file_iterator(self, input_dir: str) -> InfoLoader: + """ Override to get a group specific iterator. If the sorter and grouper use the same kind + of iterator, use that. Otherwise return the 'all' iterator, as which ever way it is cut all + outputs will be required. Monkey patch the actual loader used into the children in case of + any callbacks. + + Parameters + ---------- + input_dir: str + Full path to containing folder of faces to be supported + + Returns + ------- + :class:`InfoLoader` + The correct InfoLoader iterator for the current sort method + """ + if self._sorter.loader_type == self._grouper.loader_type: + retval = InfoLoader(input_dir, self._sorter.loader_type) + else: + retval = InfoLoader(input_dir, "all") + self._sorter._iterator = retval # pylint:disable=protected-access + self._grouper._iterator = retval # pylint:disable=protected-access + return retval + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Score a single image for sort method: "distance", "yaw" "pitch" or "size" and add the + result to :attr:`_result` + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` or ``None`` + A face image loaded from disk or ``None`` + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + self._sorter.score_image(filename, image, alignments) + self._grouper.score_image(filename, image, alignments) + + def sort(self) -> None: + """ Sort the sorter and grouper methods """ + logger.debug("Sorting") + self._sorter.sort() + self._result = self._sorter.sorted_filelist # type:ignore + self._grouper.sort() + self._binned = self._grouper.binned + self._bin_names = self._grouper.bin_names + logger.debug("Sorted") + + def binning(self) -> list[list[str]]: + """ Override standard binning, to bin by the group-by method and sort by the sorting + method. + + Go through the grouped binned results, and reorder each bin contents based on the + sorted list + + Returns + ------- + list + List of bins of filenames + """ + sorted_ = self._result + output: list[list[str]] = [] + for bin_ in tqdm(self._binned, desc="Binning and sorting", file=sys.stdout, leave=False): + indices: dict[int, str] = {} + for filename in bin_: + indices[sorted_.index(filename)] = filename + output.append([indices[idx] for idx in sorted(indices)]) + return output + + +class SortBlur(SortMethod): + """ Sort images by blur or blur-fft amount + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: + super().__init__(arguments, loader_type="all", is_group=is_group) + method = arguments.group_method if self._is_group else arguments.sort_method + self._use_fft = method == "blur_fft" + + def estimate_blur(self, image: np.ndarray, alignments=None) -> float: + """ Estimate the amount of blur an image has with the variance of the Laplacian. + Normalize by pixel number to offset the effect of image size on pixel gradients & variance. + + Parameters + ---------- + image: :class:`numpy.ndarray` + The face image to calculate blur for + alignments: dict, optional + The metadata for the face image or ``None`` if no metadata is available. If metadata is + provided the face will be masked by the "components" mask prior to calculating blur. + Default:``None`` + + Returns + ------- + float + The estimated blur score for the face + """ + if alignments is not None: + image = self._mask_face(image, alignments) + if image.ndim == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + blur_map = T.cast(np.ndarray, cv2.Laplacian(image, cv2.CV_32F)) + score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1]) + return score + + def estimate_blur_fft(self, + image: np.ndarray, + alignments: PNGHeaderAlignmentsDict | None = None) -> float: + """ Estimate the amount of blur a fft filtered image has. + + Parameters + ---------- + image: :class:`numpy.ndarray` + Use Fourier Transform to analyze the frequency characteristics of the masked + face using 2D Discrete Fourier Transform (DFT) filter to find the frequency domain. + A mean value is assigned to the magnitude spectrum and returns a blur score. + Adapted from https://www.pyimagesearch.com/2020/06/15/ + opencv-fast-fourier-transform-fft-for-blur-detection-in-images-and-video-streams/ + alignments: dict, optional + The metadata for the face image or ``None`` if no metadata is available. If metadata is + provided the face will be masked by the "components" mask prior to calculating blur. + Default:``None`` + + Returns + ------- + float + The estimated fft blur score for the face + """ + if alignments is not None: + image = self._mask_face(image, alignments) + + if image.ndim == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + height, width = image.shape + c_height, c_width = (int(height / 2.0), int(width / 2.0)) + fft = np.fft.fft2(image) + fft_shift = np.fft.fftshift(fft) + fft_shift[c_height - 75:c_height + 75, c_width - 75:c_width + 75] = 0 + ifft_shift = np.fft.ifftshift(fft_shift) + shift_back = np.fft.ifft2(ifft_shift) + magnitude = np.log(np.abs(shift_back)) + score = np.mean(magnitude) + + return score + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Score a single image for blur or blur-fft and add the result to :attr:`_result` + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` + A face image loaded from disk + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + assert image is not None + if self._log_once: + msg = "Grouping" if self._is_group else "Sorting" + inf = "fft_filtered " if self._use_fft else " " + logger.info("%s by estimated %simage blur...", msg, inf) + self._log_once = False + + estimator = self.estimate_blur_fft if self._use_fft else self.estimate_blur + self._result.append((filename, estimator(image, alignments))) + + def sort(self) -> None: + """ Sort by metric score. Order in reverse for distance sort. """ + logger.info("Sorting...") + self._result = sorted(self._result, key=operator.itemgetter(1), reverse=True) + + def binning(self) -> list[list[str]]: + """ Create bins to split linearly from the lowest to the highest sample value + + Returns + ------- + list + List of bins of filenames + """ + return self._binning_linear_threshold(multiplier=100) + + +class SortColor(SortMethod): + """ Score by channel average intensity or black pixels. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: + super().__init__(arguments, loader_type="face", is_group=is_group) + self._desired_channel = {'gray': 0, 'luma': 0, 'orange': 1, 'green': 2} + + method = arguments.group_method if self._is_group else arguments.sort_method + self._method = method.replace("color_", "") + + def _convert_color(self, image: np.ndarray) -> np.ndarray: + """ Helper function to convert color spaces + + Parameters + ---------- + image: :class:`numpy.ndarray` + The original image to convert color space for + + Returns + ------- + :class:`numpy.ndarray` + The color converted image + """ + if self._method == 'gray': + conversion = np.array([[0.0722], [0.7152], [0.2126]]) + else: + conversion = np.array([[0.25, 0.5, 0.25], [-0.5, 0.0, 0.5], [-0.25, 0.5, -0.25]]) + + operation = 'ijk, kl -> ijl' if self._method == "gray" else 'ijl, kl -> ijk' + path = np.einsum_path(operation, image[..., :3], conversion, optimize='optimal')[0] + return np.einsum(operation, image[..., :3], conversion, optimize=path).astype('float32') + + def _near_split(self, bin_range: int) -> list[int]: + """ Obtain the split for the given number of bins for the given range + + Parameters + ---------- + bin_range: int + The range of data to separate into bins + + Returns + ------- + list + The split dividers for the given number of bins for the given range + """ + quotient, remainder = divmod(bin_range, self._num_bins) + seps = [quotient + 1] * remainder + [quotient] * (self._num_bins - remainder) + uplimit = 0 + bins = [0] + for sep in seps: + bins.append(uplimit + sep) + uplimit += sep + return bins + + def binning(self) -> list[list[str]]: + """ Group into bins by percentage of black pixels """ + # TODO. Only grouped by black pixels. Check color + + logger.info("Grouping by percentage of %s...", self._method) + + # Starting the binning process + bins: list[list[str]] = [[] for _ in range(self._num_bins)] + # Get edges of bins from 0 to 100 + bins_edges = self._near_split(100) + # Get the proper bin number for each img order + img_bins = np.digitize([float(x[1]) for x in self._result], bins_edges, right=True) + + # Place imgs in bins + for idx, _bin in enumerate(img_bins): + bins[_bin].append(self._result[idx][0]) + + retval = [b for b in bins if b] + return retval + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Score a single image for color + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` + A face image loaded from disk + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + if self._log_once: + msg = "Grouping" if self._is_group else "Sorting" + if self._method == "black": + logger.info("%s by percentage of black pixels...", msg) + else: + logger.info("%s by channel average intensity...", msg) + self._log_once = False + + assert image is not None + if self._method == "black": + score = np.ndarray.all(image == [0, 0, 0], axis=2).sum()/image.size*100*3 + else: + channel_to_sort = self._desired_channel[self._method] + score = np.average(self._convert_color(image), axis=(0, 1))[channel_to_sort] + self._result.append((filename, score)) + + def sort(self) -> None: + """ Sort by metric score. Order in reverse for distance sort. """ + if self._method == "black": + self._sort_black_pixels() + return + self._result = sorted(self._result, key=operator.itemgetter(1), reverse=True) + + def _sort_black_pixels(self) -> None: + """ Sort by percentage of black pixels + + Calculates the sum of black pixels, gets the percentage X 3 channels + """ + img_list_len = len(self._result) + for i in tqdm(range(0, img_list_len - 1), + desc="Comparing black pixels", file=sys.stdout, + leave=False): + for j in range(0, img_list_len-i-1): + if self._result[j][1] > self._result[j+1][1]: + temp = self._result[j] + self._result[j] = self._result[j+1] + self._result[j+1] = temp + + +class SortFace(SortMethod): + """ Sort by identity similarity using VGG Face 2 + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + + _logged_lm_count_once = False + _warning = ("Extracted faces do not contain facial landmark data. Results sorted by this " + "method are likely to be sub-standard.") + + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: + super().__init__(arguments, loader_type="all", is_group=is_group) + self._vgg_face = VGGFace() + self._vgg_face.init_model() + threshold = arguments.threshold + self._output_update_info = True + self._threshold: float | None = 0.25 if threshold < 0 else threshold + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Processing logic for sort by face method. + + Reads header information from the PNG file to look for VGGFace2 embedding. If it does not + exist, the embedding is obtained and added back into the PNG Header. + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` + A face image loaded from disk + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + # pylint:disable=duplicate-code + if not alignments: + msg = ("The images to be sorted do not contain alignment data. Images must have " + "been generated by Faceswap's Extract process.\nIf you are sorting an " + "older faceset, then you should re-extract the faces from your source " + "alignments file to generate this data.") + raise FaceswapError(msg) + + if self._log_once: + msg = "Grouping" if self._is_group else "Sorting" + logger.info("%s by identity similarity...", msg) + self._log_once = False + + if alignments.get("identity", {}).get("vggface2"): + embedding = np.array(alignments["identity"]["vggface2"], dtype="float32") + + if not self._logged_lm_count_once and len(alignments["landmarks_xy"]) == 4: + logger.warning(self._warning) + self._logged_lm_count_once = True + + self._result.append((filename, embedding)) + return + + if self._output_update_info: + logger.info("VGG Face2 Embeddings are being written to the image header. " + "Sorting by this method will be quicker next time") + self._output_update_info = False + + a_face = AlignedFace(np.array(alignments["landmarks_xy"], dtype="float32"), + image=image, + centering="legacy", + size=self._vgg_face.input_size, + is_aligned=True) + + if a_face.landmark_type == LandmarkType.LM_2D_4 and not self._logged_lm_count_once: + logger.warning(self._warning) + self._logged_lm_count_once = True + + face = a_face.face + assert face is not None + embedding = self._vgg_face.predict(face[None, ...])[0] + alignments.setdefault("identity", {})["vggface2"] = embedding.tolist() + self._iterator.update_png_header(filename, alignments) + self._result.append((filename, embedding)) + + def sort(self) -> None: + """ Sort by dendogram. + + Parameters + ---------- + matched_list: list + The list of tuples with filename in first position and face encoding in the 2nd + + Returns + ------- + list + The original list, sorted for this metric + """ + logger.info("Sorting by ward linkage. This may take some time...") + preds = np.array([item[1] for item in self._result]) + indices = Cluster(np.array(preds), "ward", threshold=self._threshold)() + self._result = [(self._result[idx][0], float(score)) for idx, score in indices] + + def binning(self) -> list[list[str]]: + """ Group into bins by their sorted score + + The bin ID has been output in the 2nd column of :attr:`_result` so use that for binnin + + Returns + ------- + list + List of bins of filenames + """ + num_bins = len(set(int(i[1]) for i in self._result)) + logger.info("Grouping by %s...", self.__class__.__name__.replace("Sort", "")) + bins: list[list[str]] = [[] for _ in range(num_bins)] + + for filename, bin_id in self._result: + bins[int(bin_id)].append(filename) + + return bins + + +class SortHistogram(SortMethod): + """ Sort by image histogram similarity or dissimilarity + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: + super().__init__(arguments, loader_type="all", is_group=is_group) + method = arguments.group_method if self._is_group else arguments.sort_method + self._is_dissim = method == "hist-dissim" + self._threshold: float = 0.3 if arguments.threshold < 0.0 else arguments.threshold + + def _calc_histogram(self, + image: np.ndarray, + alignments: PNGHeaderAlignmentsDict | None) -> np.ndarray: + if alignments: + image = self._mask_face(image, alignments) + return cv2.calcHist([image], [0], None, [256], [0, 256]) + + def _sort_dissim(self) -> None: + """ Sort histograms by dissimilarity """ + result = T.cast(list[tuple[str, np.ndarray]], self._result) + img_list_len = len(result) + for i in tqdm(range(0, img_list_len), + desc="Comparing histograms", + file=sys.stdout, + leave=False): + score_total = 0.0 + for j in range(0, img_list_len): + if i == j: + continue + score_total += cv2.compareHist(result[i][1], + result[j][1], + cv2.HISTCMP_BHATTACHARYYA) + result[i][2] = score_total + + self._result = sorted(result, key=operator.itemgetter(2), reverse=True) + + def _sort_sim(self) -> None: + """ Sort histograms by similarity """ + result = T.cast(list[tuple[str, np.ndarray]], self._result) + img_list_len = len(result) + for i in tqdm(range(0, img_list_len - 1), + desc="Comparing histograms", + file=sys.stdout, + leave=False): + min_score = float("inf") + j_min_score = i + 1 + for j in range(i + 1, img_list_len): + score = cv2.compareHist(result[i][1], + result[j][1], + cv2.HISTCMP_BHATTACHARYYA) + if score < min_score: + min_score = score + j_min_score = j + (self._result[i + 1], self._result[j_min_score]) = (result[j_min_score], result[i + 1]) + + @classmethod + def _get_avg_score(cls, image: np.ndarray, references: list[np.ndarray]) -> float: + """ Return the average histogram score between a face and reference images + + Parameters + ---------- + image: :class:`numpy.ndarray` + The image to test + references: list + List of reference images to test the original image against + + Returns + ------- + float + The average score between the histograms + """ + scores = [] + for img2 in references: + score = cv2.compareHist(image, img2, cv2.HISTCMP_BHATTACHARYYA) + scores.append(score) + return sum(scores) / len(scores) + + def binning(self) -> list[list[str]]: + """ Group into bins by histogram """ + # pylint:disable=duplicate-code + msg = "dissimilarity" if self._is_dissim else "similarity" + logger.info("Grouping by %s...", msg) + + # Groups are of the form: group_num -> reference histogram + reference_groups: dict[int, list[np.ndarray]] = {} + + # Bins array, where index is the group number and value is + # an array containing the file paths to the images in that group + bins: list[list[str]] = [] + + threshold = self._threshold + + img_list_len = len(self._result) + reference_groups[0] = [T.cast(np.ndarray, self._result[0][1])] + bins.append([self._result[0][0]]) + + for i in tqdm(range(1, img_list_len), + desc="Grouping", + file=sys.stdout, + leave=False): + current_key = -1 + current_score = float("inf") + for key, value in reference_groups.items(): + score = self._get_avg_score(self._result[i][1], value) + if score < current_score: + current_key, current_score = key, score + + if current_score < threshold: + reference_groups[T.cast(int, current_key)].append(self._result[i][1]) + bins[current_key].append(self._result[i][0]) + else: + reference_groups[len(reference_groups)] = [self._result[i][1]] + bins.append([self._result[i][0]]) + + return bins + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Collect the histogram for the given face + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` + A face image loaded from disk + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + if self._log_once: + msg = "Grouping" if self._is_group else "Sorting" + logger.info("%s by histogram similarity...", msg) + self._log_once = False + + assert image is not None + self._result.append((filename, self._calc_histogram(image, alignments))) + + def sort(self) -> None: + """ Sort by histogram. """ + logger.info("Comparing histograms and sorting...") + if self._is_dissim: + self._sort_dissim() + return + self._sort_sim() + + +__all__ = get_module_objects(__name__) diff --git a/tools/sort/sort_methods_aligned.py b/tools/sort/sort_methods_aligned.py new file mode 100644 index 0000000000..5cb3ba99e1 --- /dev/null +++ b/tools/sort/sort_methods_aligned.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +""" Sorting methods that use the properties of a :class:`lib.align.AlignedFace` object to obtain +their sorting metrics. +""" +from __future__ import annotations +import logging +import operator +import sys +import typing as T + +import numpy as np +from tqdm import tqdm + +from lib.align import AlignedFace, LandmarkType +from lib.utils import get_module_objects, FaceswapError +from .sort_methods import SortMethod + +if T.TYPE_CHECKING: + from argparse import Namespace + from lib.align.alignments import PNGHeaderAlignmentsDict + +logger = logging.getLogger(__name__) + + +class SortAlignedMetric(SortMethod): + """ Sort by comparison of metrics stored in an Aligned Face objects. This is a parent class + for sort by aligned metrics methods. Individual methods should inherit from this class + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + sort_reverse: bool, optional + ``True`` if the sorted results should be in reverse order. Default: ``True`` + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + + _logged_lm_count_once: bool = False + + def _get_metric(self, aligned_face: AlignedFace) -> np.ndarray | float: + """ Obtain the correct metric for the given sort method" + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float or :class:`numpy.ndarray` + The metric for the current face based on chosen sort method + """ + raise NotImplementedError + + def sort(self) -> None: + """ Sort by metric score. Order in reverse for distance sort. """ + logger.info("Sorting...") + self._result = sorted(self._result, key=operator.itemgetter(1), reverse=True) + + def score_image(self, + filename: str, + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: + """ Score a single image for sort method: "distance", "yaw", "pitch" or "size" and add the + result to :attr:`_result` + + Parameters + ---------- + filename: str + The filename of the currently processing image + image: :class:`np.ndarray` or ``None`` + A face image loaded from disk or ``None`` + alignments: dict or ``None`` + The alignments dictionary for the aligned face or ``None`` + """ + if self._log_once: + msg = "Grouping" if self._is_group else "Sorting" + logger.info("%s by %s...", msg, self._method) + self._log_once = False + + if not alignments: + msg = ("The images to be sorted do not contain alignment data. Images must have " + "been generated by Faceswap's Extract process.\nIf you are sorting an " + "older faceset, then you should re-extract the faces from your source " + "alignments file to generate this data.") + raise FaceswapError(msg) + + face = AlignedFace(np.array(alignments["landmarks_xy"], dtype="float32")) + if (not self._logged_lm_count_once + and face.landmark_type == LandmarkType.LM_2D_4 + and self.__class__.__name__ != "SortSize"): + logger.warning("You have selected to sort by an aligned metric, but at least one face " + "does not contain facial landmark data. This probably won't work") + self._logged_lm_count_once = True + self._result.append((filename, self._get_metric(face))) + + +class SortDistance(SortAlignedMetric): + """ Sorting mechanism for sorting faces from small to large """ + def _get_metric(self, aligned_face: AlignedFace) -> float: + """ Obtain the distance from mean face metric for the given face + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float + The distance metric for the current face + """ + return aligned_face.average_distance + + def sort(self) -> None: + """ Override default sort to sort in ascending order. """ + logger.info("Sorting...") + self._result = sorted(self._result, key=operator.itemgetter(1), reverse=False) + + def binning(self) -> list[list[str]]: + """ Create bins to split linearly from the lowest to the highest sample value + + Returns + ------- + list + List of bins of filenames + """ + return self._binning_linear_threshold(multiplier=100) + + +class SortPitch(SortAlignedMetric): + """ Sorting mechansim for sorting a face by pitch (down to up) """ + def _get_metric(self, aligned_face: AlignedFace) -> float: + """ Obtain the pitch metric for the given face + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float + The pitch metric for the current face + """ + return aligned_face.pose.pitch + + def binning(self) -> list[list[str]]: + """ Create bins from 0 degrees to 180 degrees based on number of bins + + Allocate item to bin when it is in range of one of the pre-allocated bins + + Returns + ------- + list + List of bins of filenames + """ + thresholds = np.linspace(90, -90, self._num_bins + 1) + + # Start bin names from 0 for more intuitive experience + names = np.flip(thresholds.astype("int")) + 90 + self._bin_names = [f"{self._method}_" + f"{idx:03d}_{int(names[idx])}" + f"degs_to_{int(names[idx + 1])}degs" + for idx in range(self._num_bins)] + + bins: list[list[str]] = [[] for _ in range(self._num_bins)] + for filename, result in self._result: + result = np.clip(result, -90.0, 90.0) + bin_idx = next(bin_id for bin_id, thresh in enumerate(thresholds) + if result >= thresh) - 1 + bins[bin_idx].append(filename) + return bins + + +class SortYaw(SortPitch): + """ Sorting mechansim for sorting a face by yaw (left to right). Same logic as sort pitch, but + with different metric """ + def _get_metric(self, aligned_face: AlignedFace) -> float: + """ Obtain the yaw metric for the given face + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float + The yaw metric for the current face + """ + return aligned_face.pose.yaw + + +class SortRoll(SortPitch): + """ Sorting mechansim for sorting a face by roll (rotation). Same logic as sort pitch, but + with different metric """ + def _get_metric(self, aligned_face: AlignedFace) -> float: + """ Obtain the roll metric for the given face + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float + The yaw metric for the current face + """ + return aligned_face.pose.roll + + +class SortSize(SortAlignedMetric): + """ Sorting mechanism for sorting faces from small to large """ + def _get_metric(self, aligned_face: AlignedFace) -> float: + """ Obtain the size metric for the given face + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float + The size metric for the current face + """ + roi = aligned_face.original_roi + size = ((roi[1][0] - roi[0][0]) ** 2 + (roi[1][1] - roi[0][1]) ** 2) ** 0.5 + return size + + def binning(self) -> list[list[str]]: + """ Create bins to split linearly from the lowest to the highest sample value + + Allocate item to bin when it is in range of one of the pre-allocated bins + + Returns + ------- + list + List of bins of filenames + """ + return self._binning_linear_threshold(units="px") + + +class SortFaceCNN(SortAlignedMetric): + """ Sort by landmark similarity or dissimilarity + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The command line arguments passed to the sort process + is_group: bool, optional + Set to ``True`` if this class is going to be called exclusively for binning. + Default: ``False`` + """ + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: + super().__init__(arguments, is_group=is_group) + self._is_dissim = self._method == "face-cnn-dissim" + self._threshold: float = 7.2 if arguments.threshold < 1.0 else arguments.threshold + + def _get_metric(self, aligned_face: AlignedFace) -> np.ndarray: + """ Obtain the xy aligned landmarks for the face" + + Parameters + ---------- + aligned_face: :class:`lib.align.AlignedFace` + The aligned face to extract the metric from + + Returns + ------- + float + The metric for the current face based on chosen sort method + """ + return aligned_face.landmarks + + def sort(self) -> None: + """ Sort by landmarks. """ + logger.info("Comparing landmarks and sorting...") + if self._is_dissim: + self._sort_landmarks_dissim() + return + self._sort_landmarks_ssim() + + def _sort_landmarks_ssim(self) -> None: + """ Sort landmarks by similarity """ + img_list_len = len(self._result) + for i in tqdm(range(0, img_list_len - 1), desc="Comparing", file=sys.stdout, leave=False): + min_score = float("inf") + j_min_score = i + 1 + for j in range(i + 1, img_list_len): + fl1 = self._result[i][1] + fl2 = self._result[j][1] + score = np.sum(np.absolute((fl2 - fl1).flatten())) + if score < min_score: + min_score = score + j_min_score = j + (self._result[i + 1], self._result[j_min_score]) = (self._result[j_min_score], + self._result[i + 1]) + + def _sort_landmarks_dissim(self) -> None: + """ Sort landmarks by dissimilarity """ + logger.info("Comparing landmarks...") + img_list_len = len(self._result) + for i in tqdm(range(0, img_list_len - 1), desc="Comparing", file=sys.stdout, leave=False): + score_total = 0 + for j in range(i + 1, img_list_len): + if i == j: + continue + fl1 = self._result[i][1] + fl2 = self._result[j][1] + score_total += np.sum(np.absolute((fl2 - fl1).flatten())) + self._result[i][2] = score_total + + logger.info("Sorting...") + self._result = sorted(self._result, key=operator.itemgetter(2), reverse=True) + + def binning(self) -> list[list[str]]: + """ Group into bins by CNN face similarity + + Returns + ------- + list + List of bins of filenames + """ + msg = "dissimilarity" if self._is_dissim else "similarity" + logger.info("Grouping by face-cnn %s...", msg) + + # Groups are of the form: group_num -> reference faces + reference_groups: dict[int, list[np.ndarray]] = {} + + # Bins array, where index is the group number and value is + # an array containing the file paths to the images in that group. + bins: list[list[str]] = [] + + # Comparison threshold used to decide how similar + # faces have to be to be grouped together. + # It is multiplied by 1000 here to allow the cli option to use smaller + # numbers. + threshold = self._threshold * 1000 + img_list_len = len(self._result) + + for i in tqdm(range(0, img_list_len - 1), + desc="Grouping", + file=sys.stdout, + leave=False): + fl1 = self._result[i][1] + + current_key = -1 + current_score = float("inf") + + for key, references in reference_groups.items(): + try: + score = self._get_avg_score(fl1, references) + except TypeError: + score = float("inf") + except ZeroDivisionError: + score = float("inf") + if score < current_score: + current_key, current_score = key, score + + if current_score < threshold: + reference_groups[current_key].append(fl1[0]) + bins[current_key].append(self._result[i][0]) + else: + reference_groups[len(reference_groups)] = [self._result[i][1]] + bins.append([self._result[i][0]]) + + return bins + + @classmethod + def _get_avg_score(cls, face: np.ndarray, references: list[np.ndarray]) -> float: + """ Return the average CNN similarity score between a face and reference images + + Parameters + ---------- + face: :class:`numpy.ndarray` + The face to check against reference images + references: list + List of reference arrays to compare the face against + + Returns + ------- + float + The average score between the face and the references + """ + scores = [] + for ref in references: + score = np.sum(np.absolute((ref - face).flatten())) + scores.append(score) + return sum(scores) / len(scores) + + +__all__ = get_module_objects(__name__) diff --git a/update_deps.py b/update_deps.py new file mode 100644 index 0000000000..0fb48b8be0 --- /dev/null +++ b/update_deps.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +""" Installs any required third party libs for faceswap.py + + Checks for installed Conda / Pip packages and updates accordingly +""" +import logging +import os +import sys + +from lib.logger import log_setup +from lib.utils import get_module_objects +from setup import Environment, Install + +logger = logging.getLogger(__name__) + + +def main(is_gui=False) -> None: + """ Check for and update dependencies + + Parameters + ---------- + is_gui: bool, optional + ``True`` if being called by the GUI. Prevents the updater from outputting progress bars + which get scrambled in the GUI + """ + logger.info("Updating dependencies...") + update = Environment(updater=True) + Install(update, is_gui=is_gui) + logger.info("Dependencies updated") + + +if __name__ == "__main__": + logfile = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap_update.log") + log_setup("INFO", logfile, "setup") + main() + + +__all__ = get_module_objects(__name__)