clean code, fix extract

main
Tran Xen 2 years ago
parent 00d9cc6f62
commit a511214aaa

@ -0,0 +1,4 @@
#!/bin/bash
autoflake --in-place --remove-unused-variables -r --remove-all-unused-imports .
mypy --install-types
pre-commit run --all-files

@ -1,5 +1,4 @@
import requests import requests
from PIL import Image
from client_utils import ( from client_utils import (
FaceSwapRequest, FaceSwapRequest,
FaceSwapUnit, FaceSwapUnit,

@ -14,7 +14,7 @@ model_name = os.path.basename(model_url)
model_path = os.path.join(models_dir, model_name) model_path = os.path.join(models_dir, model_name)
def download(url, path): def download(url: str, path: str) -> None:
request = urllib.request.urlopen(url) request = urllib.request.urlopen(url)
total = int(request.headers.get("Content-Length", 0)) total = int(request.headers.get("Content-Length", 0))
with tqdm( with tqdm(

@ -4,4 +4,5 @@ disallow_any_generics = True
disallow_untyped_calls = True disallow_untyped_calls = True
disallow_untyped_defs = True disallow_untyped_defs = True
ignore_missing_imports = True ignore_missing_imports = True
strict_optional = False strict_optional = False
explicit_package_bases=True

@ -1,4 +1,7 @@
def preload(parser): from argparse import ArgumentParser
def preload(parser: ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
"--faceswaplab_loglevel", "--faceswaplab_loglevel",
default="INFO", default="INFO",

@ -4,7 +4,6 @@ from scripts.faceswaplab_settings import faceswaplab_settings
from scripts.faceswaplab_ui import faceswaplab_tab, faceswaplab_unit_ui from scripts.faceswaplab_ui import faceswaplab_tab, faceswaplab_unit_ui
from scripts.faceswaplab_utils.models_utils import ( from scripts.faceswaplab_utils.models_utils import (
get_current_model, get_current_model,
get_face_checkpoints,
) )
from scripts import faceswaplab_globals from scripts import faceswaplab_globals
@ -12,7 +11,6 @@ from scripts.faceswaplab_swapping import swapper
from scripts.faceswaplab_utils import faceswaplab_logging, imgutils from scripts.faceswaplab_utils import faceswaplab_logging, imgutils
from scripts.faceswaplab_utils import models_utils from scripts.faceswaplab_utils import models_utils
from scripts.faceswaplab_postprocessing import upscaling from scripts.faceswaplab_postprocessing import upscaling
import numpy as np
# Reload all the modules when using "apply and restart" # Reload all the modules when using "apply and restart"
# This is mainly done for development purposes # This is mainly done for development purposes
@ -29,15 +27,13 @@ importlib.reload(faceswaplab_api)
import os import os
from dataclasses import fields from dataclasses import fields
from pprint import pformat from pprint import pformat
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple
import dill as pickle
import gradio as gr import gradio as gr
import modules.scripts as scripts import modules.scripts as scripts
from modules import script_callbacks, scripts from modules import script_callbacks, scripts
from insightface.app.common import Face
from modules import scripts, shared from modules import scripts, shared
from modules.images import save_image, image_grid from modules.images import save_image
from modules.processing import ( from modules.processing import (
Processed, Processed,
StableDiffusionProcessing, StableDiffusionProcessing,
@ -46,7 +42,6 @@ from modules.processing import (
from modules.shared import opts from modules.shared import opts
from PIL import Image from PIL import Image
from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw
from scripts.faceswaplab_utils.faceswaplab_logging import logger, save_img_debug from scripts.faceswaplab_utils.faceswaplab_logging import logger, save_img_debug
from scripts.faceswaplab_globals import VERSION_FLAG from scripts.faceswaplab_globals import VERSION_FLAG
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (
@ -76,15 +71,15 @@ class FaceSwapScript(scripts.Script):
super().__init__() super().__init__()
@property @property
def units_count(self): def units_count(self) -> int:
return opts.data.get("faceswaplab_units_count", 3) return opts.data.get("faceswaplab_units_count", 3)
@property @property
def upscaled_swapper_in_generated(self): def upscaled_swapper_in_generated(self) -> bool:
return opts.data.get("faceswaplab_upscaled_swapper", False) return opts.data.get("faceswaplab_upscaled_swapper", False)
@property @property
def upscaled_swapper_in_source(self): def upscaled_swapper_in_source(self) -> bool:
return opts.data.get("faceswaplab_upscaled_swapper_in_source", False) return opts.data.get("faceswaplab_upscaled_swapper_in_source", False)
@property @property
@ -93,24 +88,24 @@ class FaceSwapScript(scripts.Script):
return any([u.enable for u in self.units]) and not shared.state.interrupted return any([u.enable for u in self.units]) and not shared.state.interrupted
@property @property
def keep_original_images(self): def keep_original_images(self) -> bool:
return opts.data.get("faceswaplab_keep_original", False) return opts.data.get("faceswaplab_keep_original", False)
@property @property
def swap_in_generated_units(self): def swap_in_generated_units(self) -> List[FaceSwapUnitSettings]:
return [u for u in self.units if u.swap_in_generated and u.enable] return [u for u in self.units if u.swap_in_generated and u.enable]
@property @property
def swap_in_source_units(self): def swap_in_source_units(self) -> List[FaceSwapUnitSettings]:
return [u for u in self.units if u.swap_in_source and u.enable] return [u for u in self.units if u.swap_in_source and u.enable]
def title(self): def title(self) -> str:
return f"faceswaplab" return f"faceswaplab"
def show(self, is_img2img): def show(self, is_img2img: bool) -> bool:
return scripts.AlwaysVisible return scripts.AlwaysVisible
def ui(self, is_img2img): def ui(self, is_img2img: bool) -> List[gr.components.Component]:
with gr.Accordion(f"FaceSwapLab {VERSION_FLAG}", open=False): with gr.Accordion(f"FaceSwapLab {VERSION_FLAG}", open=False):
components = [] components = []
for i in range(1, self.units_count + 1): for i in range(1, self.units_count + 1):
@ -119,16 +114,9 @@ class FaceSwapScript(scripts.Script):
# If the order is modified, the before_process should be changed accordingly. # If the order is modified, the before_process should be changed accordingly.
return components + upscaler return components + upscaler
# def make_script_first(self,p: StableDiffusionProcessing) : def read_config(
# FIXME : not really useful, will only impact postprocessing (kept for further testing) self, p: StableDiffusionProcessing, *components: List[gr.components.Component]
# runner : scripts.ScriptRunner = p.scripts ) -> None:
# alwayson = runner.alwayson_scripts
# alwayson.pop(alwayson.index(self))
# alwayson.insert(0, self)
# print("Running in ", alwayson.index(self), "position")
# logger.info("Running scripts : %s", pformat(runner.alwayson_scripts))
def read_config(self, p: StableDiffusionProcessing, *components):
# The order of processing for the components is important # The order of processing for the components is important
# The method first process faceswap units then postprocessing units # The method first process faceswap units then postprocessing units
@ -148,14 +136,16 @@ class FaceSwapScript(scripts.Script):
len_conf: int = len(fields(FaceSwapUnitSettings)) len_conf: int = len(fields(FaceSwapUnitSettings))
shift: int = self.units_count * len_conf shift: int = self.units_count * len_conf
self.postprocess_options = PostProcessingOptions( self.postprocess_options = PostProcessingOptions(
*components[shift : shift + len(fields(PostProcessingOptions))] *components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
) )
logger.debug("%s", pformat(self.postprocess_options)) logger.debug("%s", pformat(self.postprocess_options))
if self.enabled: if self.enabled:
p.do_not_save_samples = not self.keep_original_images p.do_not_save_samples = not self.keep_original_images
def process(self, p: StableDiffusionProcessing, *components): def process(
self, p: StableDiffusionProcessing, *components: List[gr.components.Component]
) -> None:
self.read_config(p, *components) self.read_config(p, *components)
# If is instance of img2img, we check if face swapping in source is required. # If is instance of img2img, we check if face swapping in source is required.
@ -175,7 +165,9 @@ class FaceSwapScript(scripts.Script):
if new_inits is not None: if new_inits is not None:
p.init_images = [img[0] for img in new_inits] p.init_images = [img[0] for img in new_inits]
def postprocess(self, p: StableDiffusionProcessing, processed: Processed, *args): def postprocess(
self, p: StableDiffusionProcessing, processed: Processed, *args: List[Any]
) -> None:
if self.enabled: if self.enabled:
# Get the original images without the grid # Get the original images without the grid
orig_images: List[Image.Image] = processed.images[ orig_images: List[Image.Image] = processed.images[

@ -1,52 +1,68 @@
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from fastapi import FastAPI, Body from fastapi import FastAPI
from fastapi.exceptions import HTTPException
from modules.api.models import *
from modules.api import api from modules.api import api
from scripts.faceswaplab_api.faceswaplab_api_types import ( from scripts.faceswaplab_api.faceswaplab_api_types import (
FaceSwapUnit,
FaceSwapRequest, FaceSwapRequest,
FaceSwapResponse, FaceSwapResponse,
) )
from scripts.faceswaplab_globals import VERSION_FLAG from scripts.faceswaplab_globals import VERSION_FLAG
import gradio as gr import gradio as gr
from typing import List, Optional from typing import Dict, List, Optional, Union
from scripts.faceswaplab_swapping import swapper from scripts.faceswaplab_swapping import swapper
from scripts.faceswaplab_utils.faceswaplab_logging import save_img_debug
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_utils.imgutils import ( from scripts.faceswaplab_utils.imgutils import (
pil_to_cv2,
check_against_nsfw,
base64_to_pil, base64_to_pil,
) )
from scripts.faceswaplab_utils.models_utils import get_current_model from scripts.faceswaplab_utils.models_utils import get_current_model
from modules.shared import opts from modules.shared import opts
def encode_to_base64(image): def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str:
if type(image) is str: """
Encode an image to a base64 string.
The image can be a file path (str), a PIL Image, or a NumPy array.
Args:
image (Union[str, Image.Image, np.ndarray]): The image to encode.
Returns:
str: The base64-encoded image if successful, otherwise an empty string.
"""
if isinstance(image, str):
return image return image
elif type(image) is Image.Image: elif isinstance(image, Image.Image):
return api.encode_pil_to_base64(image) return api.encode_pil_to_base64(image)
elif type(image) is np.ndarray: elif isinstance(image, np.ndarray):
return encode_np_to_base64(image) return encode_np_to_base64(image)
else: else:
return "" return ""
def encode_np_to_base64(image): def encode_np_to_base64(image: np.ndarray) -> str:
"""
Encode a NumPy array to a base64 string.
The array is first converted to a PIL Image, then encoded.
Args:
image (np.ndarray): The NumPy array to encode.
Returns:
str: The base64-encoded image.
"""
pil = Image.fromarray(image) pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil) return api.encode_pil_to_base64(pil)
def faceswaplab_api(_: gr.Blocks, app: FastAPI): def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
@app.get( @app.get(
"/faceswaplab/version", "/faceswaplab/version",
tags=["faceswaplab"], tags=["faceswaplab"],
description="Get faceswaplab version", description="Get faceswaplab version",
) )
async def version(): async def version() -> Dict[str, str]:
return {"version": VERSION_FLAG} return {"version": VERSION_FLAG}
# use post as we consider the method non idempotent (which is debatable) # use post as we consider the method non idempotent (which is debatable)

@ -1,16 +1,8 @@
from scripts.faceswaplab_swapping import swapper
import numpy as np
from typing import List, Tuple from typing import List, Tuple
import dill as pickle
import gradio as gr
from insightface.app.common import Face
from PIL import Image from PIL import Image
from scripts.faceswaplab_utils.imgutils import ( from scripts.faceswaplab_utils.imgutils import (
pil_to_cv2,
check_against_nsfw,
base64_to_pil, base64_to_pil,
) )
from scripts.faceswaplab_utils.faceswaplab_logging import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen

@ -1,10 +1,11 @@
from scripts.faceswaplab_utils.faceswaplab_logging import logger
import os import os
MODELS_DIR = os.path.abspath(os.path.join("models", "faceswaplab")) MODELS_DIR = os.path.abspath(os.path.join("models", "faceswaplab"))
ANALYZER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "analysers")) ANALYZER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "analysers"))
FACE_PARSER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "parser")) FACE_PARSER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "parser"))
VERSION_FLAG = "v1.1.0" VERSION_FLAG: str = "v1.1.0"
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab") EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
NSFW_SCORE = 0.7
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
NSFW_SCORE_THRESHOLD: float = 0.7

@ -1,15 +1,11 @@
from modules.face_restoration import FaceRestoration
from modules.upscaler import UpscalerData
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
from PIL import Image from PIL import Image
import numpy as np
from modules import shared from modules import shared
from scripts.faceswaplab_utils import imgutils from scripts.faceswaplab_utils import imgutils
from modules import shared, processing, codeformer_model from modules import shared, processing
from modules.processing import StableDiffusionProcessingImg2Img from modules.processing import StableDiffusionProcessingImg2Img
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,
InpaintingWhen,
) )
from modules import sd_models from modules import sd_models

