from scripts.faceswaplab_swapping import swapper import numpy as np import base64 import io from dataclasses import dataclass, fields from typing import Any, List, Optional, Set, Union import dill as pickle import gradio as gr from insightface.app.common import Face from PIL import Image from scripts.faceswaplab_utils.imgutils import pil_to_cv2 from scripts.faceswaplab_utils.faceswaplab_logging import logger @dataclass class FaceSwapUnitSettings: # ORDER of parameters is IMPORTANT. It should match the result of faceswap_unit_ui # The image given in reference source_img: Union[Image.Image, str] # The checkpoint file source_face: str # The batch source images _batch_files: Union[gr.components.File, List[Image.Image]] # Will blend faces if True blend_faces: bool # Enable this unit enable: bool # Use same gender filtering same_gender: bool # Sort faces by their size (from larger to smaller) sort_by_size: bool # If True, discard images with low similarity check_similarity: bool # if True will compute similarity and add it to the image info _compute_similarity: bool # Minimum similarity against the used face (reference, batch or checkpoint) min_sim: float # Minimum similarity against the reference (reference or checkpoint if checkpoint is given) min_ref_sim: float # The face index to use for swapping _faces_index: str # The face index to get image from source reference_face_index: int # Swap in the source image in img2img (before processing) swap_in_source: bool # Swap in the generated image in img2img (always on for txt2img) swap_in_generated: bool @staticmethod def get_unit_configuration( unit: int, components: List[gr.components.Component] ) -> Any: fields_count = len(fields(FaceSwapUnitSettings)) return FaceSwapUnitSettings( *components[unit * fields_count : unit * fields_count + fields_count] ) @property def faces_index(self) -> Set[int]: """ Convert _faces_index from str to int """ faces_index = { int(x) for x in self._faces_index.strip(",").split(",") if x.isnumeric() } if len(faces_index) == 0: return {0} logger.debug("FACES INDEX : %s", faces_index) return faces_index @property def compute_similarity(self) -> bool: return self._compute_similarity or self.check_similarity @property def batch_files(self) -> List[gr.File]: """ Return empty array instead of None for batch files """ return self._batch_files or [] @property def reference_face(self) -> Optional[Face]: """ Extract reference face (only once and store it for the rest of processing). Reference face is the checkpoint or the source image or the first image in the batch in that order. """ if not hasattr(self, "_reference_face"): if self.source_face and self.source_face != "None": with open(self.source_face, "rb") as file: try: logger.info(f"loading pickle {file.name}") face = Face(pickle.load(file)) self._reference_face = face except Exception as e: logger.error("Failed to load checkpoint : %s", e) raise e elif self.source_img is not None: if isinstance(self.source_img, str): # source_img is a base64 string if ( "base64," in self.source_img ): # check if the base64 string has a data URL scheme base64_data = self.source_img.split("base64,")[-1] img_bytes = base64.b64decode(base64_data) else: # if no data URL scheme, just decode img_bytes = base64.b64decode(self.source_img) self.source_img = Image.open(io.BytesIO(img_bytes)) source_img = pil_to_cv2(self.source_img) self._reference_face = swapper.get_or_default( swapper.get_faces(source_img), self.reference_face_index, None ) if self._reference_face is None: logger.error("Face not found in reference image") else: self._reference_face = None if self._reference_face is None: logger.error("You need at least one reference face") raise Exception("No reference face found") return self._reference_face @property def faces(self) -> List[Face]: """_summary_ Extract all faces (including reference face) to provide an array of faces Only processed once. """ if self.batch_files is not None and not hasattr(self, "_faces"): self._faces = ( [self.reference_face] if self.reference_face is not None else [] ) for file in self.batch_files: if isinstance(file, Image.Image): img = file else: img = Image.open(file.name) face = swapper.get_or_default( swapper.get_faces(pil_to_cv2(img)), 0, None ) if face is not None: self._faces.append(face) return self._faces @property def blended_faces(self) -> Face: """ Blend the faces using the mean of all embeddings """ if not hasattr(self, "_blended_faces"): self._blended_faces = swapper.blend_faces(self.faces) assert ( all( [ not np.array_equal( self._blended_faces.embedding, face.embedding ) for face in self.faces ] ) if len(self.faces) > 1 else True ), "Blended faces cannot be the same as one of the face if len(face)>0" assert ( not np.array_equal( self._blended_faces.embedding, self.reference_face.embedding ) if len(self.faces) > 1 else True ), "Blended faces cannot be the same as reference face if len(face)>0" return self._blended_faces