From b773bda19f0cbe29e013e4b31b63f2f02da05468 Mon Sep 17 00:00:00 2001 From: Tran Xen <137925069+glucauze@users.noreply.github.com> Date: Wed, 2 Aug 2023 01:21:21 +0200 Subject: [PATCH] fix similarity, add checksum for swapper, fix minor bugs --- README.md | 4 + requirements.txt | 2 +- scripts/faceswaplab_swapping/swapper.py | 135 ++++++---- scripts/faceswaplab_ui/faceswaplab_tab.py | 120 ++------- .../faceswaplab_unit_settings.py | 17 +- scripts/faceswaplab_ui/faceswaplab_unit_ui.py | 2 +- .../face_checkpoints_utils.py | 236 ++++++++++++++++++ scripts/faceswaplab_utils/face_utils.py | 72 ------ scripts/faceswaplab_utils/imgutils.py | 18 +- 9 files changed, 371 insertions(+), 235 deletions(-) create mode 100644 scripts/faceswaplab_utils/face_checkpoints_utils.py delete mode 100644 scripts/faceswaplab_utils/face_utils.py diff --git a/README.md b/README.md index b3423f3..3728c68 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,10 @@ In short: More on this here : https://glucauze.github.io/sd-webui-faceswaplab/ +### Known problems (wontfix): + ++ Older versions of gradio don't work well with the extension. See this bug : https://github.com/glucauze/sd-webui-faceswaplab/issues/5 + ### Features + **Face Unit Concept**: Similar to controlNet, the program introduces the concept of a face unit. You can configure up to 10 units (3 units are the default setting) in the program settings (sd). diff --git a/requirements.txt b/requirements.txt index 2999cc3..1c8637a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ cython ifnude insightface==0.7.3 onnx==1.14.0 -onnxruntime==1.15.0 +onnxruntime==1.15.1 opencv-python==4.7.0.72 pandas pydantic==1.10.9 diff --git a/scripts/faceswaplab_swapping/swapper.py b/scripts/faceswaplab_swapping/swapper.py index d2dd2d4..89eaf8e 100644 --- a/scripts/faceswaplab_swapping/swapper.py +++ b/scripts/faceswaplab_swapping/swapper.py @@ -27,7 +27,6 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, ) from scripts.faceswaplab_utils.models_utils import get_current_model -import gradio as gr from scripts.faceswaplab_utils.typing import CV2ImgU8, PILImage, Face from scripts.faceswaplab_inpainting.i2i_pp import img2img_diffusion @@ -250,6 +249,21 @@ def getAnalysisModel() -> insightface.app.FaceAnalysis: raise FaceModelException("Loading of analysis model failed") +import hashlib + + +def is_sha1_matching(file_path: str, expected_sha1: str) -> bool: + sha1_hash = hashlib.sha1() + + with open(file_path, "rb") as file: + for byte_block in iter(lambda: file.read(4096), b""): + sha1_hash.update(byte_block) + if sha1_hash.hexdigest() == expected_sha1: + return True + else: + return False + + @lru_cache(maxsize=1) def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper: """ @@ -262,6 +276,14 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper: insightface.model_zoo.FaceModel: The face swap model. """ try: + expected_sha1 = "17a64851eaefd55ea597ee41e5c18409754244c5" + if not is_sha1_matching(model_path, expected_sha1): + logger.error( + "Suspicious sha1 for model %s, check the model is valid or has been downloaded adequately. Should be %s", + model_path, + expected_sha1, + ) + # Initializes the face swap model using the specified model path. return upscaled_inswapper.UpscaledINSwapper( insightface.model_zoo.get_model(model_path, providers=providers) @@ -270,6 +292,9 @@ def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper: logger.error( "Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)" ) + import traceback + + traceback.print_exc() raise FaceModelException("Loading of swapping model failed") @@ -315,11 +340,15 @@ def get_faces( return [] +@dataclass +class FaceFilteringOptions: + faces_index: Set[int] + source_gender: Optional[int] = None # if none will not use same gender + sort_by_face_size: bool = False + + def filter_faces( - all_faces: List[Face], - faces_index: Set[int], - source_gender: int = None, - sort_by_face_size: bool = False, + all_faces: List[Face], filtering_options: FaceFilteringOptions ) -> List[Face]: """ Sorts and filters a list of faces based on specified criteria. @@ -337,18 +366,24 @@ def filter_faces( :return: A list of Face objects sorted and filtered according to the specified criteria. """ filtered_faces = copy.copy(all_faces) - if sort_by_face_size: + if filtering_options.sort_by_face_size: filtered_faces = sorted( all_faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), ) - if source_gender is not None: + if filtering_options.source_gender is not None: filtered_faces = [ - face for face in filtered_faces if face["gender"] == source_gender + face + for face in filtered_faces + if face["gender"] == filtering_options.source_gender ] - return [face for i, face in enumerate(filtered_faces) if i in faces_index] + return [ + face + for i, face in enumerate(filtered_faces) + if i in filtering_options.faces_index + ] @dataclass @@ -391,7 +426,7 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any: return l[index] if index < len(l) else default -def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[CV2ImgU8]]: +def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]: """ Extracts faces from a list of image files. @@ -407,7 +442,7 @@ def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[CV2ImgU8]]: if len(files) > 0: for file in files: - img = Image.open(file.name) # Open the image file + img = Image.open(file) # Open the image file face = get_or_default( get_faces(pil_to_cv2(img)), 0, None ) # Extract faces from the image @@ -503,41 +538,44 @@ def swap_face( result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) return_result.image = result_image - # FIXME : recompute similarity - - # if compute_similarity: - # try: - # result_faces = get_faces( - # cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR), - # sort_by_face_size=sort_by_face_size, - # ) - # if same_gender: - # result_faces = [ - # x for x in result_faces if x["gender"] == gender - # ] - - # for i, swapped_face in enumerate(result_faces): - # logger.info(f"compare face {i}") - # if i in faces_index and i < len(target_faces): - # return_result.similarity[i] = cosine_similarity_face( - # source_face, swapped_face - # ) - # return_result.ref_similarity[i] = cosine_similarity_face( - # reference_face, swapped_face - # ) - - # logger.info(f"similarity {return_result.similarity}") - # logger.info(f"ref similarity {return_result.ref_similarity}") - - # except Exception as e: - # logger.error("Similarity processing failed %s", e) - # raise e except Exception as e: logger.error("Conversion failed %s", e) raise e return return_result +def compute_similarity( + reference_face: Face, + source_face: Face, + swapped_image: PILImage, + filtering: FaceFilteringOptions, +) -> Tuple[Dict[int, float], Dict[int, float]]: + similarity: Dict[int, float] = {} + ref_similarity: Dict[int, float] = {} + try: + swapped_image_cv2: CV2ImgU8 = cv2.cvtColor( + np.array(swapped_image), cv2.COLOR_RGB2BGR + ) + new_faces = filter_faces(get_faces(swapped_image_cv2), filtering) + if len(new_faces) == 0: + logger.error("compute_similarity : No faces to compare with !") + return None + + for i, swapped_face in enumerate(new_faces): + logger.info(f"compare face {i}") + similarity[i] = cosine_similarity_face(source_face, swapped_face) + ref_similarity[i] = cosine_similarity_face(reference_face, swapped_face) + + logger.info(f"similarity {similarity}") + logger.info(f"ref similarity {ref_similarity}") + + return (similarity, ref_similarity) + except Exception as e: + logger.error("Similarity processing failed %s", e) + raise e + return None + + def process_image_unit( model: str, unit: FaceSwapUnitSettings, @@ -580,13 +618,14 @@ def process_image_unit( logger.info("Use source face as reference face") reference_face = src_face - target_faces = filter_faces( - faces, + face_filtering_options = FaceFilteringOptions( faces_index=unit.faces_index, source_gender=src_face["gender"] if unit.same_gender else None, sort_by_face_size=unit.sort_by_size, ) + target_faces = filter_faces(faces, filtering_options=face_filtering_options) + # Apply pre-inpainting to image if unit.pre_inpainting.inpainting_denoising_strengh > 0: current_image = img2img_diffusion( @@ -611,6 +650,18 @@ def process_image_unit( save_img_debug(result.image, "After swap") + if unit.compute_similarity: + similarities = compute_similarity( + reference_face=reference_face, + source_face=src_face, + swapped_image=result.image, + filtering=face_filtering_options, + ) + if similarities: + (result.similarity, result.ref_similarity) = similarities + else: + logger.error("Failed to compute similarity") + if result.image is None: logger.error("Result image is None") if ( diff --git a/scripts/faceswaplab_ui/faceswaplab_tab.py b/scripts/faceswaplab_ui/faceswaplab_tab.py index 622c48f..0d66cec 100644 --- a/scripts/faceswaplab_ui/faceswaplab_tab.py +++ b/scripts/faceswaplab_ui/faceswaplab_tab.py @@ -1,26 +1,21 @@ -import os -import re import traceback -from pprint import pformat, pprint +from pprint import pformat from typing import * from scripts.faceswaplab_utils.typing import * import gradio as gr -import modules.scripts as scripts import onnx import pandas as pd -from modules import scripts from modules.shared import opts from PIL import Image import scripts.faceswaplab_swapping.swapper as swapper -from scripts.faceswaplab_postprocessing.postprocessing import enhance_image from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, ) from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui -from scripts.faceswaplab_utils import face_utils, imgutils +from scripts.faceswaplab_utils import face_checkpoints_utils, imgutils from scripts.faceswaplab_utils.faceswaplab_logging import logger from scripts.faceswaplab_utils.models_utils import get_models from scripts.faceswaplab_utils.ui_utils import dataclasses_from_flat_list @@ -138,24 +133,9 @@ def analyse_faces(image: PILImage, det_threshold: float = 0.5) -> Optional[str]: return None -def sanitize_name(name: str) -> str: - """ - Sanitize the input name by removing special characters and replacing spaces with underscores. - - Parameters: - name (str): The input name to be sanitized. - - Returns: - str: The sanitized name with special characters removed and spaces replaced by underscores. - """ - name = re.sub("[^A-Za-z0-9_. ]+", "", name) - name = name.replace(" ", "_") - return name[:255] - - def build_face_checkpoint_and_save( - batch_files: gr.File, name: str -) -> Optional[PILImage]: + batch_files: gr.File, name: str, overwrite: bool +) -> PILImage: """ Builds a face checkpoint using the provided image files, performs face swapping, and saves the result to a file. If a blended face is successfully obtained and the face swapping @@ -170,79 +150,19 @@ def build_face_checkpoint_and_save( """ try: - name = sanitize_name(name) - batch_files = batch_files or [] - logger.info("Build %s %s", name, [x.name for x in batch_files]) - faces = swapper.get_faces_from_img_files(batch_files) - blended_face = swapper.blend_faces(faces) - preview_path = os.path.join( - scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" + if not batch_files: + logger.error("No face found") + return None + filenames = [x.name for x in batch_files] + preview_image = face_checkpoints_utils.build_face_checkpoint_and_save( + filenames, name, overwrite=overwrite ) - - faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces") - - os.makedirs(faces_path, exist_ok=True) - - target_img: PILImage = None - if blended_face: - if blended_face["gender"] == 0: - target_img = Image.open(os.path.join(preview_path, "woman.png")) - else: - target_img = Image.open(os.path.join(preview_path, "man.png")) - - if name == "": - name = "default_name" - pprint(blended_face) - target_face = swapper.get_or_default( - swapper.get_faces(imgutils.pil_to_cv2(target_img)), 0, None - ) - if target_face is None: - logger.error( - "Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold." - ) - else: - result = swapper.swap_face( - reference_face=blended_face, - target_faces=[target_face], - source_face=blended_face, - target_img=target_img, - model=get_models()[0], - upscaled_swapper=opts.data.get( - "faceswaplab_upscaled_swapper", False - ), - ) - result_image = enhance_image( - result.image, - PostProcessingOptions( - face_restorer_name="CodeFormer", restorer_visibility=1 - ), - ) - - file_path = os.path.join(faces_path, f"{name}.safetensors") - file_number = 1 - while os.path.exists(file_path): - file_path = os.path.join( - faces_path, f"{name}_{file_number}.safetensors" - ) - file_number += 1 - result_image.save(file_path + ".png") - - face_utils.save_face(filename=file_path, face=blended_face) - try: - data = face_utils.load_face(filename=file_path) - logger.debug(data) - except Exception as e: - print(e) - return result_image - - logger.error("No face found") except Exception as e: logger.error("Failed to build checkpoint %s", e) traceback.print_exc() return None - - return target_img + return preview_image def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame: @@ -281,7 +201,7 @@ def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame: def batch_process( files: List[gr.File], save_path: str, *components: Tuple[Any, ...] -) -> Optional[List[PILImage]]: +) -> List[PILImage]: try: units_count = opts.data.get("faceswaplab_units_count", 3) @@ -308,7 +228,7 @@ def batch_process( logger.error("Batch Process error : %s", e) traceback.print_exc() - return None + return [] def tools_ui() -> None: @@ -319,7 +239,7 @@ def tools_ui() -> None: """Build a face based on a batch list of images. Will blend the resulting face and store the checkpoint in the faceswaplab/faces directory.""" ) with gr.Row(): - batch_files = gr.components.File( + build_batch_files = gr.components.File( type="file", file_count="multiple", label="Batch Sources Images", @@ -332,12 +252,18 @@ def tools_ui() -> None: interactive=False, elem_id="faceswaplab_build_preview_face", ) - name = gr.Textbox( + build_name = gr.Textbox( value="Face", placeholder="Name of the character", label="Name of the character", elem_id="faceswaplab_build_character_name", ) + build_overwrite = gr.Checkbox( + False, + placeholder="overwrite", + label="Overwrite Checkpoint if exist (else will add number)", + elem_id="faceswaplab_build_overwrite", + ) generate_checkpoint_btn = gr.Button( "Save", elem_id="faceswaplab_build_save_btn" ) @@ -452,7 +378,9 @@ def tools_ui() -> None: ) compare_btn.click(compare, inputs=[img1, img2], outputs=[compare_result_text]) generate_checkpoint_btn.click( - build_face_checkpoint_and_save, inputs=[batch_files, name], outputs=[preview] + build_face_checkpoint_and_save, + inputs=[build_batch_files, build_name, build_overwrite], + outputs=[preview], ) extract_btn.click( extract_faces, diff --git a/scripts/faceswaplab_ui/faceswaplab_unit_settings.py b/scripts/faceswaplab_ui/faceswaplab_unit_settings.py index d3ecf7c..9bfe4b0 100644 --- a/scripts/faceswaplab_ui/faceswaplab_unit_settings.py +++ b/scripts/faceswaplab_ui/faceswaplab_unit_settings.py @@ -8,7 +8,7 @@ from insightface.app.common import Face from PIL import Image from scripts.faceswaplab_utils.imgutils import pil_to_cv2 from scripts.faceswaplab_utils.faceswaplab_logging import logger -from scripts.faceswaplab_utils import face_utils +from scripts.faceswaplab_utils import face_checkpoints_utils from scripts.faceswaplab_inpainting.faceswaplab_inpainting import InpaintingOptions from client_api import api_utils @@ -118,14 +118,13 @@ class FaceSwapUnitSettings: """ if not hasattr(self, "_reference_face"): if self.source_face and self.source_face != "None": - with open(self.source_face, "rb") as file: - try: - logger.info(f"loading face {file.name}") - face = face_utils.load_face(file.name) - self._reference_face = face - except Exception as e: - logger.error("Failed to load checkpoint : %s", e) - raise e + try: + logger.info(f"loading face {self.source_face}") + face = face_checkpoints_utils.load_face(self.source_face) + self._reference_face = face + except Exception as e: + logger.error("Failed to load checkpoint : %s", e) + raise e elif self.source_img is not None: if isinstance(self.source_img, str): # source_img is a base64 string if ( diff --git a/scripts/faceswaplab_ui/faceswaplab_unit_ui.py b/scripts/faceswaplab_ui/faceswaplab_unit_ui.py index 8cac3db..a035c14 100644 --- a/scripts/faceswaplab_ui/faceswaplab_unit_ui.py +++ b/scripts/faceswaplab_ui/faceswaplab_unit_ui.py @@ -1,6 +1,6 @@ from typing import List from scripts.faceswaplab_ui.faceswaplab_inpainting_ui import face_inpainting_ui -from scripts.faceswaplab_utils.face_utils import get_face_checkpoints +from scripts.faceswaplab_utils.face_checkpoints_utils import get_face_checkpoints import gradio as gr diff --git a/scripts/faceswaplab_utils/face_checkpoints_utils.py b/scripts/faceswaplab_utils/face_checkpoints_utils.py new file mode 100644 index 0000000..de3ebf1 --- /dev/null +++ b/scripts/faceswaplab_utils/face_checkpoints_utils.py @@ -0,0 +1,236 @@ +import glob +import os +from typing import * +from insightface.app.common import Face +from safetensors.torch import save_file, safe_open +import torch + +import modules.scripts as scripts +from modules import scripts +from scripts.faceswaplab_utils.faceswaplab_logging import logger +from scripts.faceswaplab_utils.typing import * +from scripts.faceswaplab_utils import imgutils +from scripts.faceswaplab_postprocessing.postprocessing import enhance_image +from scripts.faceswaplab_postprocessing.postprocessing_options import ( + PostProcessingOptions, +) +from scripts.faceswaplab_utils.models_utils import get_models +from modules.shared import opts +import traceback + +import dill as pickle # will be removed in future versions +from scripts.faceswaplab_swapping import swapper +from pprint import pformat +import re + + +def sanitize_name(name: str) -> str: + """ + Sanitize the input name by removing special characters and replacing spaces with underscores. + + Parameters: + name (str): The input name to be sanitized. + + Returns: + str: The sanitized name with special characters removed and spaces replaced by underscores. + """ + name = re.sub("[^A-Za-z0-9_. ]+", "", name) + name = name.replace(" ", "_") + return name[:255] + + +def build_face_checkpoint_and_save( + batch_files: List[str], name: str, overwrite: bool = False +) -> PILImage: + """ + Builds a face checkpoint using the provided image files, performs face swapping, + and saves the result to a file. If a blended face is successfully obtained and the face swapping + process succeeds, the resulting image is returned. Otherwise, None is returned. + + Args: + batch_files (list): List of image file paths used to create the face checkpoint. + name (str): The name assigned to the face checkpoint. + + Returns: + PIL.PILImage or None: The resulting swapped face image if the process is successful; None otherwise. + """ + + try: + name = sanitize_name(name) + batch_files = batch_files or [] + logger.info("Build %s %s", name, [x for x in batch_files]) + faces = swapper.get_faces_from_img_files(batch_files) + blended_face = swapper.blend_faces(faces) + preview_path = os.path.join( + scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" + ) + + reference_preview_img: PILImage = None + if blended_face: + if blended_face["gender"] == 0: + reference_preview_img = Image.open( + os.path.join(preview_path, "woman.png") + ) + else: + reference_preview_img = Image.open( + os.path.join(preview_path, "man.png") + ) + + if name == "": + name = "default_name" + logger.debug("Face %s", pformat(blended_face)) + target_face = swapper.get_or_default( + swapper.get_faces(imgutils.pil_to_cv2(reference_preview_img)), 0, None + ) + if target_face is None: + logger.error( + "Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold." + ) + else: + result = swapper.swap_face( + reference_face=blended_face, + target_faces=[target_face], + source_face=blended_face, + target_img=reference_preview_img, + model=get_models()[0], + upscaled_swapper=opts.data.get( + "faceswaplab_upscaled_swapper", False + ), + ) + preview_image = enhance_image( + result.image, + PostProcessingOptions( + face_restorer_name="CodeFormer", restorer_visibility=1 + ), + ) + + file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors") + if not overwrite: + file_number = 1 + while os.path.exists(file_path): + file_path = os.path.join( + get_checkpoint_path(), f"{name}_{file_number}.safetensors" + ) + file_number += 1 + save_face(filename=file_path, face=blended_face) + preview_image.save(file_path + ".png") + try: + data = load_face(file_path) + logger.debug(data) + except Exception as e: + logger.error("Error loading checkpoint, after creation %s", e) + traceback.print_exc() + + return preview_image + + else: + logger.error("No face found") + return None + except Exception as e: + logger.error("Failed to build checkpoint %s", e) + traceback.print_exc() + return None + + +def save_face(face: Face, filename: str) -> None: + try: + tensors = { + "embedding": torch.tensor(face["embedding"]), + "gender": torch.tensor(face["gender"]), + "age": torch.tensor(face["age"]), + } + save_file(tensors, filename) + except Exception as e: + traceback.print_exc + logger.error("Failed to save checkpoint %s", e) + raise e + + +def load_face(name: str) -> Face: + filename = matching_checkpoint(name) + if filename is None: + return None + + if filename.endswith(".pkl"): + logger.warning( + "Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions." + ) + logger.warning("The file will be converted to .safetensors") + logger.warning( + "You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d" + ) + with open(filename, "rb") as file: + logger.info("Load pkl") + face = Face(pickle.load(file)) + logger.warning( + "Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working" + ) + save_face(face, filename.replace(".pkl", ".safetensors")) + return face + + elif filename.endswith(".safetensors"): + face = {} + with safe_open(filename, framework="pt", device="cpu") as f: + for k in f.keys(): + logger.debug("load key %s", k) + face[k] = f.get_tensor(k).numpy() + return Face(face) + + raise NotImplementedError("Unknown file type, face extraction not implemented") + + +def get_checkpoint_path() -> str: + checkpoint_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces") + os.makedirs(checkpoint_path, exist_ok=True) + return checkpoint_path + + +def matching_checkpoint(name: str) -> Optional[str]: + """ + Retrieve the full path of a checkpoint file matching the given name. + + If the name already includes a path separator, it is returned as-is. Otherwise, the function looks for a matching + file with the extensions ".safetensors" or ".pkl" in the checkpoint directory. + + Args: + name (str): The name or path of the checkpoint file. + + Returns: + Optional[str]: The full path of the matching checkpoint file, or None if no match is found. + """ + + # If the name already includes a path separator, return it as is + if os.path.sep in name: + return name + + # If the name doesn't end with the specified extensions, look for a matching file + if not (name.endswith(".safetensors") or name.endswith(".pkl")): + # Try appending each extension and check if the file exists in the checkpoint path + for ext in [".safetensors", ".pkl"]: + full_path = os.path.join(get_checkpoint_path(), name + ext) + if os.path.exists(full_path): + return full_path + # If no matching file is found, return None + return None + + # If the name already ends with the specified extensions, simply complete the path + return os.path.join(get_checkpoint_path(), name) + + +def get_face_checkpoints() -> List[str]: + """ + Retrieve a list of face checkpoint paths. + + This function searches for face files with the extension ".safetensors" in the specified directory and returns a list + containing the paths of those files. + + Returns: + list: A list of face paths, including the string "None" as the first element. + """ + faces_path = os.path.join(get_checkpoint_path(), "*.safetensors") + faces = glob.glob(faces_path) + + faces_path = os.path.join(get_checkpoint_path(), "*.pkl") + faces += glob.glob(faces_path) + + return ["None"] + [os.path.basename(face) for face in sorted(faces)] diff --git a/scripts/faceswaplab_utils/face_utils.py b/scripts/faceswaplab_utils/face_utils.py deleted file mode 100644 index e07a6b6..0000000 --- a/scripts/faceswaplab_utils/face_utils.py +++ /dev/null @@ -1,72 +0,0 @@ -import glob -import os -from typing import List -from insightface.app.common import Face -from safetensors.torch import save_file, safe_open -import torch - -import modules.scripts as scripts -from modules import scripts -from scripts.faceswaplab_utils.faceswaplab_logging import logger -import dill as pickle # will be removed in future versions - - -def save_face(face: Face, filename: str) -> None: - tensors = { - "embedding": torch.tensor(face["embedding"]), - "gender": torch.tensor(face["gender"]), - "age": torch.tensor(face["age"]), - } - save_file(tensors, filename) - - -def load_face(filename: str) -> Face: - if filename.endswith(".pkl"): - logger.warning( - "Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions." - ) - logger.warning("The file will be converted to .safetensors") - logger.warning( - "You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d" - ) - with open(filename, "rb") as file: - logger.info("Load pkl") - face = Face(pickle.load(file)) - logger.warning( - "Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working" - ) - save_face(face, filename.replace(".pkl", ".safetensors")) - return face - - elif filename.endswith(".safetensors"): - face = {} - with safe_open(filename, framework="pt", device="cpu") as f: - for k in f.keys(): - logger.debug("load key %s", k) - face[k] = f.get_tensor(k).numpy() - return Face(face) - - raise NotImplementedError("Unknown file type, face extraction not implemented") - - -def get_face_checkpoints() -> List[str]: - """ - Retrieve a list of face checkpoint paths. - - This function searches for face files with the extension ".safetensors" in the specified directory and returns a list - containing the paths of those files. - - Returns: - list: A list of face paths, including the string "None" as the first element. - """ - faces_path = os.path.join( - scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors" - ) - faces = glob.glob(faces_path) - - faces_path = os.path.join( - scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl" - ) - faces += glob.glob(faces_path) - - return ["None"] + sorted(faces) diff --git a/scripts/faceswaplab_utils/imgutils.py b/scripts/faceswaplab_utils/imgutils.py index 7b0e534..e8d67fb 100644 --- a/scripts/faceswaplab_utils/imgutils.py +++ b/scripts/faceswaplab_utils/imgutils.py @@ -11,6 +11,7 @@ from modules import processing import base64 from collections import Counter from scripts.faceswaplab_utils.typing import BoxCoords, CV2ImgU8, PILImage +from scripts.faceswaplab_utils.faceswaplab_logging import logger def check_against_nsfw(img: PILImage) -> bool: @@ -157,19 +158,6 @@ def create_square_image(image_list: List[PILImage]) -> Optional[PILImage]: return None -# def create_mask(image : PILImage, box_coords : Tuple[int, int, int, int]) -> PILImage: -# width, height = image.size -# mask = Image.new("L", (width, height), 255) -# x1, y1, x2, y2 = box_coords -# for x in range(width): -# for y in range(height): -# if x1 <= x <= x2 and y1 <= y <= y2: -# mask.putpixel((x, y), 255) -# else: -# mask.putpixel((x, y), 0) -# return mask - - def create_mask( image: PILImage, box_coords: BoxCoords, @@ -216,7 +204,9 @@ def apply_mask( if overlays is None or batch_index >= len(overlays): return img overlay: PILImage = overlays[batch_index] - overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS) + logger.debug("Overlay size %s, Image size %s", overlay.size, img.size) + if overlay.size != img.size: + overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS) img = img.copy() img.paste(overlay, (0, 0), overlay) return img