diff --git a/check.sh b/check.sh index b5cbbc8..d5e3b5f 100755 --- a/check.sh +++ b/check.sh @@ -1,4 +1,4 @@ #!/bin/bash autoflake --in-place --remove-unused-variables -r --remove-all-unused-imports . mypy --install-types -pre-commit run --all-files \ No newline at end of file +pre-commit run --all-files diff --git a/client_api/__init__.py b/client_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/example/api/client_utils.py b/client_api/client_utils.py similarity index 100% rename from example/api/client_utils.py rename to client_api/client_utils.py diff --git a/example/api/roop_api_example.py b/client_api/faceswaplab_api_example.py similarity index 80% rename from example/api/roop_api_example.py rename to client_api/faceswaplab_api_example.py index 918d995..323e526 100644 --- a/example/api/roop_api_example.py +++ b/client_api/faceswaplab_api_example.py @@ -12,13 +12,13 @@ address = "http://127.0.0.1:7860" # First face unit : unit1 = FaceSwapUnit( - source_img=pil_to_base64("../../references/man.png"), # The face you want to use + source_img=pil_to_base64("../references/man.png"), # The face you want to use faces_index=(0,), # Replace first face ) # Second face unit : unit2 = FaceSwapUnit( - source_img=pil_to_base64("../../references/woman.png"), # The face you want to use + source_img=pil_to_base64("../references/woman.png"), # The face you want to use same_gender=True, faces_index=(0,), # Replace first woman since same gender is on ) @@ -48,5 +48,6 @@ result = requests.post( ) response = FaceSwapResponse.parse_obj(result.json()) -for img, info in zip(response.pil_images, response.infos): - img.show(title=info) +print(response.json()) +for img in response.pil_images: + img.show() diff --git a/example/api/test_image.png b/client_api/test_image.png similarity index 100% rename from example/api/test_image.png rename to client_api/test_image.png diff --git a/scripts/faceswaplab_api/faceswaplab_api.py b/scripts/faceswaplab_api/faceswaplab_api.py index e896523..fb3a7ef 100644 --- a/scripts/faceswaplab_api/faceswaplab_api.py +++ b/scripts/faceswaplab_api/faceswaplab_api.py @@ -13,9 +13,6 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting from scripts.faceswaplab_utils.imgutils import ( base64_to_pil, ) -from scripts.faceswaplab_utils.models_utils import get_current_model -from modules.shared import opts -from scripts.faceswaplab_postprocessing.postprocessing import enhance_image from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, ) @@ -135,22 +132,18 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None: units: List[FaceSwapUnitSettings] = [] src_image: Optional[Image.Image] = base64_to_pil(request.image) response = FaceSwapResponse(images=[], infos=[]) - if request.postprocessing: - pp_options = get_postprocessing_options(request.postprocessing) if src_image is not None: + if request.postprocessing: + pp_options = get_postprocessing_options(request.postprocessing) units = get_faceswap_units_settings(request.units) - swapped_images = swapper.process_images_units( - get_current_model(), - images=[(src_image, None)], - units=units, - upscaled_swapper=opts.data.get("faceswaplab_upscaled_swapper", False), + swapped_images = swapper.batch_process( + [src_image], None, units=units, postprocess_options=pp_options ) - for img, info in swapped_images: - if pp_options: - img = enhance_image(img, pp_options) + + for img in swapped_images: response.images.append(encode_to_base64(img)) - response.infos.append(info) + response.infos = [] # Not used atm return response diff --git a/scripts/faceswaplab_globals.py b/scripts/faceswaplab_globals.py index cdc4c66..40dbf12 100644 --- a/scripts/faceswaplab_globals.py +++ b/scripts/faceswaplab_globals.py @@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join( scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" ) -VERSION_FLAG: str = "v1.1.0" +VERSION_FLAG: str = "v1.1.1" EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab") # The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW. diff --git a/scripts/faceswaplab_swapping/swapper.py b/scripts/faceswaplab_swapping/swapper.py index 9dfc454..cf8abc6 100644 --- a/scripts/faceswaplab_swapping/swapper.py +++ b/scripts/faceswaplab_swapping/swapper.py @@ -2,6 +2,7 @@ import copy import os from dataclasses import dataclass from typing import Any, Dict, List, Set, Tuple, Optional +import tempfile import cv2 import insightface @@ -21,6 +22,12 @@ from scripts import faceswaplab_globals from modules.shared import opts from functools import lru_cache from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings +from scripts.faceswaplab_postprocessing.postprocessing import enhance_image +from scripts.faceswaplab_postprocessing.postprocessing_options import ( + PostProcessingOptions, +) +from scripts.faceswaplab_utils.models_utils import get_current_model + providers = ["CPUExecutionProvider"] @@ -78,6 +85,53 @@ def compare_faces(img1: Image.Image, img2: Image.Image) -> float: return -1 +def batch_process( + src_images: List[Image.Image], + save_path: Optional[str], + units: List[FaceSwapUnitSettings], + postprocess_options: PostProcessingOptions, +) -> Optional[List[Image.Image]]: + try: + if save_path: + os.makedirs(save_path, exist_ok=True) + + units = [u for u in units if u.enable] + if src_images is not None and len(units) > 0: + result_images = [] + for src_image in src_images: + current_images = [] + swapped_images = process_images_units( + get_current_model(), + images=[(src_image, None)], + units=units, + upscaled_swapper=opts.data.get( + "faceswaplab_upscaled_swapper", False + ), + ) + if len(swapped_images) > 0: + current_images += [img for img, _ in swapped_images] + + logger.info("%s images generated", len(current_images)) + for i, img in enumerate(current_images): + current_images[i] = enhance_image(img, postprocess_options) + + if save_path: + for img in current_images: + path = tempfile.NamedTemporaryFile( + delete=False, suffix=".png", dir=save_path + ).name + img.save(path) + + result_images += current_images + return result_images + except Exception as e: + logger.error("Batch Process error : %s", e) + import traceback + + traceback.print_exc() + return None + + class FaceModelException(Exception): """Exception raised when an error is encountered in the face model.""" diff --git a/scripts/faceswaplab_ui/faceswaplab_tab.py b/scripts/faceswaplab_ui/faceswaplab_tab.py index ac94cf9..ce33b02 100644 --- a/scripts/faceswaplab_ui/faceswaplab_tab.py +++ b/scripts/faceswaplab_ui/faceswaplab_tab.py @@ -26,7 +26,6 @@ from scripts.faceswaplab_postprocessing.postprocessing import enhance_image from dataclasses import fields from typing import Any, Dict, List, Optional from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings -from scripts.faceswaplab_utils.models_utils import get_current_model import re @@ -291,9 +290,6 @@ def batch_process( files: List[gr.File], save_path: str, *components: List[gr.components.Component] ) -> Optional[List[Image.Image]]: try: - if save_path is not None: - os.makedirs(save_path, exist_ok=True) - units_count = opts.data.get("faceswaplab_units_count", 3) units: List[FaceSwapUnitSettings] = [] @@ -312,36 +308,15 @@ def batch_process( *components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore ) logger.debug("%s", pformat(postprocess_options)) - - units = [u for u in units if u.enable] - if files is not None and len(units) > 0: - images = [] - for file in files: - current_images = [] - src_image = Image.open(file.name) - swapped_images = swapper.process_images_units( - get_current_model(), - images=[(src_image, None)], - units=units, - upscaled_swapper=opts.data.get( - "faceswaplab_upscaled_swapper", False - ), - ) - if len(swapped_images) > 0: - current_images += [img for img, _ in swapped_images] - - logger.info("%s images generated", len(current_images)) - for i, img in enumerate(current_images): - current_images[i] = enhance_image(img, postprocess_options) - - for img in current_images: - path = tempfile.NamedTemporaryFile( - delete=False, suffix=".png", dir=save_path - ).name - img.save(path) - - images += current_images - return images + images = [ + Image.open(file.name) for file in files + ] # potentially greedy but Image.open is supposed to be lazy + return swapper.batch_process( + images, + save_path=save_path, + units=units, + postprocess_options=postprocess_options, + ) except Exception as e: logger.error("Batch Process error : %s", e) import traceback diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..9133d9c --- /dev/null +++ b/test.sh @@ -0,0 +1,3 @@ +#!/bin/bash +./check.sh +pytest -p no:warnings \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..f35cbd9 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,83 @@ +from typing import List +import pytest +import requests +import sys + +sys.path.append(".") + +from client_api.client_utils import ( + FaceSwapUnit, + FaceSwapResponse, + PostProcessingOptions, + FaceSwapRequest, + base64_to_pil, + pil_to_base64, + InpaintingWhen, +) +from PIL import Image + +base_url = "http://127.0.0.1:7860" + + +@pytest.fixture +def face_swap_request() -> FaceSwapRequest: + # First face unit + unit1 = FaceSwapUnit( + source_img=pil_to_base64("references/man.png"), # The face you want to use + faces_index=(0,), # Replace first face + ) + + # Second face unit + unit2 = FaceSwapUnit( + source_img=pil_to_base64("references/woman.png"), # The face you want to use + same_gender=True, + faces_index=(0,), # Replace first woman since same gender is on + ) + + # Post-processing config + pp = PostProcessingOptions( + face_restorer_name="CodeFormer", + codeformer_weight=0.5, + restorer_visibility=1, + upscaler_name="Lanczos", + scale=4, + inpainting_steps=30, + inpainting_denoising_strengh=0.1, + inpainting_when=InpaintingWhen.BEFORE_RESTORE_FACE, + ) + + # Prepare the request + request = FaceSwapRequest( + image=pil_to_base64("tests/test_image.png"), + units=[unit1, unit2], + postprocessing=pp, + ) + + return request + + +def test_version() -> None: + response = requests.get(f"{base_url}/faceswaplab/version") + assert response.status_code == 200 + assert "version" in response.json() + + +def test_faceswap(face_swap_request: FaceSwapRequest) -> None: + response = requests.post( + f"{base_url}/faceswaplab/swap_face", + data=face_swap_request.json(), + headers={"Content-Type": "application/json; charset=utf-8"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "images" in data + assert "infos" in data + + res = FaceSwapResponse.parse_obj(response.json()) + images: List[Image.Image] = res.pil_images + assert len(images) == 1 + image = images[0] + orig_image = base64_to_pil(face_swap_request.image) + assert image.width == orig_image.width * face_swap_request.postprocessing.scale + assert image.height == orig_image.height * face_swap_request.postprocessing.scale diff --git a/tests/test_image.png b/tests/test_image.png new file mode 100644 index 0000000..21a424b Binary files /dev/null and b/tests/test_image.png differ