add tests

main
Tran Xen 2 years ago
parent ddc403db1f
commit 15e9366eb6

@ -12,13 +12,13 @@ address = "http://127.0.0.1:7860"
# First face unit : # First face unit :
unit1 = FaceSwapUnit( unit1 = FaceSwapUnit(
source_img=pil_to_base64("../../references/man.png"), # The face you want to use source_img=pil_to_base64("../references/man.png"), # The face you want to use
faces_index=(0,), # Replace first face faces_index=(0,), # Replace first face
) )
# Second face unit : # Second face unit :
unit2 = FaceSwapUnit( unit2 = FaceSwapUnit(
source_img=pil_to_base64("../../references/woman.png"), # The face you want to use source_img=pil_to_base64("../references/woman.png"), # The face you want to use
same_gender=True, same_gender=True,
faces_index=(0,), # Replace first woman since same gender is on faces_index=(0,), # Replace first woman since same gender is on
) )
@ -48,5 +48,6 @@ result = requests.post(
) )
response = FaceSwapResponse.parse_obj(result.json()) response = FaceSwapResponse.parse_obj(result.json())
for img, info in zip(response.pil_images, response.infos): print(response.json())
img.show(title=info) for img in response.pil_images:
img.show()

Before

Width:  |  Height:  |  Size: 99 KiB

After

Width:  |  Height:  |  Size: 99 KiB

@ -13,9 +13,6 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting
from scripts.faceswaplab_utils.imgutils import ( from scripts.faceswaplab_utils.imgutils import (
base64_to_pil, base64_to_pil,
) )
from scripts.faceswaplab_utils.models_utils import get_current_model
from modules.shared import opts
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from scripts.faceswaplab_postprocessing.postprocessing_options import ( from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions, PostProcessingOptions,
) )
@ -135,22 +132,18 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
units: List[FaceSwapUnitSettings] = [] units: List[FaceSwapUnitSettings] = []
src_image: Optional[Image.Image] = base64_to_pil(request.image) src_image: Optional[Image.Image] = base64_to_pil(request.image)
response = FaceSwapResponse(images=[], infos=[]) response = FaceSwapResponse(images=[], infos=[])
if request.postprocessing:
pp_options = get_postprocessing_options(request.postprocessing)
if src_image is not None: if src_image is not None:
if request.postprocessing:
pp_options = get_postprocessing_options(request.postprocessing)
units = get_faceswap_units_settings(request.units) units = get_faceswap_units_settings(request.units)
swapped_images = swapper.process_images_units( swapped_images = swapper.batch_process(
get_current_model(), [src_image], None, units=units, postprocess_options=pp_options
images=[(src_image, None)],
units=units,
upscaled_swapper=opts.data.get("faceswaplab_upscaled_swapper", False),
) )
for img, info in swapped_images:
if pp_options: for img in swapped_images:
img = enhance_image(img, pp_options)
response.images.append(encode_to_base64(img)) response.images.append(encode_to_base64(img))
response.infos.append(info)
response.infos = [] # Not used atm
return response return response

@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
) )
VERSION_FLAG: str = "v1.1.0" VERSION_FLAG: str = "v1.1.1"
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab") EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW. # The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.

@ -2,6 +2,7 @@ import copy
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Optional from typing import Any, Dict, List, Set, Tuple, Optional
import tempfile
import cv2 import cv2
import insightface import insightface
@ -21,6 +22,12 @@ from scripts import faceswaplab_globals
from modules.shared import opts from modules.shared import opts
from functools import lru_cache from functools import lru_cache
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions,
)
from scripts.faceswaplab_utils.models_utils import get_current_model
providers = ["CPUExecutionProvider"] providers = ["CPUExecutionProvider"]
@ -78,6 +85,53 @@ def compare_faces(img1: Image.Image, img2: Image.Image) -> float:
return -1 return -1
def batch_process(
src_images: List[Image.Image],
save_path: Optional[str],
units: List[FaceSwapUnitSettings],
postprocess_options: PostProcessingOptions,
) -> Optional[List[Image.Image]]:
try:
if save_path:
os.makedirs(save_path, exist_ok=True)
units = [u for u in units if u.enable]
if src_images is not None and len(units) > 0:
result_images = []
for src_image in src_images:
current_images = []
swapped_images = process_images_units(
get_current_model(),
images=[(src_image, None)],
units=units,
upscaled_swapper=opts.data.get(
"faceswaplab_upscaled_swapper", False
),
)
if len(swapped_images) > 0:
current_images += [img for img, _ in swapped_images]
logger.info("%s images generated", len(current_images))
for i, img in enumerate(current_images):
current_images[i] = enhance_image(img, postprocess_options)
if save_path:
for img in current_images:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=save_path
).name
img.save(path)
result_images += current_images
return result_images
except Exception as e:
logger.error("Batch Process error : %s", e)
import traceback
traceback.print_exc()
return None
class FaceModelException(Exception): class FaceModelException(Exception):
"""Exception raised when an error is encountered in the face model.""" """Exception raised when an error is encountered in the face model."""