@ -1,4 +1,3 @@
from modules.face_restoration import FaceRestoration
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
from PIL import Image from PIL import Image
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (

@ -1,11 +1,10 @@
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,
InpaintingWhen,
) )
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from modules import shared, processing, codeformer_model from modules import codeformer_model
def upscale_img(image: Image.Image, pp_options: PostProcessingOptions) -> Image.Image: def upscale_img(image: Image.Image, pp_options: PostProcessingOptions) -> Image.Image:

@ -3,7 +3,7 @@ from modules import script_callbacks, shared
import gradio as gr import gradio as gr
def on_ui_settings(): def on_ui_settings() -> None:
section = ("faceswaplab", "FaceSwapLab") section = ("faceswaplab", "FaceSwapLab")
models = get_models() models = get_models()
shared.opts.add_option( shared.opts.add_option(

@ -40,7 +40,6 @@ please contact the contributor(s) of the work.
import torch import torch
import cv2
import os import os
import torch import torch
from torch.hub import download_url_to_file, get_dir from torch.hub import download_url_to_file, get_dir

@ -1,7 +1,7 @@
import copy import copy
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Set, Tuple, Optional, Union from typing import Any, Dict, List, Set, Tuple, Optional
import cv2 import cv2
import insightface import insightface
@ -13,7 +13,6 @@ from sklearn.metrics.pairwise import cosine_similarity
from scripts.faceswaplab_swapping import upscaled_inswapper from scripts.faceswaplab_swapping import upscaled_inswapper
from scripts.faceswaplab_utils.imgutils import ( from scripts.faceswaplab_utils.imgutils import (
cv2_to_pil,
pil_to_cv2, pil_to_cv2,
check_against_nsfw, check_against_nsfw,
) )
@ -26,7 +25,7 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting
providers = ["CPUExecutionProvider"] providers = ["CPUExecutionProvider"]
def cosine_similarity_face(face1, face2) -> float: def cosine_similarity_face(face1: Face, face2: Face) -> float:
""" """
Calculates the cosine similarity between two face embeddings. Calculates the cosine similarity between two face embeddings.
@ -92,7 +91,7 @@ class FaceModelException(Exception):
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def getAnalysisModel(): def getAnalysisModel() -> insightface.app.FaceAnalysis:
""" """
Retrieves the analysis model for face analysis. Retrieves the analysis model for face analysis.
@ -112,11 +111,11 @@ def getAnalysisModel():
logger.error( logger.error(
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)" "Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
) )
raise FaceModelException("Loading of swapping model failed") raise FaceModelException("Loading of analysis model failed")
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def getFaceSwapModel(model_path: str): def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
""" """
Retrieves the face swap model and initializes it if necessary. Retrieves the face swap model and initializes it if necessary.
@ -135,13 +134,14 @@ def getFaceSwapModel(model_path: str):
logger.error( logger.error(
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)" "Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
) )
raise FaceModelException("Loading of swapping model failed")
def get_faces( def get_faces(
img_data: np.ndarray, img_data: np.ndarray, # type: ignore
det_size=(640, 640), det_size: Tuple[int, int] = (640, 640),
det_thresh: Optional[int] = None, det_thresh: Optional[float] = None,
sort_by_face_size=False, sort_by_face_size: bool = False,
) -> List[Face]: ) -> List[Face]:
""" """
Detects and retrieves faces from an image using an analysis model. Detects and retrieves faces from an image using an analysis model.
@ -211,7 +211,7 @@ class ImageResult:
""" """
def get_or_default(l, index, default): def get_or_default(l: List[Any], index: int, default: Any) -> Any:
""" """
Retrieve the value at the specified index from the given list. Retrieve the value at the specified index from the given list.
If the index is out of bounds, return the default value instead. If the index is out of bounds, return the default value instead.
@ -227,7 +227,10 @@ def get_or_default(l, index, default):
return l[index] if index < len(l) else default return l[index] if index < len(l) else default
def get_faces_from_img_files(files): import gradio as gr
def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[np.ndarray]]: # type: ignore
""" """
Extracts faces from a list of image files. Extracts faces from a list of image files.
@ -300,15 +303,15 @@ def blend_faces(faces: List[Face]) -> Face:
def swap_face( def swap_face(
reference_face: np.ndarray, reference_face: np.ndarray, # type: ignore
source_face: np.ndarray, source_face: np.ndarray, # type: ignore
target_img: Image.Image, target_img: Image.Image,
model: str, model: str,
faces_index: Set[int] = {0}, faces_index: Set[int] = {0},
same_gender=True, same_gender: bool = True,
upscaled_swapper=False, upscaled_swapper: bool = False,
compute_similarity=True, compute_similarity: bool = True,
sort_by_face_size=False, sort_by_face_size: bool = False,
) -> ImageResult: ) -> ImageResult:
""" """
Swaps faces in the target image with the source face. Swaps faces in the target image with the source face.
@ -344,6 +347,7 @@ def swap_face(
for i, swapped_face in enumerate(target_faces): for i, swapped_face in enumerate(target_faces):
logger.info(f"swap face {i}") logger.info(f"swap face {i}")
if i in faces_index: if i in faces_index:
# type : ignore
result = face_swapper.get( result = face_swapper.get(
result, swapped_face, source_face, upscale=upscaled_swapper result, swapped_face, source_face, upscale=upscaled_swapper
) )
@ -385,13 +389,13 @@ def swap_face(
def process_image_unit( def process_image_unit(
model, model: str,
unit: FaceSwapUnitSettings, unit: FaceSwapUnitSettings,
image: Image.Image, image: Image.Image,
info=None, info: str = None,
upscaled_swapper=False, upscaled_swapper: bool = False,
force_blend=False, force_blend: bool = False,
) -> List: ) -> List[Tuple[Image.Image, str]]:
"""Process one image and return a List of (image, info) (one if blended, many if not). """Process one image and return a List of (image, info) (one if blended, many if not).
Args: Args:
@ -472,12 +476,12 @@ def process_image_unit(
def process_images_units( def process_images_units(
model, model: str,
units: List[FaceSwapUnitSettings], units: List[FaceSwapUnitSettings],
images: List[Tuple[Optional[Image.Image], Optional[str]]], images: List[Tuple[Optional[Image.Image], Optional[str]]],
upscaled_swapper=False, upscaled_swapper: bool = False,
force_blend=False, force_blend: bool = False,
) -> Union[List, None]: ) -> Optional[List[Tuple[Image.Image, str]]]:
if len(units) == 0: if len(units) == 0:
logger.info("Finished processing image, return %s images", len(images)) logger.info("Finished processing image, return %s images", len(images))
return None return None

@ -1,17 +1,11 @@
import cv2 import cv2
import numpy as np import numpy as np
import onnx
import onnxruntime
from insightface.model_zoo.inswapper import INSwapper from insightface.model_zoo.inswapper import INSwapper
from insightface.utils import face_align from insightface.utils import face_align
from modules import codeformer_model, processing, scripts, shared from modules import processing, shared
from modules.face_restoration import FaceRestoration from modules.shared import opts
from modules.shared import cmd_opts, opts, state
from modules.upscaler import UpscalerData from modules.upscaler import UpscalerData
from onnx import numpy_helper
from PIL import Image
from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_postprocessing import upscaling from scripts.faceswaplab_postprocessing import upscaling
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,

@ -5,13 +5,12 @@ from pprint import pformat, pprint
import dill as pickle import dill as pickle
import gradio as gr import gradio as gr
import modules.scripts as scripts import modules.scripts as scripts
import numpy as np
import onnx import onnx
import pandas as pd import pandas as pd
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
from scripts.faceswaplab_ui.faceswaplab_upscaler_ui import upscaler_ui from scripts.faceswaplab_ui.faceswaplab_upscaler_ui import upscaler_ui
from insightface.app.common import Face from insightface.app.common import Face
from modules import script_callbacks, scripts from modules import scripts
from PIL import Image from PIL import Image
from modules.shared import opts from modules.shared import opts
@ -25,12 +24,13 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
) )
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from dataclasses import fields from dataclasses import fields
from typing import List from typing import Any, List, Optional, Union
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_utils.models_utils import get_current_model from scripts.faceswaplab_utils.models_utils import get_current_model
import re
def compare(img1, img2): def compare(img1: Image.Image, img2: Image.Image) -> Union[float, str]:
if img1 is not None and img2 is not None: if img1 is not None and img2 is not None:
return swapper.compare_faces(img1, img2) return swapper.compare_faces(img1, img2)
@ -40,19 +40,10 @@ def compare(img1, img2):
def extract_faces( def extract_faces(
files, files,
extract_path, extract_path,
face_restorer_name, *components: List[gr.components.Component],
face_restorer_visibility,
codeformer_weight,
upscaler_name,
upscaler_scale,
upscaler_visibility,
inpainting_denoising_strengh,
inpainting_prompt,
inpainting_negative_prompt,
inpainting_steps,
inpainting_sampler,
inpainting_when,
): ):
postprocess_options = PostProcessingOptions(*components) # type: ignore
if not extract_path: if not extract_path:
tempfile.mkdtemp() tempfile.mkdtemp()
if files is not None: if files is not None:
@ -66,24 +57,16 @@ def extract_faces(
bbox = face.bbox.astype(int) bbox = face.bbox.astype(int)
x_min, y_min, x_max, y_max = bbox x_min, y_min, x_max, y_max = bbox
face_image = img.crop((x_min, y_min, x_max, y_max)) face_image = img.crop((x_min, y_min, x_max, y_max))
if face_restorer_name or face_restorer_visibility: if (
scale = 1 if face_image.width > 512 else 512 // face_image.width postprocess_options.face_restorer_name
or postprocess_options.restorer_visibility
):
postprocess_options.scale = (
1 if face_image.width > 512 else 512 // face_image.width
)
face_image = enhance_image( face_image = enhance_image(
face_image, face_image,
PostProcessingOptions( postprocess_options,
face_restorer_name=face_restorer_name,
restorer_visibility=face_restorer_visibility,
codeformer_weight=codeformer_weight,
upscaler_name=upscaler_name,
upscale_visibility=upscaler_visibility,
scale=scale,
inpainting_denoising_strengh=inpainting_denoising_strengh,
inpainting_prompt=inpainting_prompt,
inpainting_steps=inpainting_steps,
inpainting_negative_prompt=inpainting_negative_prompt,
inpainting_when=inpainting_when,
inpainting_sampler=inpainting_sampler,
),
) )
path = tempfile.NamedTemporaryFile( path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=extract_path delete=False, suffix=".png", dir=extract_path
@ -95,7 +78,7 @@ def extract_faces(
return None return None
def analyse_faces(image, det_threshold=0.5): def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> str:
try: try:
faces = swapper.get_faces(imgutils.pil_to_cv2(image), det_thresh=det_threshold) faces = swapper.get_faces(imgutils.pil_to_cv2(image), det_thresh=det_threshold)
result = "" result = ""
@ -110,27 +93,40 @@ def analyse_faces(image, det_threshold=0.5):
return "Analysis Failed" return "Analysis Failed"
def build_face_checkpoint_and_save(batch_files, name): def sanitize_name(name: str) -> str:
logger.debug(f"Sanitize name {name}")
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
name = name.replace(" ", "_")
logger.debug(f"Sanitized name {name[:255]}")
return name[:255]
def build_face_checkpoint_and_save(
batch_files: gr.File, name: str
) -> Optional[Image.Image]:
""" """
Builds a face checkpoint, swaps faces, and saves the result to a file. Builds a face checkpoint using the provided image files, performs face swapping,
and saves the result to a file. If a blended face is successfully obtained and the face swapping
process succeeds, the resulting image is returned. Otherwise, None is returned.
Args: Args:
batch_files (list): List of image file paths. batch_files (list): List of image file paths used to create the face checkpoint.
name (str): Name of the face checkpoint name (str): The name assigned to the face checkpoint.
Returns: Returns:
PIL.Image.Image or None: Resulting swapped face image if successful, otherwise None. PIL.Image.Image or None: The resulting swapped face image if the process is successful; None otherwise.
""" """
name = sanitize_name(name)
batch_files = batch_files or [] batch_files = batch_files or []
print("Build", name, [x.name for x in batch_files]) logger.info("Build %s %s", name, [x.name for x in batch_files])
faces = swapper.get_faces_from_img_files(batch_files) faces = swapper.get_faces_from_img_files(batch_files)
blended_face = swapper.blend_faces(faces) blended_face = swapper.blend_faces(faces)
preview_path = os.path.join( preview_path = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
) )
faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces") faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
if not os.path.exists(faces_path):
os.makedirs(faces_path) os.makedirs(faces_path, exist_ok=True)
target_img = None target_img = None
if blended_face: if blended_face:
@ -208,7 +204,9 @@ def explore_onnx_faceswap_model(model_path):
return df return df
def batch_process(files, save_path, *components): def batch_process(
files, save_path, *components: List[gr.components.Component]
) -> Optional[List[Image.Image]]:
try: try:
if save_path is not None: if save_path is not None:
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
@ -228,7 +226,7 @@ def batch_process(files, save_path, *components):
len_conf: int = len(fields(FaceSwapUnitSettings)) len_conf: int = len(fields(FaceSwapUnitSettings))
shift: int = units_count * len_conf shift: int = units_count * len_conf
postprocess_options = PostProcessingOptions( postprocess_options = PostProcessingOptions(
*components[shift : shift + len(fields(PostProcessingOptions))] *components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
) )
logger.debug("%s", pformat(postprocess_options)) logger.debug("%s", pformat(postprocess_options))
@ -247,7 +245,7 @@ def batch_process(files, save_path, *components):
), ),
) )
if len(swapped_images) > 0: if len(swapped_images) > 0:
current_images += [img for img, info in swapped_images] current_images += [img for img, _ in swapped_images]
logger.info("%s images generated", len(current_images)) logger.info("%s images generated", len(current_images))
for i, img in enumerate(current_images): for i, img in enumerate(current_images):
@ -269,7 +267,7 @@ def batch_process(files, save_path, *components):
return None return None
def tools_ui(): def tools_ui() -> None:
models = get_models() models = get_models()
with gr.Tab("Tools"): with gr.Tab("Tools"):
with gr.Tab("Build"): with gr.Tab("Build"):
@ -431,7 +429,7 @@ def tools_ui():
) )
def on_ui_tabs(): def on_ui_tabs() -> List[Any]:
with gr.Blocks(analytics_enabled=False) as ui_faceswap: with gr.Blocks(analytics_enabled=False) as ui_faceswap:
tools_ui() tools_ui()
return [(ui_faceswap, "FaceSwapLab", "faceswaplab_tab")] return [(ui_faceswap, "FaceSwapLab", "faceswaplab_tab")]

