fix similarity, add checksum for swapper, fix minor bugs

main
Tran Xen 2 years ago
parent ee7f7d09d2
commit b773bda19f

@ -20,6 +20,10 @@ In short:
More on this here : https://glucauze.github.io/sd-webui-faceswaplab/ More on this here : https://glucauze.github.io/sd-webui-faceswaplab/
### Known problems (wontfix):
+ Older versions of gradio don't work well with the extension. See this bug : https://github.com/glucauze/sd-webui-faceswaplab/issues/5
### Features ### Features
+ **Face Unit Concept**: Similar to controlNet, the program introduces the concept of a face unit. You can configure up to 10 units (3 units are the default setting) in the program settings (sd). + **Face Unit Concept**: Similar to controlNet, the program introduces the concept of a face unit. You can configure up to 10 units (3 units are the default setting) in the program settings (sd).

@ -2,7 +2,7 @@ cython
ifnude ifnude
insightface==0.7.3 insightface==0.7.3
onnx==1.14.0 onnx==1.14.0
onnxruntime==1.15.0 onnxruntime==1.15.1
opencv-python==4.7.0.72 opencv-python==4.7.0.72
pandas pandas
pydantic==1.10.9 pydantic==1.10.9

@ -27,7 +27,6 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,
) )
from scripts.faceswaplab_utils.models_utils import get_current_model from scripts.faceswaplab_utils.models_utils import get_current_model
import gradio as gr
from scripts.faceswaplab_utils.typing import CV2ImgU8, PILImage, Face from scripts.faceswaplab_utils.typing import CV2ImgU8, PILImage, Face
from scripts.faceswaplab_inpainting.i2i_pp import img2img_diffusion from scripts.faceswaplab_inpainting.i2i_pp import img2img_diffusion
@ -250,6 +249,21 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis:
raise FaceModelException("Loading of analysis model failed") raise FaceModelException("Loading of analysis model failed")
import hashlib
def is_sha1_matching(file_path: str, expected_sha1: str) -> bool:
sha1_hash = hashlib.sha1()
with open(file_path, "rb") as file:
for byte_block in iter(lambda: file.read(4096), b""):
sha1_hash.update(byte_block)
if sha1_hash.hexdigest() == expected_sha1:
return True
else:
return False
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper: def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
""" """
@ -262,6 +276,14 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
insightface.model_zoo.FaceModel: The face swap model. insightface.model_zoo.FaceModel: The face swap model.
""" """
try: try:
expected_sha1 = "17a64851eaefd55ea597ee41e5c18409754244c5"
if not is_sha1_matching(model_path, expected_sha1):
logger.error(
"Suspicious sha1 for model %s, check the model is valid or has been downloaded adequately. Should be %s",
model_path,
expected_sha1,
)
# Initializes the face swap model using the specified model path. # Initializes the face swap model using the specified model path.
return upscaled_inswapper.UpscaledINSwapper( return upscaled_inswapper.UpscaledINSwapper(
insightface.model_zoo.get_model(model_path, providers=providers) insightface.model_zoo.get_model(model_path, providers=providers)
@ -270,6 +292,9 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
logger.error( logger.error(
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)" "Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
) )
import traceback
traceback.print_exc()
raise FaceModelException("Loading of swapping model failed") raise FaceModelException("Loading of swapping model failed")
@ -315,11 +340,15 @@ def get_faces(
return [] return []
@dataclass
class FaceFilteringOptions:
faces_index: Set[int]
source_gender: Optional[int] = None # if none will not use same gender
sort_by_face_size: bool = False
def filter_faces( def filter_faces(
all_faces: List[Face], all_faces: List[Face], filtering_options: FaceFilteringOptions
faces_index: Set[int],
source_gender: int = None,
sort_by_face_size: bool = False,
) -> List[Face]: ) -> List[Face]:
""" """
Sorts and filters a list of faces based on specified criteria. Sorts and filters a list of faces based on specified criteria.
@ -337,18 +366,24 @@ def filter_faces(
:return: A list of Face objects sorted and filtered according to the specified criteria. :return: A list of Face objects sorted and filtered according to the specified criteria.
""" """
filtered_faces = copy.copy(all_faces) filtered_faces = copy.copy(all_faces)
if sort_by_face_size: if filtering_options.sort_by_face_size:
filtered_faces = sorted( filtered_faces = sorted(
all_faces, all_faces,
reverse=True, reverse=True,
key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]),
) )
if source_gender is not None: if filtering_options.source_gender is not None:
filtered_faces = [ filtered_faces = [
face for face in filtered_faces if face["gender"] == source_gender face
for face in filtered_faces
if face["gender"] == filtering_options.source_gender
]
return [
face
for i, face in enumerate(filtered_faces)
if i in filtering_options.faces_index
] ]
return [face for i, face in enumerate(filtered_faces) if i in faces_index]
@dataclass @dataclass
@ -391,7 +426,7 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any:
return l[index] if index < len(l) else default return l[index] if index < len(l) else default
def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[CV2ImgU8]]: def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]:
""" """
Extracts faces from a list of image files. Extracts faces from a list of image files.
@ -407,7 +442,7 @@ def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[CV2ImgU8]]:
if len(files) > 0: if len(files) > 0:
for file in files: for file in files:
img = Image.open(file.name) # Open the image file img = Image.open(file) # Open the image file
face = get_or_default( face = get_or_default(
get_faces(pil_to_cv2(img)), 0, None get_faces(pil_to_cv2(img)), 0, None
) # Extract faces from the image ) # Extract faces from the image
@ -503,41 +538,44 @@ def swap_face(
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
return_result.image = result_image return_result.image = result_image
# FIXME : recompute similarity
# if compute_similarity:
# try:
# result_faces = get_faces(
# cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR),
# sort_by_face_size=sort_by_face_size,
# )
# if same_gender:
# result_faces = [
# x for x in result_faces if x["gender"] == gender
# ]
# for i, swapped_face in enumerate(result_faces):
# logger.info(f"compare face {i}")
# if i in faces_index and i < len(target_faces):
# return_result.similarity[i] = cosine_similarity_face(
# source_face, swapped_face
# )
# return_result.ref_similarity[i] = cosine_similarity_face(
# reference_face, swapped_face
# )
# logger.info(f"similarity {return_result.similarity}")
# logger.info(f"ref similarity {return_result.ref_similarity}")
# except Exception as e:
# logger.error("Similarity processing failed %s", e)
# raise e
except Exception as e: except Exception as e:
logger.error("Conversion failed %s", e) logger.error("Conversion failed %s", e)
raise e raise e
return return_result return return_result
def compute_similarity(
reference_face: Face,
source_face: Face,
swapped_image: PILImage,
filtering: FaceFilteringOptions,
) -> Tuple[Dict[int, float], Dict[int, float]]:
similarity: Dict[int, float] = {}
ref_similarity: Dict[int, float] = {}
try:
swapped_image_cv2: CV2ImgU8 = cv2.cvtColor(
np.array(swapped_image), cv2.COLOR_RGB2BGR
)
new_faces = filter_faces(get_faces(swapped_image_cv2), filtering)
if len(new_faces) == 0:
logger.error("compute_similarity : No faces to compare with !")
return None
for i, swapped_face in enumerate(new_faces):
logger.info(f"compare face {i}")
similarity[i] = cosine_similarity_face(source_face, swapped_face)
ref_similarity[i] = cosine_similarity_face(reference_face, swapped_face)
logger.info(f"similarity {similarity}")
logger.info(f"ref similarity {ref_similarity}")
return (similarity, ref_similarity)
except Exception as e:
logger.error("Similarity processing failed %s", e)
raise e
return None
def process_image_unit( def process_image_unit(
model: str, model: str,
unit: FaceSwapUnitSettings, unit: FaceSwapUnitSettings,
@ -580,13 +618,14 @@ def process_image_unit(
logger.info("Use source face as reference face") logger.info("Use source face as reference face")
reference_face = src_face reference_face = src_face
target_faces = filter_faces( face_filtering_options = FaceFilteringOptions(
faces,
faces_index=unit.faces_index, faces_index=unit.faces_index,
source_gender=src_face["gender"] if unit.same_gender else None, source_gender=src_face["gender"] if unit.same_gender else None,
sort_by_face_size=unit.sort_by_size, sort_by_face_size=unit.sort_by_size,
) )
target_faces = filter_faces(faces, filtering_options=face_filtering_options)
# Apply pre-inpainting to image # Apply pre-inpainting to image
if unit.pre_inpainting.inpainting_denoising_strengh > 0: if unit.pre_inpainting.inpainting_denoising_strengh > 0:
current_image = img2img_diffusion( current_image = img2img_diffusion(
@ -611,6 +650,18 @@ def process_image_unit(
save_img_debug(result.image, "After swap") save_img_debug(result.image, "After swap")
if unit.compute_similarity:
similarities = compute_similarity(
reference_face=reference_face,
source_face=src_face,
swapped_image=result.image,
filtering=face_filtering_options,
)
if similarities:
(result.similarity, result.ref_similarity) = similarities
else:
logger.error("Failed to compute similarity")
if result.image is None: if result.image is None:
logger.error("Result image is None") logger.error("Result image is None")
if ( if (

@ -1,26 +1,21 @@
import os
import re
import traceback import traceback
from pprint import pformat, pprint from pprint import pformat
from typing import * from typing import *
from scripts.faceswaplab_utils.typing import * from scripts.faceswaplab_utils.typing import *
import gradio as gr import gradio as gr
import modules.scripts as scripts
import onnx import onnx
import pandas as pd import pandas as pd
from modules import scripts
from modules.shared import opts from modules.shared import opts
from PIL import Image from PIL import Image
import scripts.faceswaplab_swapping.swapper as swapper import scripts.faceswaplab_swapping.swapper as swapper
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,
) )
from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
from scripts.faceswaplab_utils import face_utils, imgutils from scripts.faceswaplab_utils import face_checkpoints_utils, imgutils
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_utils.models_utils import get_models from scripts.faceswaplab_utils.models_utils import get_models
from scripts.faceswaplab_utils.ui_utils import dataclasses_from_flat_list from scripts.faceswaplab_utils.ui_utils import dataclasses_from_flat_list
@ -138,24 +133,9 @@ def analyse_faces(image: PILImage, det_threshold: float = 0.5) -> Optional[str]:
return None return None
def sanitize_name(name: str) -> str:
"""
Sanitize the input name by removing special characters and replacing spaces with underscores.
Parameters:
name (str): The input name to be sanitized.
Returns:
str: The sanitized name with special characters removed and spaces replaced by underscores.
"""
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
name = name.replace(" ", "_")
return name[:255]
def build_face_checkpoint_and_save( def build_face_checkpoint_and_save(
batch_files: gr.File, name: str batch_files: gr.File, name: str, overwrite: bool
) -> Optional[PILImage]: ) -> PILImage:
""" """
Builds a face checkpoint using the provided image files, performs face swapping, Builds a face checkpoint using the provided image files, performs face swapping,
and saves the result to a file. If a blended face is successfully obtained and the face swapping and saves the result to a file. If a blended face is successfully obtained and the face swapping
@ -170,79 +150,19 @@ def build_face_checkpoint_and_save(
""" """
try: try:
name = sanitize_name(name) if not batch_files:
batch_files = batch_files or []
logger.info("Build %s %s", name, [x.name for x in batch_files])
faces = swapper.get_faces_from_img_files(batch_files)
blended_face = swapper.blend_faces(faces)
preview_path = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
os.makedirs(faces_path, exist_ok=True)
target_img: PILImage = None
if blended_face:
if blended_face["gender"] == 0:
target_img = Image.open(os.path.join(preview_path, "woman.png"))
else:
target_img = Image.open(os.path.join(preview_path, "man.png"))
if name == "":
name = "default_name"
pprint(blended_face)
target_face = swapper.get_or_default(
swapper.get_faces(imgutils.pil_to_cv2(target_img)), 0, None
)
if target_face is None:
logger.error(
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
)
else:
result = swapper.swap_face(
reference_face=blended_face,
target_faces=[target_face],
source_face=blended_face,
target_img=target_img,
model=get_models()[0],
upscaled_swapper=opts.data.get(
"faceswaplab_upscaled_swapper", False
),
)
result_image = enhance_image(
result.image,
PostProcessingOptions(
face_restorer_name="CodeFormer", restorer_visibility=1
),
)
file_path = os.path.join(faces_path, f"{name}.safetensors")
file_number = 1
while os.path.exists(file_path):
file_path = os.path.join(
faces_path, f"{name}_{file_number}.safetensors"
)
file_number += 1
result_image.save(file_path + ".png")
face_utils.save_face(filename=file_path, face=blended_face)
try:
data = face_utils.load_face(filename=file_path)
logger.debug(data)
except Exception as e:
print(e)
return result_image
logger.error("No face found") logger.error("No face found")
return None
filenames = [x.name for x in batch_files]
preview_image = face_checkpoints_utils.build_face_checkpoint_and_save(
filenames, name, overwrite=overwrite
)
except Exception as e: except Exception as e:
logger.error("Failed to build checkpoint %s", e) logger.error("Failed to build checkpoint %s", e)
traceback.print_exc() traceback.print_exc()
return None return None
return preview_image
return target_img
def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame: def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame:
@ -281,7 +201,7 @@ def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame:
def batch_process( def batch_process(
files: List[gr.File], save_path: str, *components: Tuple[Any, ...] files: List[gr.File], save_path: str, *components: Tuple[Any, ...]
) -> Optional[List[PILImage]]: ) -> List[PILImage]:
try: try:
units_count = opts.data.get("faceswaplab_units_count", 3) units_count = opts.data.get("faceswaplab_units_count", 3)
@ -308,7 +228,7 @@ def batch_process(
logger.error("Batch Process error : %s", e) logger.error("Batch Process error : %s", e)
traceback.print_exc() traceback.print_exc()
return None return []
def tools_ui() -> None: def tools_ui() -> None:
@ -319,7 +239,7 @@ def tools_ui() -> None:
"""Build a face based on a batch list of images. Will blend the resulting face and store the checkpoint in the faceswaplab/faces directory.""" """Build a face based on a batch list of images. Will blend the resulting face and store the checkpoint in the faceswaplab/faces directory."""
) )
with gr.Row(): with gr.Row():
batch_files = gr.components.File( build_batch_files = gr.components.File(
type="file", type="file",
file_count="multiple", file_count="multiple",
label="Batch Sources Images", label="Batch Sources Images",
@ -332,12 +252,18 @@ def tools_ui() -> None:
interactive=False, interactive=False,
elem_id="faceswaplab_build_preview_face", elem_id="faceswaplab_build_preview_face",
) )
name = gr.Textbox( build_name = gr.Textbox(
value="Face", value="Face",
placeholder="Name of the character", placeholder="Name of the character",
label="Name of the character", label="Name of the character",
elem_id="faceswaplab_build_character_name", elem_id="faceswaplab_build_character_name",
) )
build_overwrite = gr.Checkbox(
False,
placeholder="overwrite",
label="Overwrite Checkpoint if exist (else will add number)",
elem_id="faceswaplab_build_overwrite",
)
generate_checkpoint_btn = gr.Button( generate_checkpoint_btn = gr.Button(
"Save", elem_id="faceswaplab_build_save_btn" "Save", elem_id="faceswaplab_build_save_btn"
) )
@ -452,7 +378,9 @@ def tools_ui() -> None:
) )
compare_btn.click(compare, inputs=[img1, img2], outputs=[compare_result_text]) compare_btn.click(compare, inputs=[img1, img2], outputs=[compare_result_text])
generate_checkpoint_btn.click( generate_checkpoint_btn.click(
build_face_checkpoint_and_save, inputs=[batch_files, name], outputs=[preview] build_face_checkpoint_and_save,
inputs=[build_batch_files, build_name, build_overwrite],
outputs=[preview],
) )
extract_btn.click( extract_btn.click(
extract_faces, extract_faces,

@ -8,7 +8,7 @@ from insightface.app.common import Face
from PIL import Image from PIL import Image
from scripts.faceswaplab_utils.imgutils import pil_to_cv2 from scripts.faceswaplab_utils.imgutils import pil_to_cv2
from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_utils import face_utils from scripts.faceswaplab_utils import face_checkpoints_utils
from scripts.faceswaplab_inpainting.faceswaplab_inpainting import InpaintingOptions from scripts.faceswaplab_inpainting.faceswaplab_inpainting import InpaintingOptions
from client_api import api_utils from client_api import api_utils
@ -118,10 +118,9 @@ class FaceSwapUnitSettings:
""" """
if not hasattr(self, "_reference_face"): if not hasattr(self, "_reference_face"):
if self.source_face and self.source_face != "None": if self.source_face and self.source_face != "None":
with open(self.source_face, "rb") as file:
try: try:
logger.info(f"loading face {file.name}") logger.info(f"loading face {self.source_face}")
face = face_utils.load_face(file.name) face = face_checkpoints_utils.load_face(self.source_face)
self._reference_face = face self._reference_face = face
except Exception as e: except Exception as e:
logger.error("Failed to load checkpoint : %s", e) logger.error("Failed to load checkpoint : %s", e)

@ -1,6 +1,6 @@
from typing import List from typing import List
from scripts.faceswaplab_ui.faceswaplab_inpainting_ui import face_inpainting_ui from scripts.faceswaplab_ui.faceswaplab_inpainting_ui import face_inpainting_ui
from scripts.faceswaplab_utils.face_utils import get_face_checkpoints from scripts.faceswaplab_utils.face_checkpoints_utils import get_face_checkpoints
import gradio as gr import gradio as gr

@ -0,0 +1,236 @@
import glob
import os
from typing import *
from insightface.app.common import Face
from safetensors.torch import save_file, safe_open
import torch
import modules.scripts as scripts
from modules import scripts
from scripts.faceswaplab_utils.faceswaplab_logging import logger
from scripts.faceswaplab_utils.typing import *
from scripts.faceswaplab_utils import imgutils
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions,
)
from scripts.faceswaplab_utils.models_utils import get_models
from modules.shared import opts
import traceback
import dill as pickle # will be removed in future versions
from scripts.faceswaplab_swapping import swapper
from pprint import pformat
import re
def sanitize_name(name: str) -> str:
"""
Sanitize the input name by removing special characters and replacing spaces with underscores.
Parameters:
name (str): The input name to be sanitized.
Returns:
str: The sanitized name with special characters removed and spaces replaced by underscores.
"""
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
name = name.replace(" ", "_")
return name[:255]
def build_face_checkpoint_and_save(
batch_files: List[str], name: str, overwrite: bool = False
) -> PILImage:
"""
Builds a face checkpoint using the provided image files, performs face swapping,
and saves the result to a file. If a blended face is successfully obtained and the face swapping
process succeeds, the resulting image is returned. Otherwise, None is returned.
Args:
batch_files (list): List of image file paths used to create the face checkpoint.
name (str): The name assigned to the face checkpoint.
Returns:
PIL.PILImage or None: The resulting swapped face image if the process is successful; None otherwise.
"""
try:
name = sanitize_name(name)
batch_files = batch_files or []
logger.info("Build %s %s", name, [x for x in batch_files])
faces = swapper.get_faces_from_img_files(batch_files)
blended_face = swapper.blend_faces(faces)
preview_path = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
reference_preview_img: PILImage = None
if blended_face:
if blended_face["gender"] == 0:
reference_preview_img = Image.open(
os.path.join(preview_path, "woman.png")
)
else:
reference_preview_img = Image.open(
os.path.join(preview_path, "man.png")
)
if name == "":
name = "default_name"
logger.debug("Face %s", pformat(blended_face))
target_face = swapper.get_or_default(
swapper.get_faces(imgutils.pil_to_cv2(reference_preview_img)), 0, None
)
if target_face is None:
logger.error(
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
)
else:
result = swapper.swap_face(
reference_face=blended_face,
target_faces=[target_face],
source_face=blended_face,
target_img=reference_preview_img,
model=get_models()[0],
upscaled_swapper=opts.data.get(
"faceswaplab_upscaled_swapper", False
),
)
preview_image = enhance_image(
result.image,
PostProcessingOptions(
face_restorer_name="CodeFormer", restorer_visibility=1
),
)
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
if not overwrite:
file_number = 1
while os.path.exists(file_path):
file_path = os.path.join(
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
)
file_number += 1
save_face(filename=file_path, face=blended_face)
preview_image.save(file_path + ".png")
try:
data = load_face(file_path)
logger.debug(data)
except Exception as e:
logger.error("Error loading checkpoint, after creation %s", e)
traceback.print_exc()
return preview_image
else:
logger.error("No face found")
return None
except Exception as e:
logger.error("Failed to build checkpoint %s", e)
traceback.print_exc()
return None
def save_face(face: Face, filename: str) -> None:
try:
tensors = {
"embedding": torch.tensor(face["embedding"]),
"gender": torch.tensor(face["gender"]),
"age": torch.tensor(face["age"]),
}
save_file(tensors, filename)
except Exception as e:
traceback.print_exc
logger.error("Failed to save checkpoint %s", e)
raise e
def load_face(name: str) -> Face:
filename = matching_checkpoint(name)
if filename is None:
return None
if filename.endswith(".pkl"):
logger.warning(
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
)
logger.warning("The file will be converted to .safetensors")
logger.warning(
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
)
with open(filename, "rb") as file:
logger.info("Load pkl")
face = Face(pickle.load(file))
logger.warning(
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
)
save_face(face, filename.replace(".pkl", ".safetensors"))
return face
elif filename.endswith(".safetensors"):
face = {}
with safe_open(filename, framework="pt", device="cpu") as f:
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
return Face(face)
raise NotImplementedError("Unknown file type, face extraction not implemented")
def get_checkpoint_path() -> str:
checkpoint_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
os.makedirs(checkpoint_path, exist_ok=True)
return checkpoint_path
def matching_checkpoint(name: str) -> Optional[str]:
"""
Retrieve the full path of a checkpoint file matching the given name.
If the name already includes a path separator, it is returned as-is. Otherwise, the function looks for a matching
file with the extensions ".safetensors" or ".pkl" in the checkpoint directory.
Args:
name (str): The name or path of the checkpoint file.
Returns:
Optional[str]: The full path of the matching checkpoint file, or None if no match is found.
"""
# If the name already includes a path separator, return it as is
if os.path.sep in name:
return name
# If the name doesn't end with the specified extensions, look for a matching file
if not (name.endswith(".safetensors") or name.endswith(".pkl")):
# Try appending each extension and check if the file exists in the checkpoint path
for ext in [".safetensors", ".pkl"]:
full_path = os.path.join(get_checkpoint_path(), name + ext)
if os.path.exists(full_path):
return full_path
# If no matching file is found, return None
return None
# If the name already ends with the specified extensions, simply complete the path
return os.path.join(get_checkpoint_path(), name)
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(get_checkpoint_path(), "*.safetensors")
faces = glob.glob(faces_path)
faces_path = os.path.join(get_checkpoint_path(), "*.pkl")
faces += glob.glob(faces_path)
return ["None"] + [os.path.basename(face) for face in sorted(faces)]

@ -1,72 +0,0 @@
import glob
import os
from typing import List
from insightface.app.common import Face
from safetensors.torch import save_file, safe_open
import torch
import modules.scripts as scripts
from modules import scripts
from scripts.faceswaplab_utils.faceswaplab_logging import logger
import dill as pickle # will be removed in future versions
def save_face(face: Face, filename: str) -> None:
tensors = {
"embedding": torch.tensor(face["embedding"]),
"gender": torch.tensor(face["gender"]),
"age": torch.tensor(face["age"]),
}
save_file(tensors, filename)
def load_face(filename: str) -> Face:
if filename.endswith(".pkl"):
logger.warning(
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions."
)
logger.warning("The file will be converted to .safetensors")
logger.warning(
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d"
)
with open(filename, "rb") as file:
logger.info("Load pkl")
face = Face(pickle.load(file))
logger.warning(
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working"
)
save_face(face, filename.replace(".pkl", ".safetensors"))
return face
elif filename.endswith(".safetensors"):
face = {}
with safe_open(filename, framework="pt", device="cpu") as f:
for k in f.keys():
logger.debug("load key %s", k)
face[k] = f.get_tensor(k).numpy()
return Face(face)
raise NotImplementedError("Unknown file type, face extraction not implemented")
def get_face_checkpoints() -> List[str]:
"""
Retrieve a list of face checkpoint paths.
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list
containing the paths of those files.
Returns:
list: A list of face paths, including the string "None" as the first element.
"""
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors"
)
faces = glob.glob(faces_path)
faces_path = os.path.join(
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl"
)
faces += glob.glob(faces_path)
return ["None"] + sorted(faces)

@ -11,6 +11,7 @@ from modules import processing
import base64 import base64
from collections import Counter from collections import Counter
from scripts.faceswaplab_utils.typing import BoxCoords, CV2ImgU8, PILImage from scripts.faceswaplab_utils.typing import BoxCoords, CV2ImgU8, PILImage
from scripts.faceswaplab_utils.faceswaplab_logging import logger
def check_against_nsfw(img: PILImage) -> bool: def check_against_nsfw(img: PILImage) -> bool:
@ -157,19 +158,6 @@ def create_square_image(image_list: List[PILImage]) -> Optional[PILImage]:
return None return None
# def create_mask(image : PILImage, box_coords : Tuple[int, int, int, int]) -> PILImage:
# width, height = image.size
# mask = Image.new("L", (width, height), 255)
# x1, y1, x2, y2 = box_coords
# for x in range(width):
# for y in range(height):
# if x1 <= x <= x2 and y1 <= y <= y2:
# mask.putpixel((x, y), 255)
# else:
# mask.putpixel((x, y), 0)
# return mask
def create_mask( def create_mask(
image: PILImage, image: PILImage,
box_coords: BoxCoords, box_coords: BoxCoords,
@ -216,6 +204,8 @@ def apply_mask(
if overlays is None or batch_index >= len(overlays): if overlays is None or batch_index >= len(overlays):
return img return img
overlay: PILImage = overlays[batch_index] overlay: PILImage = overlays[batch_index]
logger.debug("Overlay size %s, Image size %s", overlay.size, img.size)
if overlay.size != img.size:
overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS) overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS)
img = img.copy() img = img.copy()
img.paste(overlay, (0, 0), overlay) img.paste(overlay, (0, 0), overlay)

Loading…
Cancel
Save