@ -26,7 +26,6 @@ from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from dataclasses import fields from dataclasses import fields
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_utils.models_utils import get_current_model
import re import re
@ -291,9 +290,6 @@ def batch_process(
files: List[gr.File], save_path: str, *components: List[gr.components.Component] files: List[gr.File], save_path: str, *components: List[gr.components.Component]
) -> Optional[List[Image.Image]]: ) -> Optional[List[Image.Image]]:
try: try:
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
units_count = opts.data.get("faceswaplab_units_count", 3) units_count = opts.data.get("faceswaplab_units_count", 3)
units: List[FaceSwapUnitSettings] = [] units: List[FaceSwapUnitSettings] = []
@ -312,36 +308,15 @@ def batch_process(
*components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore *components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
) )
logger.debug("%s", pformat(postprocess_options)) logger.debug("%s", pformat(postprocess_options))
images = [
units = [u for u in units if u.enable] Image.open(file.name) for file in files
if files is not None and len(units) > 0: ] # potentially greedy but Image.open is supposed to be lazy
images = [] return swapper.batch_process(
for file in files: images,
current_images = [] save_path=save_path,
src_image = Image.open(file.name)
swapped_images = swapper.process_images_units(
get_current_model(),
images=[(src_image, None)],
units=units, units=units,
upscaled_swapper=opts.data.get( postprocess_options=postprocess_options,
"faceswaplab_upscaled_swapper", False
),
) )
if len(swapped_images) > 0:
current_images += [img for img, _ in swapped_images]
logger.info("%s images generated", len(current_images))
for i, img in enumerate(current_images):
current_images[i] = enhance_image(img, postprocess_options)
for img in current_images:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=save_path
).name
img.save(path)
images += current_images
return images
except Exception as e: except Exception as e:
logger.error("Batch Process error : %s", e) logger.error("Batch Process error : %s", e)
import traceback import traceback

@ -0,0 +1,3 @@
#!/bin/bash
./check.sh
pytest -p no:warnings

@ -0,0 +1,83 @@
from typing import List
import pytest
import requests
import sys
sys.path.append(".")
from client_api.client_utils import (
FaceSwapUnit,
FaceSwapResponse,
PostProcessingOptions,
FaceSwapRequest,
base64_to_pil,
pil_to_base64,
InpaintingWhen,
)
from PIL import Image
base_url = "http://127.0.0.1:7860"
@pytest.fixture
def face_swap_request() -> FaceSwapRequest:
# First face unit
unit1 = FaceSwapUnit(
source_img=pil_to_base64("references/man.png"), # The face you want to use
faces_index=(0,), # Replace first face
)
# Second face unit
unit2 = FaceSwapUnit(
source_img=pil_to_base64("references/woman.png"), # The face you want to use
same_gender=True,
faces_index=(0,), # Replace first woman since same gender is on
)
# Post-processing config
pp = PostProcessingOptions(
face_restorer_name="CodeFormer",
codeformer_weight=0.5,
restorer_visibility=1,
upscaler_name="Lanczos",
scale=4,
inpainting_steps=30,
inpainting_denoising_strengh=0.1,
inpainting_when=InpaintingWhen.BEFORE_RESTORE_FACE,
)
# Prepare the request
request = FaceSwapRequest(
image=pil_to_base64("tests/test_image.png"),
units=[unit1, unit2],
postprocessing=pp,
)
return request
def test_version() -> None:
response = requests.get(f"{base_url}/faceswaplab/version")
assert response.status_code == 200
assert "version" in response.json()
def test_faceswap(face_swap_request: FaceSwapRequest) -> None:
response = requests.post(
f"{base_url}/faceswaplab/swap_face",
data=face_swap_request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
data = response.json()
assert "images" in data
assert "infos" in data
res = FaceSwapResponse.parse_obj(response.json())
images: List[Image.Image] = res.pil_images
assert len(images) == 1
image = images[0]
orig_image = base64_to_pil(face_swap_request.image)
assert image.width == orig_image.width * face_swap_request.postprocessing.scale
assert image.height == orig_image.height * face_swap_request.postprocessing.scale

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Loading…
Cancel
Save