@ -8,7 +8,7 @@ import dill as pickle
import gradio as gr import gradio as gr
from insightface.app.common import Face from insightface.app.common import Face
from PIL import Image from PIL import Image
from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw from scripts.faceswaplab_utils.imgutils import pil_to_cv2
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger

@ -1,8 +1,11 @@
from typing import List
from scripts.faceswaplab_utils.models_utils import get_face_checkpoints from scripts.faceswaplab_utils.models_utils import get_face_checkpoints
import gradio as gr import gradio as gr
def faceswap_unit_ui(is_img2img, unit_num=1, id_prefix="faceswaplab"): def faceswap_unit_ui(
is_img2img: bool, unit_num: int = 1, id_prefix: str = "faceswaplab"
) -> List[gr.components.Component]:
with gr.Tab(f"Face {unit_num}"): with gr.Tab(f"Face {unit_num}"):
with gr.Column(): with gr.Column():
gr.Markdown( gr.Markdown(
@ -37,7 +40,7 @@ def faceswap_unit_ui(is_img2img, unit_num=1, id_prefix="faceswaplab"):
elem_id=f"{id_prefix}_face{unit_num}_refresh_checkpoints", elem_id=f"{id_prefix}_face{unit_num}_refresh_checkpoints",
) )
def refresh_fn(selected): def refresh_fn(selected: str) -> None:
return gr.Dropdown.update( return gr.Dropdown.update(
value=selected, choices=get_face_checkpoints() value=selected, choices=get_face_checkpoints()
) )

