From a511214aaacda7432ee9d2734183361f32e2cf26 Mon Sep 17 00:00:00 2001 From: Tran Xen <137925069+glucauze@users.noreply.github.com> Date: Fri, 28 Jul 2023 20:47:14 +0200 Subject: [PATCH] clean code, fix extract --- check.sh | 4 + example/api/roop_api_example.py | 1 - install.py | 2 +- mypy.ini | 3 +- preload.py | 5 +- scripts/faceswaplab.py | 50 +++--- scripts/faceswaplab_api/faceswaplab_api.py | 46 ++++-- .../faceswaplab_api/faceswaplab_api_types.py | 8 - scripts/faceswaplab_globals.py | 7 +- scripts/faceswaplab_postprocessing/i2i_pp.py | 6 +- .../postprocessing.py | 1 - .../faceswaplab_postprocessing/upscaling.py | 3 +- .../faceswaplab_settings.py | 2 +- .../faceswaplab_swapping/parsing/__init__.py | 1 - scripts/faceswaplab_swapping/swapper.py | 58 +++---- .../upscaled_inswapper.py | 10 +- scripts/faceswaplab_ui/faceswaplab_tab.py | 90 ++++++----- .../faceswaplab_unit_settings.py | 2 +- scripts/faceswaplab_ui/faceswaplab_unit_ui.py | 7 +- .../faceswaplab_ui/faceswaplab_upscaler_ui.py | 15 +- .../faceswaplab_utils/faceswaplab_logging.py | 49 +++++- scripts/faceswaplab_utils/imgutils.py | 147 ++++++++++++++---- scripts/faceswaplab_utils/models_utils.py | 5 +- 23 files changed, 328 insertions(+), 194 deletions(-) create mode 100755 check.sh diff --git a/check.sh b/check.sh new file mode 100755 index 0000000..b5cbbc8 --- /dev/null +++ b/check.sh @@ -0,0 +1,4 @@ +#!/bin/bash +autoflake --in-place --remove-unused-variables -r --remove-all-unused-imports . +mypy --install-types +pre-commit run --all-files \ No newline at end of file diff --git a/example/api/roop_api_example.py b/example/api/roop_api_example.py index f38539a..d262940 100644 --- a/example/api/roop_api_example.py +++ b/example/api/roop_api_example.py @@ -1,5 +1,4 @@ import requests -from PIL import Image from client_utils import ( FaceSwapRequest, FaceSwapUnit, diff --git a/install.py b/install.py index 5ae1104..f1d7033 100644 --- a/install.py +++ b/install.py @@ -14,7 +14,7 @@ model_name = os.path.basename(model_url) model_path = os.path.join(models_dir, model_name) -def download(url, path): +def download(url: str, path: str) -> None: request = urllib.request.urlopen(url) total = int(request.headers.get("Content-Length", 0)) with tqdm( diff --git a/mypy.ini b/mypy.ini index d90ae92..4a4f35f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,4 +4,5 @@ disallow_any_generics = True disallow_untyped_calls = True disallow_untyped_defs = True ignore_missing_imports = True -strict_optional = False \ No newline at end of file +strict_optional = False +explicit_package_bases=True \ No newline at end of file diff --git a/preload.py b/preload.py index 7201e31..5c3eaf5 100644 --- a/preload.py +++ b/preload.py @@ -1,4 +1,7 @@ -def preload(parser): +from argparse import ArgumentParser + + +def preload(parser: ArgumentParser) -> None: parser.add_argument( "--faceswaplab_loglevel", default="INFO", diff --git a/scripts/faceswaplab.py b/scripts/faceswaplab.py index 37a2e0a..f058c3b 100644 --- a/scripts/faceswaplab.py +++ b/scripts/faceswaplab.py @@ -4,7 +4,6 @@ from scripts.faceswaplab_settings import faceswaplab_settings from scripts.faceswaplab_ui import faceswaplab_tab, faceswaplab_unit_ui from scripts.faceswaplab_utils.models_utils import ( get_current_model, - get_face_checkpoints, ) from scripts import faceswaplab_globals @@ -12,7 +11,6 @@ from scripts.faceswaplab_swapping import swapper from scripts.faceswaplab_utils import faceswaplab_logging, imgutils from scripts.faceswaplab_utils import models_utils from scripts.faceswaplab_postprocessing import upscaling -import numpy as np # Reload all the modules when using "apply and restart" # This is mainly done for development purposes @@ -29,15 +27,13 @@ importlib.reload(faceswaplab_api) import os from dataclasses import fields from pprint import pformat -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple -import dill as pickle import gradio as gr import modules.scripts as scripts from modules import script_callbacks, scripts -from insightface.app.common import Face from modules import scripts, shared -from modules.images import save_image, image_grid +from modules.images import save_image from modules.processing import ( Processed, StableDiffusionProcessing, @@ -46,7 +42,6 @@ from modules.processing import ( from modules.shared import opts from PIL import Image -from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw from scripts.faceswaplab_utils.faceswaplab_logging import logger, save_img_debug from scripts.faceswaplab_globals import VERSION_FLAG from scripts.faceswaplab_postprocessing.postprocessing_options import ( @@ -76,15 +71,15 @@ class FaceSwapScript(scripts.Script): super().__init__() @property - def units_count(self): + def units_count(self) -> int: return opts.data.get("faceswaplab_units_count", 3) @property - def upscaled_swapper_in_generated(self): + def upscaled_swapper_in_generated(self) -> bool: return opts.data.get("faceswaplab_upscaled_swapper", False) @property - def upscaled_swapper_in_source(self): + def upscaled_swapper_in_source(self) -> bool: return opts.data.get("faceswaplab_upscaled_swapper_in_source", False) @property @@ -93,24 +88,24 @@ class FaceSwapScript(scripts.Script): return any([u.enable for u in self.units]) and not shared.state.interrupted @property - def keep_original_images(self): + def keep_original_images(self) -> bool: return opts.data.get("faceswaplab_keep_original", False) @property - def swap_in_generated_units(self): + def swap_in_generated_units(self) -> List[FaceSwapUnitSettings]: return [u for u in self.units if u.swap_in_generated and u.enable] @property - def swap_in_source_units(self): + def swap_in_source_units(self) -> List[FaceSwapUnitSettings]: return [u for u in self.units if u.swap_in_source and u.enable] - def title(self): + def title(self) -> str: return f"faceswaplab" - def show(self, is_img2img): + def show(self, is_img2img: bool) -> bool: return scripts.AlwaysVisible - def ui(self, is_img2img): + def ui(self, is_img2img: bool) -> List[gr.components.Component]: with gr.Accordion(f"FaceSwapLab {VERSION_FLAG}", open=False): components = [] for i in range(1, self.units_count + 1): @@ -119,16 +114,9 @@ class FaceSwapScript(scripts.Script): # If the order is modified, the before_process should be changed accordingly. return components + upscaler - # def make_script_first(self,p: StableDiffusionProcessing) : - # FIXME : not really useful, will only impact postprocessing (kept for further testing) - # runner : scripts.ScriptRunner = p.scripts - # alwayson = runner.alwayson_scripts - # alwayson.pop(alwayson.index(self)) - # alwayson.insert(0, self) - # print("Running in ", alwayson.index(self), "position") - # logger.info("Running scripts : %s", pformat(runner.alwayson_scripts)) - - def read_config(self, p: StableDiffusionProcessing, *components): + def read_config( + self, p: StableDiffusionProcessing, *components: List[gr.components.Component] + ) -> None: # The order of processing for the components is important # The method first process faceswap units then postprocessing units @@ -148,14 +136,16 @@ class FaceSwapScript(scripts.Script): len_conf: int = len(fields(FaceSwapUnitSettings)) shift: int = self.units_count * len_conf self.postprocess_options = PostProcessingOptions( - *components[shift : shift + len(fields(PostProcessingOptions))] + *components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore ) logger.debug("%s", pformat(self.postprocess_options)) if self.enabled: p.do_not_save_samples = not self.keep_original_images - def process(self, p: StableDiffusionProcessing, *components): + def process( + self, p: StableDiffusionProcessing, *components: List[gr.components.Component] + ) -> None: self.read_config(p, *components) # If is instance of img2img, we check if face swapping in source is required. @@ -175,7 +165,9 @@ class FaceSwapScript(scripts.Script): if new_inits is not None: p.init_images = [img[0] for img in new_inits] - def postprocess(self, p: StableDiffusionProcessing, processed: Processed, *args): + def postprocess( + self, p: StableDiffusionProcessing, processed: Processed, *args: List[Any] + ) -> None: if self.enabled: # Get the original images without the grid orig_images: List[Image.Image] = processed.images[ diff --git a/scripts/faceswaplab_api/faceswaplab_api.py b/scripts/faceswaplab_api/faceswaplab_api.py index 012cc9a..6decff3 100644 --- a/scripts/faceswaplab_api/faceswaplab_api.py +++ b/scripts/faceswaplab_api/faceswaplab_api.py @@ -1,52 +1,68 @@ from PIL import Image import numpy as np -from fastapi import FastAPI, Body -from fastapi.exceptions import HTTPException -from modules.api.models import * +from fastapi import FastAPI from modules.api import api from scripts.faceswaplab_api.faceswaplab_api_types import ( - FaceSwapUnit, FaceSwapRequest, FaceSwapResponse, ) from scripts.faceswaplab_globals import VERSION_FLAG import gradio as gr -from typing import List, Optional +from typing import Dict, List, Optional, Union from scripts.faceswaplab_swapping import swapper -from scripts.faceswaplab_utils.faceswaplab_logging import save_img_debug from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_utils.imgutils import ( - pil_to_cv2, - check_against_nsfw, base64_to_pil, ) from scripts.faceswaplab_utils.models_utils import get_current_model from modules.shared import opts -def encode_to_base64(image): - if type(image) is str: +def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: + """ + Encode an image to a base64 string. + + The image can be a file path (str), a PIL Image, or a NumPy array. + + Args: + image (Union[str, Image.Image, np.ndarray]): The image to encode. + + Returns: + str: The base64-encoded image if successful, otherwise an empty string. + """ + if isinstance(image, str): return image - elif type(image) is Image.Image: + elif isinstance(image, Image.Image): return api.encode_pil_to_base64(image) - elif type(image) is np.ndarray: + elif isinstance(image, np.ndarray): return encode_np_to_base64(image) else: return "" -def encode_np_to_base64(image): +def encode_np_to_base64(image: np.ndarray) -> str: + """ + Encode a NumPy array to a base64 string. + + The array is first converted to a PIL Image, then encoded. + + Args: + image (np.ndarray): The NumPy array to encode. + + Returns: + str: The base64-encoded image. + """ pil = Image.fromarray(image) return api.encode_pil_to_base64(pil) -def faceswaplab_api(_: gr.Blocks, app: FastAPI): +def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None: @app.get( "/faceswaplab/version", tags=["faceswaplab"], description="Get faceswaplab version", ) - async def version(): + async def version() -> Dict[str, str]: return {"version": VERSION_FLAG} # use post as we consider the method non idempotent (which is debatable) diff --git a/scripts/faceswaplab_api/faceswaplab_api_types.py b/scripts/faceswaplab_api/faceswaplab_api_types.py index 368840a..383a1d9 100644 --- a/scripts/faceswaplab_api/faceswaplab_api_types.py +++ b/scripts/faceswaplab_api/faceswaplab_api_types.py @@ -1,16 +1,8 @@ -from scripts.faceswaplab_swapping import swapper -import numpy as np from typing import List, Tuple -import dill as pickle -import gradio as gr -from insightface.app.common import Face from PIL import Image from scripts.faceswaplab_utils.imgutils import ( - pil_to_cv2, - check_against_nsfw, base64_to_pil, ) -from scripts.faceswaplab_utils.faceswaplab_logging import logger from pydantic import BaseModel, Field from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen diff --git a/scripts/faceswaplab_globals.py b/scripts/faceswaplab_globals.py index 4df9abb..fc14c41 100644 --- a/scripts/faceswaplab_globals.py +++ b/scripts/faceswaplab_globals.py @@ -1,10 +1,11 @@ -from scripts.faceswaplab_utils.faceswaplab_logging import logger import os MODELS_DIR = os.path.abspath(os.path.join("models", "faceswaplab")) ANALYZER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "analysers")) FACE_PARSER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "parser")) -VERSION_FLAG = "v1.1.0" +VERSION_FLAG: str = "v1.1.0" EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab") -NSFW_SCORE = 0.7 + +# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW. +NSFW_SCORE_THRESHOLD: float = 0.7 diff --git a/scripts/faceswaplab_postprocessing/i2i_pp.py b/scripts/faceswaplab_postprocessing/i2i_pp.py index c5970e9..97952e4 100644 --- a/scripts/faceswaplab_postprocessing/i2i_pp.py +++ b/scripts/faceswaplab_postprocessing/i2i_pp.py @@ -1,15 +1,11 @@ -from modules.face_restoration import FaceRestoration -from modules.upscaler import UpscalerData from scripts.faceswaplab_utils.faceswaplab_logging import logger from PIL import Image -import numpy as np from modules import shared from scripts.faceswaplab_utils import imgutils -from modules import shared, processing, codeformer_model +from modules import shared, processing from modules.processing import StableDiffusionProcessingImg2Img from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, - InpaintingWhen, ) from modules import sd_models diff --git a/scripts/faceswaplab_postprocessing/postprocessing.py b/scripts/faceswaplab_postprocessing/postprocessing.py index 1732696..ddc8599 100644 --- a/scripts/faceswaplab_postprocessing/postprocessing.py +++ b/scripts/faceswaplab_postprocessing/postprocessing.py @@ -1,4 +1,3 @@ -from modules.face_restoration import FaceRestoration from scripts.faceswaplab_utils.faceswaplab_logging import logger from PIL import Image from scripts.faceswaplab_postprocessing.postprocessing_options import ( diff --git a/scripts/faceswaplab_postprocessing/upscaling.py b/scripts/faceswaplab_postprocessing/upscaling.py index 01fa438..04ba5fb 100644 --- a/scripts/faceswaplab_postprocessing/upscaling.py +++ b/scripts/faceswaplab_postprocessing/upscaling.py @@ -1,11 +1,10 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, - InpaintingWhen, ) from scripts.faceswaplab_utils.faceswaplab_logging import logger from PIL import Image import numpy as np -from modules import shared, processing, codeformer_model +from modules import codeformer_model def upscale_img(image: Image.Image, pp_options: PostProcessingOptions) -> Image.Image: diff --git a/scripts/faceswaplab_settings/faceswaplab_settings.py b/scripts/faceswaplab_settings/faceswaplab_settings.py index fddec02..ccaad47 100644 --- a/scripts/faceswaplab_settings/faceswaplab_settings.py +++ b/scripts/faceswaplab_settings/faceswaplab_settings.py @@ -3,7 +3,7 @@ from modules import script_callbacks, shared import gradio as gr -def on_ui_settings(): +def on_ui_settings() -> None: section = ("faceswaplab", "FaceSwapLab") models = get_models() shared.opts.add_option( diff --git a/scripts/faceswaplab_swapping/parsing/__init__.py b/scripts/faceswaplab_swapping/parsing/__init__.py index 6b4c5fd..39fe887 100644 --- a/scripts/faceswaplab_swapping/parsing/__init__.py +++ b/scripts/faceswaplab_swapping/parsing/__init__.py @@ -40,7 +40,6 @@ please contact the contributor(s) of the work. import torch -import cv2 import os import torch from torch.hub import download_url_to_file, get_dir diff --git a/scripts/faceswaplab_swapping/swapper.py b/scripts/faceswaplab_swapping/swapper.py index 5a138dd..9dfc454 100644 --- a/scripts/faceswaplab_swapping/swapper.py +++ b/scripts/faceswaplab_swapping/swapper.py @@ -1,7 +1,7 @@ import copy import os from dataclasses import dataclass -from typing import Dict, List, Set, Tuple, Optional, Union +from typing import Any, Dict, List, Set, Tuple, Optional import cv2 import insightface @@ -13,7 +13,6 @@ from sklearn.metrics.pairwise import cosine_similarity from scripts.faceswaplab_swapping import upscaled_inswapper from scripts.faceswaplab_utils.imgutils import ( - cv2_to_pil, pil_to_cv2, check_against_nsfw, ) @@ -26,7 +25,7 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting providers = ["CPUExecutionProvider"] -def cosine_similarity_face(face1, face2) -> float: +def cosine_similarity_face(face1: Face, face2: Face) -> float: """ Calculates the cosine similarity between two face embeddings. @@ -92,7 +91,7 @@ class FaceModelException(Exception): @lru_cache(maxsize=1) -def getAnalysisModel(): +def getAnalysisModel() -> insightface.app.FaceAnalysis: """ Retrieves the analysis model for face analysis. @@ -112,11 +111,11 @@ def getAnalysisModel(): logger.error( "Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)" ) - raise FaceModelException("Loading of swapping model failed") + raise FaceModelException("Loading of analysis model failed") @lru_cache(maxsize=1) -def getFaceSwapModel(model_path: str): +def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper: """ Retrieves the face swap model and initializes it if necessary. @@ -135,13 +134,14 @@ def getFaceSwapModel(model_path: str): logger.error( "Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)" ) + raise FaceModelException("Loading of swapping model failed") def get_faces( - img_data: np.ndarray, - det_size=(640, 640), - det_thresh: Optional[int] = None, - sort_by_face_size=False, + img_data: np.ndarray, # type: ignore + det_size: Tuple[int, int] = (640, 640), + det_thresh: Optional[float] = None, + sort_by_face_size: bool = False, ) -> List[Face]: """ Detects and retrieves faces from an image using an analysis model. @@ -211,7 +211,7 @@ class ImageResult: """ -def get_or_default(l, index, default): +def get_or_default(l: List[Any], index: int, default: Any) -> Any: """ Retrieve the value at the specified index from the given list. If the index is out of bounds, return the default value instead. @@ -227,7 +227,10 @@ def get_or_default(l, index, default): return l[index] if index < len(l) else default -def get_faces_from_img_files(files): +import gradio as gr + + +def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[np.ndarray]]: # type: ignore """ Extracts faces from a list of image files. @@ -300,15 +303,15 @@ def blend_faces(faces: List[Face]) -> Face: def swap_face( - reference_face: np.ndarray, - source_face: np.ndarray, + reference_face: np.ndarray, # type: ignore + source_face: np.ndarray, # type: ignore target_img: Image.Image, model: str, faces_index: Set[int] = {0}, - same_gender=True, - upscaled_swapper=False, - compute_similarity=True, - sort_by_face_size=False, + same_gender: bool = True, + upscaled_swapper: bool = False, + compute_similarity: bool = True, + sort_by_face_size: bool = False, ) -> ImageResult: """ Swaps faces in the target image with the source face. @@ -344,6 +347,7 @@ def swap_face( for i, swapped_face in enumerate(target_faces): logger.info(f"swap face {i}") if i in faces_index: + # type : ignore result = face_swapper.get( result, swapped_face, source_face, upscale=upscaled_swapper ) @@ -385,13 +389,13 @@ def swap_face( def process_image_unit( - model, + model: str, unit: FaceSwapUnitSettings, image: Image.Image, - info=None, - upscaled_swapper=False, - force_blend=False, -) -> List: + info: str = None, + upscaled_swapper: bool = False, + force_blend: bool = False, +) -> List[Tuple[Image.Image, str]]: """Process one image and return a List of (image, info) (one if blended, many if not). Args: @@ -472,12 +476,12 @@ def process_image_unit( def process_images_units( - model, + model: str, units: List[FaceSwapUnitSettings], images: List[Tuple[Optional[Image.Image], Optional[str]]], - upscaled_swapper=False, - force_blend=False, -) -> Union[List, None]: + upscaled_swapper: bool = False, + force_blend: bool = False, +) -> Optional[List[Tuple[Image.Image, str]]]: if len(units) == 0: logger.info("Finished processing image, return %s images", len(images)) return None diff --git a/scripts/faceswaplab_swapping/upscaled_inswapper.py b/scripts/faceswaplab_swapping/upscaled_inswapper.py index fc17ce0..2726e1b 100644 --- a/scripts/faceswaplab_swapping/upscaled_inswapper.py +++ b/scripts/faceswaplab_swapping/upscaled_inswapper.py @@ -1,17 +1,11 @@ import cv2 import numpy as np -import onnx -import onnxruntime from insightface.model_zoo.inswapper import INSwapper from insightface.utils import face_align -from modules import codeformer_model, processing, scripts, shared -from modules.face_restoration import FaceRestoration -from modules.shared import cmd_opts, opts, state +from modules import processing, shared +from modules.shared import opts from modules.upscaler import UpscalerData -from onnx import numpy_helper -from PIL import Image -from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_postprocessing import upscaling from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, diff --git a/scripts/faceswaplab_ui/faceswaplab_tab.py b/scripts/faceswaplab_ui/faceswaplab_tab.py index d46c18e..bc3b1d5 100644 --- a/scripts/faceswaplab_ui/faceswaplab_tab.py +++ b/scripts/faceswaplab_ui/faceswaplab_tab.py @@ -5,13 +5,12 @@ from pprint import pformat, pprint import dill as pickle import gradio as gr import modules.scripts as scripts -import numpy as np import onnx import pandas as pd from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui from scripts.faceswaplab_ui.faceswaplab_upscaler_ui import upscaler_ui from insightface.app.common import Face -from modules import script_callbacks, scripts +from modules import scripts from PIL import Image from modules.shared import opts @@ -25,12 +24,13 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import ( ) from scripts.faceswaplab_postprocessing.postprocessing import enhance_image from dataclasses import fields -from typing import List +from typing import Any, List, Optional, Union from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_utils.models_utils import get_current_model +import re -def compare(img1, img2): +def compare(img1: Image.Image, img2: Image.Image) -> Union[float, str]: if img1 is not None and img2 is not None: return swapper.compare_faces(img1, img2) @@ -40,19 +40,10 @@ def compare(img1, img2): def extract_faces( files, extract_path, - face_restorer_name, - face_restorer_visibility, - codeformer_weight, - upscaler_name, - upscaler_scale, - upscaler_visibility, - inpainting_denoising_strengh, - inpainting_prompt, - inpainting_negative_prompt, - inpainting_steps, - inpainting_sampler, - inpainting_when, + *components: List[gr.components.Component], ): + postprocess_options = PostProcessingOptions(*components) # type: ignore + if not extract_path: tempfile.mkdtemp() if files is not None: @@ -66,24 +57,16 @@ def extract_faces( bbox = face.bbox.astype(int) x_min, y_min, x_max, y_max = bbox face_image = img.crop((x_min, y_min, x_max, y_max)) - if face_restorer_name or face_restorer_visibility: - scale = 1 if face_image.width > 512 else 512 // face_image.width + if ( + postprocess_options.face_restorer_name + or postprocess_options.restorer_visibility + ): + postprocess_options.scale = ( + 1 if face_image.width > 512 else 512 // face_image.width + ) face_image = enhance_image( face_image, - PostProcessingOptions( - face_restorer_name=face_restorer_name, - restorer_visibility=face_restorer_visibility, - codeformer_weight=codeformer_weight, - upscaler_name=upscaler_name, - upscale_visibility=upscaler_visibility, - scale=scale, - inpainting_denoising_strengh=inpainting_denoising_strengh, - inpainting_prompt=inpainting_prompt, - inpainting_steps=inpainting_steps, - inpainting_negative_prompt=inpainting_negative_prompt, - inpainting_when=inpainting_when, - inpainting_sampler=inpainting_sampler, - ), + postprocess_options, ) path = tempfile.NamedTemporaryFile( delete=False, suffix=".png", dir=extract_path @@ -95,7 +78,7 @@ def extract_faces( return None -def analyse_faces(image, det_threshold=0.5): +def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> str: try: faces = swapper.get_faces(imgutils.pil_to_cv2(image), det_thresh=det_threshold) result = "" @@ -110,27 +93,40 @@ def analyse_faces(image, det_threshold=0.5): return "Analysis Failed" -def build_face_checkpoint_and_save(batch_files, name): +def sanitize_name(name: str) -> str: + logger.debug(f"Sanitize name {name}") + name = re.sub("[^A-Za-z0-9_. ]+", "", name) + name = name.replace(" ", "_") + logger.debug(f"Sanitized name {name[:255]}") + return name[:255] + + +def build_face_checkpoint_and_save( + batch_files: gr.File, name: str +) -> Optional[Image.Image]: """ - Builds a face checkpoint, swaps faces, and saves the result to a file. + Builds a face checkpoint using the provided image files, performs face swapping, + and saves the result to a file. If a blended face is successfully obtained and the face swapping + process succeeds, the resulting image is returned. Otherwise, None is returned. Args: - batch_files (list): List of image file paths. - name (str): Name of the face checkpoint + batch_files (list): List of image file paths used to create the face checkpoint. + name (str): The name assigned to the face checkpoint. Returns: - PIL.Image.Image or None: Resulting swapped face image if successful, otherwise None. + PIL.Image.Image or None: The resulting swapped face image if the process is successful; None otherwise. """ + name = sanitize_name(name) batch_files = batch_files or [] - print("Build", name, [x.name for x in batch_files]) + logger.info("Build %s %s", name, [x.name for x in batch_files]) faces = swapper.get_faces_from_img_files(batch_files) blended_face = swapper.blend_faces(faces) preview_path = os.path.join( scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" ) faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces") - if not os.path.exists(faces_path): - os.makedirs(faces_path) + + os.makedirs(faces_path, exist_ok=True) target_img = None if blended_face: @@ -208,7 +204,9 @@ def explore_onnx_faceswap_model(model_path): return df -def batch_process(files, save_path, *components): +def batch_process( + files, save_path, *components: List[gr.components.Component] +) -> Optional[List[Image.Image]]: try: if save_path is not None: os.makedirs(save_path, exist_ok=True) @@ -228,7 +226,7 @@ def batch_process(files, save_path, *components): len_conf: int = len(fields(FaceSwapUnitSettings)) shift: int = units_count * len_conf postprocess_options = PostProcessingOptions( - *components[shift : shift + len(fields(PostProcessingOptions))] + *components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore ) logger.debug("%s", pformat(postprocess_options)) @@ -247,7 +245,7 @@ def batch_process(files, save_path, *components): ), ) if len(swapped_images) > 0: - current_images += [img for img, info in swapped_images] + current_images += [img for img, _ in swapped_images] logger.info("%s images generated", len(current_images)) for i, img in enumerate(current_images): @@ -269,7 +267,7 @@ def batch_process(files, save_path, *components): return None -def tools_ui(): +def tools_ui() -> None: models = get_models() with gr.Tab("Tools"): with gr.Tab("Build"): @@ -431,7 +429,7 @@ def tools_ui(): ) -def on_ui_tabs(): +def on_ui_tabs() -> List[Any]: with gr.Blocks(analytics_enabled=False) as ui_faceswap: tools_ui() return [(ui_faceswap, "FaceSwapLab", "faceswaplab_tab")] diff --git a/scripts/faceswaplab_ui/faceswaplab_unit_settings.py b/scripts/faceswaplab_ui/faceswaplab_unit_settings.py index 51769cb..7e9348f 100644 --- a/scripts/faceswaplab_ui/faceswaplab_unit_settings.py +++ b/scripts/faceswaplab_ui/faceswaplab_unit_settings.py @@ -8,7 +8,7 @@ import dill as pickle import gradio as gr from insightface.app.common import Face from PIL import Image -from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw +from scripts.faceswaplab_utils.imgutils import pil_to_cv2 from scripts.faceswaplab_utils.faceswaplab_logging import logger diff --git a/scripts/faceswaplab_ui/faceswaplab_unit_ui.py b/scripts/faceswaplab_ui/faceswaplab_unit_ui.py index 4791803..a57f32c 100644 --- a/scripts/faceswaplab_ui/faceswaplab_unit_ui.py +++ b/scripts/faceswaplab_ui/faceswaplab_unit_ui.py @@ -1,8 +1,11 @@ +from typing import List from scripts.faceswaplab_utils.models_utils import get_face_checkpoints import gradio as gr -def faceswap_unit_ui(is_img2img, unit_num=1, id_prefix="faceswaplab"): +def faceswap_unit_ui( + is_img2img: bool, unit_num: int = 1, id_prefix: str = "faceswaplab" +) -> List[gr.components.Component]: with gr.Tab(f"Face {unit_num}"): with gr.Column(): gr.Markdown( @@ -37,7 +40,7 @@ def faceswap_unit_ui(is_img2img, unit_num=1, id_prefix="faceswaplab"): elem_id=f"{id_prefix}_face{unit_num}_refresh_checkpoints", ) - def refresh_fn(selected): + def refresh_fn(selected: str) -> None: return gr.Dropdown.update( value=selected, choices=get_face_checkpoints() ) diff --git a/scripts/faceswaplab_ui/faceswaplab_upscaler_ui.py b/scripts/faceswaplab_ui/faceswaplab_upscaler_ui.py index c767086..c7ce9fa 100644 --- a/scripts/faceswaplab_ui/faceswaplab_upscaler_ui.py +++ b/scripts/faceswaplab_ui/faceswaplab_upscaler_ui.py @@ -1,13 +1,12 @@ +from typing import List import gradio as gr import modules from modules import shared, sd_models -from modules.shared import cmd_opts, opts, state +from modules.shared import opts +from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen -import scripts.faceswaplab_postprocessing.upscaling as upscaling -from scripts.faceswaplab_utils.faceswaplab_logging import logger - -def upscaler_ui(): +def upscaler_ui() -> List[gr.components.Component]: with gr.Tab(f"Post-Processing"): gr.Markdown( """Upscaling is performed on the whole image. Upscaling happens before face restoration.""" @@ -74,10 +73,8 @@ def upscaler_ui(): ) inpainting_when = gr.Dropdown( elem_id="faceswaplab_pp_inpainting_when", - choices=[ - e.value for e in upscaling.InpaintingWhen.__members__.values() - ], - value=[upscaling.InpaintingWhen.BEFORE_RESTORE_FACE.value], + choices=[e.value for e in InpaintingWhen.__members__.values()], + value=[InpaintingWhen.BEFORE_RESTORE_FACE.value], label="Enable/When", ) inpainting_denoising_strength = gr.Slider( diff --git a/scripts/faceswaplab_utils/faceswaplab_logging.py b/scripts/faceswaplab_utils/faceswaplab_logging.py index 8390e08..ffa58fd 100644 --- a/scripts/faceswaplab_utils/faceswaplab_logging.py +++ b/scripts/faceswaplab_utils/faceswaplab_logging.py @@ -1,12 +1,24 @@ import logging import copy import sys +from typing import Any from modules import shared from PIL import Image +from logging import LogRecord class ColoredFormatter(logging.Formatter): - COLORS = { + """ + A custom logging formatter that outputs logs with level names colored. + + Class Attributes: + COLORS (dict): A dictionary mapping logging level names to their corresponding color codes. + + Inherits From: + logging.Formatter + """ + + COLORS: dict[str, str] = { "DEBUG": "\033[0;36m", # CYAN "INFO": "\033[0;32m", # GREEN "WARNING": "\033[0;33m", # YELLOW @@ -15,7 +27,21 @@ class ColoredFormatter(logging.Formatter): "RESET": "\033[0m", # RESET COLOR } - def format(self, record): + def format(self, record: LogRecord) -> str: + """ + Format the specified record as text. + + The record's attribute dictionary is used as the operand to a string + formatting operation which yields the returned string. Before formatting + the dictionary, a check is made to see if the format uses the levelname + of the record. If it does, a colorized version is created and used. + + Args: + record (LogRecord): The log record to be formatted. + + Returns: + str: The formatted string which includes the colorized levelname. + """ colored_record = copy.copy(record) levelname = colored_record.levelname seq = self.COLORS.get(levelname, self.COLORS["RESET"]) @@ -46,7 +72,24 @@ if logger.getEffectiveLevel() <= logging.DEBUG: DEBUG_DIR = tempfile.mkdtemp() -def save_img_debug(img: Image.Image, message: str, *opts): +def save_img_debug(img: Image.Image, message: str, *opts: Any) -> None: + """ + Saves an image to a temporary file if the logger's effective level is set to DEBUG or lower. + After saving, it logs a debug message along with the file URI of the image. + + Parameters + ---------- + img : Image.Image + The image to be saved. + message : str + The message to be logged. + *opts : Any + Additional arguments to be passed to the logger's debug method. + + Returns + ------- + None + """ if logger.getEffectiveLevel() <= logging.DEBUG: with tempfile.NamedTemporaryFile( dir=DEBUG_DIR, delete=False, suffix=".png" diff --git a/scripts/faceswaplab_utils/imgutils.py b/scripts/faceswaplab_utils/imgutils.py index 23e0f05..f286bdc 100644 --- a/scripts/faceswaplab_utils/imgutils.py +++ b/scripts/faceswaplab_utils/imgutils.py @@ -1,35 +1,76 @@ import io -from typing import Optional -from PIL import Image, ImageChops, ImageOps, ImageFilter +from typing import List, Optional, Tuple, Union, Dict +from PIL import Image import cv2 import numpy as np from math import isqrt, ceil import torch from ifnude import detect -from scripts.faceswaplab_globals import NSFW_SCORE +from scripts.faceswaplab_globals import NSFW_SCORE_THRESHOLD from modules import processing import base64 +from collections import Counter + + +def check_against_nsfw(img: Image.Image) -> bool: + """ + Check if an image exceeds the Not Safe for Work (NSFW) score. + + Parameters: + img (PIL.Image.Image): The image to be checked. + + Returns: + bool: True if any part of the image is considered NSFW, False otherwise. + """ + shapes: List[bool] = [] + chunks: List[Dict[str, Union[int, float]]] = detect(img) -def check_against_nsfw(img): - shapes = [] - chunks = detect(img) for chunk in chunks: - shapes.append(chunk["score"] > NSFW_SCORE) + shapes.append(chunk["score"] > NSFW_SCORE_THRESHOLD) + return any(shapes) -def pil_to_cv2(pil_img): +def pil_to_cv2(pil_img: Image.Image) -> np.ndarray: # type: ignore + """ + Convert a PIL Image into an OpenCV image (cv2). + + Args: + pil_img (PIL.Image.Image): An image in PIL format. + + Returns: + np.ndarray: The input image converted to OpenCV format (BGR). + """ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) -def cv2_to_pil(cv2_img): +def cv2_to_pil(cv2_img: np.ndarray) -> Image.Image: # type: ignore + """ + Convert an OpenCV image (cv2) into a PIL Image. + + Args: + cv2_img (np.ndarray): An image in OpenCV format (BGR). + + Returns: + PIL.Image.Image: The input image converted to PIL format (RGB). + """ return Image.fromarray(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)) -def torch_to_pil(images): +def torch_to_pil(images: torch.Tensor) -> List[Image.Image]: """ - Convert a numpy image or a batch of images to a PIL image. + Converts a tensor image or a batch of tensor images to a PIL image or a list of PIL images. + + Parameters + ---------- + images : torch.Tensor + A tensor representing an image or a batch of images. + + Returns + ------- + list + A list of PIL images. """ images = images.cpu().permute(0, 2, 3, 1).numpy() if images.ndim == 3: @@ -39,9 +80,19 @@ def torch_to_pil(images): return pil_images -def pil_to_torch(pil_images): +def pil_to_torch(pil_images: Union[Image.Image, List[Image.Image]]) -> torch.Tensor: """ - Convert a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors. + Converts a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors. + + Parameters + ---------- + pil_images : Union[Image.Image, List[Image.Image]] + A PIL image or a list of PIL images. + + Returns + ------- + torch.Tensor + A tensor representing an image or a batch of images. """ if isinstance(pil_images, list): numpy_images = [np.array(image) for image in pil_images] @@ -53,10 +104,7 @@ def pil_to_torch(pil_images): return torch_image -from collections import Counter - - -def create_square_image(image_list): +def create_square_image(image_list: List[Image.Image]) -> Optional[Image.Image]: """ Creates a square image by combining multiple images in a grid pattern. @@ -108,16 +156,41 @@ def create_square_image(image_list): return None -def create_mask(image, box_coords): +# def create_mask(image : Image.Image, box_coords : Tuple[int, int, int, int]) -> Image.Image: +# width, height = image.size +# mask = Image.new("L", (width, height), 255) +# x1, y1, x2, y2 = box_coords +# for x in range(width): +# for y in range(height): +# if x1 <= x <= x2 and y1 <= y <= y2: +# mask.putpixel((x, y), 255) +# else: +# mask.putpixel((x, y), 0) +# return mask + + +def create_mask( + image: Image.Image, box_coords: Tuple[int, int, int, int] +) -> Image.Image: + """ + Create a binary mask for a given image and bounding box coordinates. + + Args: + image (PIL.Image.Image): The input image. + box_coords (Tuple[int, int, int, int]): A tuple of 4 integers defining the bounding box. + It follows the pattern (x1, y1, x2, y2), where (x1, y1) is the top-left coordinate of the + box and (x2, y2) is the bottom-right coordinate of the box. + + Returns: + PIL.Image.Image: A binary mask of the same size as the input image, where pixels within + the bounding box are white (255) and pixels outside the bounding box are black (0). + """ width, height = image.size - mask = Image.new("L", (width, height), 255) + mask = Image.new("L", (width, height), 0) x1, y1, x2, y2 = box_coords - for x in range(width): - for y in range(height): - if x1 <= x <= x2 and y1 <= y <= y2: - mask.putpixel((x, y), 255) - else: - mask.putpixel((x, y), 0) + for x in range(x1, x2 + 1): + for y in range(y1, y2 + 1): + mask.putpixel((x, y), 255) return mask @@ -185,12 +258,32 @@ def prepare_mask( def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]: + """ + Converts a base64 string to a PIL Image object. + + Parameters: + base64str (Optional[str]): The base64 string to convert. This string may contain a data URL scheme + (i.e., 'data:image/jpeg;base64,') or just be the raw base64 encoded data. If None, the function + will return None. + + Returns: + Optional[Image.Image]: A PIL Image object created from the base64 string. If the input is None, + the function returns None. + + Raises: + binascii.Error: If the base64 string is not properly formatted or encoded. + PIL.UnidentifiedImageError: If the image format cannot be identified. + """ + if base64str is None: return None - if "base64," in base64str: # check if the base64 string has a data URL scheme + + # Check if the base64 string has a data URL scheme + if "base64," in base64str: base64_data = base64str.split("base64,")[-1] img_bytes = base64.b64decode(base64_data) else: - # if no data URL scheme, just decode + # If no data URL scheme, just decode img_bytes = base64.b64decode(base64str) + return Image.open(io.BytesIO(img_bytes)) diff --git a/scripts/faceswaplab_utils/models_utils.py b/scripts/faceswaplab_utils/models_utils.py index dba19a4..737a173 100644 --- a/scripts/faceswaplab_utils/models_utils.py +++ b/scripts/faceswaplab_utils/models_utils.py @@ -1,5 +1,6 @@ import glob import os +from typing import List import modules.scripts as scripts from modules import scripts from scripts.faceswaplab_globals import EXTENSION_PATH @@ -7,7 +8,7 @@ from modules.shared import opts from scripts.faceswaplab_utils.faceswaplab_logging import logger -def get_models(): +def get_models() -> List[str]: """ Retrieve a list of swap model files. @@ -44,7 +45,7 @@ def get_current_model() -> str: return model -def get_face_checkpoints(): +def get_face_checkpoints() -> List[str]: """ Retrieve a list of face checkpoint paths.