@ -1,13 +1,12 @@
from typing import List
import gradio as gr import gradio as gr
import modules import modules
from modules import shared, sd_models from modules import shared, sd_models
from modules.shared import cmd_opts, opts, state from modules.shared import opts
from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen
import scripts.faceswaplab_postprocessing.upscaling as upscaling
from scripts.faceswaplab_utils.faceswaplab_logging import logger
def upscaler_ui() -> List[gr.components.Component]:
def upscaler_ui():
with gr.Tab(f"Post-Processing"): with gr.Tab(f"Post-Processing"):
gr.Markdown( gr.Markdown(
"""Upscaling is performed on the whole image. Upscaling happens before face restoration.""" """Upscaling is performed on the whole image. Upscaling happens before face restoration."""
@ -74,10 +73,8 @@ def upscaler_ui():
) )
inpainting_when = gr.Dropdown( inpainting_when = gr.Dropdown(
elem_id="faceswaplab_pp_inpainting_when", elem_id="faceswaplab_pp_inpainting_when",
choices=[ choices=[e.value for e in InpaintingWhen.__members__.values()],
e.value for e in upscaling.InpaintingWhen.__members__.values() value=[InpaintingWhen.BEFORE_RESTORE_FACE.value],
],
value=[upscaling.InpaintingWhen.BEFORE_RESTORE_FACE.value],
label="Enable/When", label="Enable/When",
) )
inpainting_denoising_strength = gr.Slider( inpainting_denoising_strength = gr.Slider(

@ -1,12 +1,24 @@
import logging import logging
import copy import copy
import sys import sys
from typing import Any
from modules import shared from modules import shared
from PIL import Image from PIL import Image
from logging import LogRecord
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
COLORS = { """
A custom logging formatter that outputs logs with level names colored.
Class Attributes:
COLORS (dict): A dictionary mapping logging level names to their corresponding color codes.
Inherits From:
logging.Formatter
"""
COLORS: dict[str, str] = {
"DEBUG": "\033[0;36m", # CYAN "DEBUG": "\033[0;36m", # CYAN
"INFO": "\033[0;32m", # GREEN "INFO": "\033[0;32m", # GREEN
"WARNING": "\033[0;33m", # YELLOW "WARNING": "\033[0;33m", # YELLOW
@ -15,7 +27,21 @@ class ColoredFormatter(logging.Formatter):
"RESET": "\033[0m", # RESET COLOR "RESET": "\033[0m", # RESET COLOR
} }
def format(self, record): def format(self, record: LogRecord) -> str:
"""
Format the specified record as text.
The record's attribute dictionary is used as the operand to a string
formatting operation which yields the returned string. Before formatting
the dictionary, a check is made to see if the format uses the levelname
of the record. If it does, a colorized version is created and used.
Args:
record (LogRecord): The log record to be formatted.
Returns:
str: The formatted string which includes the colorized levelname.
"""
colored_record = copy.copy(record) colored_record = copy.copy(record)
levelname = colored_record.levelname levelname = colored_record.levelname
seq = self.COLORS.get(levelname, self.COLORS["RESET"]) seq = self.COLORS.get(levelname, self.COLORS["RESET"])
@ -46,7 +72,24 @@ if logger.getEffectiveLevel() <= logging.DEBUG:
DEBUG_DIR = tempfile.mkdtemp() DEBUG_DIR = tempfile.mkdtemp()
def save_img_debug(img: Image.Image, message: str, *opts): def save_img_debug(img: Image.Image, message: str, *opts: Any) -> None:
"""
Saves an image to a temporary file if the logger's effective level is set to DEBUG or lower.
After saving, it logs a debug message along with the file URI of the image.
Parameters
----------
img : Image.Image
The image to be saved.
message : str
The message to be logged.
*opts : Any
Additional arguments to be passed to the logger's debug method.
Returns
-------
None
"""
if logger.getEffectiveLevel() <= logging.DEBUG: if logger.getEffectiveLevel() <= logging.DEBUG:
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
dir=DEBUG_DIR, delete=False, suffix=".png" dir=DEBUG_DIR, delete=False, suffix=".png"

@ -1,35 +1,76 @@
import io import io
from typing import Optional from typing import List, Optional, Tuple, Union, Dict
from PIL import Image, ImageChops, ImageOps, ImageFilter from PIL import Image
import cv2 import cv2
import numpy as np import numpy as np
from math import isqrt, ceil from math import isqrt, ceil
import torch import torch
from ifnude import detect from ifnude import detect
from scripts.faceswaplab_globals import NSFW_SCORE from scripts.faceswaplab_globals import NSFW_SCORE_THRESHOLD
from modules import processing from modules import processing
import base64 import base64
from collections import Counter
def check_against_nsfw(img: Image.Image) -> bool:
"""
Check if an image exceeds the Not Safe for Work (NSFW) score.
Parameters:
img (PIL.Image.Image): The image to be checked.
Returns:
bool: True if any part of the image is considered NSFW, False otherwise.
"""
shapes: List[bool] = []
chunks: List[Dict[str, Union[int, float]]] = detect(img)
def check_against_nsfw(img):
shapes = []
chunks = detect(img)
for chunk in chunks: for chunk in chunks:
shapes.append(chunk["score"] > NSFW_SCORE) shapes.append(chunk["score"] > NSFW_SCORE_THRESHOLD)
return any(shapes) return any(shapes)
def pil_to_cv2(pil_img): def pil_to_cv2(pil_img: Image.Image) -> np.ndarray: # type: ignore
"""
Convert a PIL Image into an OpenCV image (cv2).
Args:
pil_img (PIL.Image.Image): An image in PIL format.
Returns:
np.ndarray: The input image converted to OpenCV format (BGR).
"""
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
def cv2_to_pil(cv2_img): def cv2_to_pil(cv2_img: np.ndarray) -> Image.Image: # type: ignore
"""
Convert an OpenCV image (cv2) into a PIL Image.
Args:
cv2_img (np.ndarray): An image in OpenCV format (BGR).
Returns:
PIL.Image.Image: The input image converted to PIL format (RGB).
"""
return Image.fromarray(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)) return Image.fromarray(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB))
def torch_to_pil(images): def torch_to_pil(images: torch.Tensor) -> List[Image.Image]:
""" """
Convert a numpy image or a batch of images to a PIL image. Converts a tensor image or a batch of tensor images to a PIL image or a list of PIL images.
Parameters
----------
images : torch.Tensor
A tensor representing an image or a batch of images.
Returns
-------
list
A list of PIL images.
""" """
images = images.cpu().permute(0, 2, 3, 1).numpy() images = images.cpu().permute(0, 2, 3, 1).numpy()
if images.ndim == 3: if images.ndim == 3:
@ -39,9 +80,19 @@ def torch_to_pil(images):
return pil_images return pil_images
def pil_to_torch(pil_images): def pil_to_torch(pil_images: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
""" """
Convert a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors. Converts a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors.
Parameters
----------
pil_images : Union[Image.Image, List[Image.Image]]
A PIL image or a list of PIL images.
Returns
-------
torch.Tensor
A tensor representing an image or a batch of images.
""" """
if isinstance(pil_images, list): if isinstance(pil_images, list):
numpy_images = [np.array(image) for image in pil_images] numpy_images = [np.array(image) for image in pil_images]
@ -53,10 +104,7 @@ def pil_to_torch(pil_images):
return torch_image return torch_image
from collections import Counter def create_square_image(image_list: List[Image.Image]) -> Optional[Image.Image]:
def create_square_image(image_list):
""" """
Creates a square image by combining multiple images in a grid pattern. Creates a square image by combining multiple images in a grid pattern.
@ -108,16 +156,41 @@ def create_square_image(image_list):
return None return None
def create_mask(image, box_coords): # def create_mask(image : Image.Image, box_coords : Tuple[int, int, int, int]) -> Image.Image:
# width, height = image.size
# mask = Image.new("L", (width, height), 255)
# x1, y1, x2, y2 = box_coords
# for x in range(width):
# for y in range(height):
# if x1 <= x <= x2 and y1 <= y <= y2:
# mask.putpixel((x, y), 255)
# else:
# mask.putpixel((x, y), 0)
# return mask
def create_mask(
image: Image.Image, box_coords: Tuple[int, int, int, int]
) -> Image.Image:
"""
Create a binary mask for a given image and bounding box coordinates.
Args:
image (PIL.Image.Image): The input image.
box_coords (Tuple[int, int, int, int]): A tuple of 4 integers defining the bounding box.
It follows the pattern (x1, y1, x2, y2), where (x1, y1) is the top-left coordinate of the
box and (x2, y2) is the bottom-right coordinate of the box.
Returns:
PIL.Image.Image: A binary mask of the same size as the input image, where pixels within
the bounding box are white (255) and pixels outside the bounding box are black (0).
"""
width, height = image.size width, height = image.size
mask = Image.new("L", (width, height), 255) mask = Image.new("L", (width, height), 0)
x1, y1, x2, y2 = box_coords x1, y1, x2, y2 = box_coords
for x in range(width): for x in range(x1, x2 + 1):
for y in range(height): for y in range(y1, y2 + 1):
if x1 <= x <= x2 and y1 <= y <= y2: mask.putpixel((x, y), 255)
mask.putpixel((x, y), 255)
else:
mask.putpixel((x, y), 0)
return mask return mask
@ -185,12 +258,32 @@ def prepare_mask(
def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]: def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]:
"""
Converts a base64 string to a PIL Image object.
Parameters:
base64str (Optional[str]): The base64 string to convert. This string may contain a data URL scheme
(i.e., 'data:image/jpeg;base64,') or just be the raw base64 encoded data. If None, the function
will return None.
Returns:
Optional[Image.Image]: A PIL Image object created from the base64 string. If the input is None,
the function returns None.
Raises:
binascii.Error: If the base64 string is not properly formatted or encoded.
PIL.UnidentifiedImageError: If the image format cannot be identified.
"""
if base64str is None: if base64str is None:
return None return None
if "base64," in base64str: # check if the base64 string has a data URL scheme
# Check if the base64 string has a data URL scheme
if "base64," in base64str:
base64_data = base64str.split("base64,")[-1] base64_data = base64str.split("base64,")[-1]
img_bytes = base64.b64decode(base64_data) img_bytes = base64.b64decode(base64_data)
else: else:
# if no data URL scheme, just decode # If no data URL scheme, just decode
img_bytes = base64.b64decode(base64str) img_bytes = base64.b64decode(base64str)
return Image.open(io.BytesIO(img_bytes)) return Image.open(io.BytesIO(img_bytes))

@ -1,5 +1,6 @@
import glob import glob
import os import os
from typing import List
import modules.scripts as scripts import modules.scripts as scripts
from modules import scripts from modules import scripts
from scripts.faceswaplab_globals import EXTENSION_PATH from scripts.faceswaplab_globals import EXTENSION_PATH
@ -7,7 +8,7 @@ from modules.shared import opts
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
def get_models(): def get_models() -> List[str]:
""" """
Retrieve a list of swap model files. Retrieve a list of swap model files.
@ -44,7 +45,7 @@ def get_current_model() -> str:
return model return model
def get_face_checkpoints(): def get_face_checkpoints() -> List[str]:
""" """
Retrieve a list of face checkpoint paths. Retrieve a list of face checkpoint paths.

Loading…
Cancel
Save