Spaces:
Runtime error
Runtime error
| import os | |
| import unittest | |
| import modules.flags | |
| from modules import util | |
| class TestUtils(unittest.TestCase): | |
| def test_can_parse_tokens_with_lora(self): | |
| test_cases = [ | |
| { | |
| "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5, True), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'), | |
| }, | |
| # Test can not exceed limit | |
| { | |
| "input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1, True), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4)], | |
| 'some prompt, very cool, cool' | |
| ), | |
| }, | |
| # test Loras from UI take precedence over prompt | |
| { | |
| "input": ( | |
| "some prompt, very cool, <lora:l1:0.4>, <lora:l2:-0.2>, <lora:l3:0.3>, <lora:l4:0.5>, <lora:l6:0.24>, <lora:l7:0.1>", | |
| [("hey-lora.safetensors", 0.4)], | |
| 5, | |
| True | |
| ), | |
| "output": ( | |
| [ | |
| ('hey-lora.safetensors', 0.4), | |
| ('l1.safetensors', 0.4), | |
| ('l2.safetensors', -0.2), | |
| ('l3.safetensors', 0.3), | |
| ('l4.safetensors', 0.5) | |
| ], | |
| 'some prompt, very cool' | |
| ) | |
| }, | |
| # test correct matching even if there is no space separating loras in the same token | |
| { | |
| "input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3, True), | |
| "output": ( | |
| [ | |
| ('hey-lora.safetensors', 0.4), | |
| ('you-lora.safetensors', 0.2) | |
| ], | |
| 'some prompt, very cool' | |
| ), | |
| }, | |
| # test deduplication, also selected loras are never overridden with loras in prompt | |
| { | |
| "input": ( | |
| "some prompt, very cool, <lora:hey-lora:0.4><lora:hey-lora:0.4><lora:you-lora:0.2>", | |
| [('you-lora.safetensors', 0.3)], | |
| 3, | |
| True | |
| ), | |
| "output": ( | |
| [ | |
| ('you-lora.safetensors', 0.3), | |
| ('hey-lora.safetensors', 0.4) | |
| ], | |
| 'some prompt, very cool' | |
| ), | |
| }, | |
| { | |
| "input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6, True), | |
| "output": ( | |
| [], | |
| '<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>' | |
| ) | |
| } | |
| ] | |
| for test in test_cases: | |
| prompt, loras, loras_limit, skip_file_check = test["input"] | |
| expected = test["output"] | |
| actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, | |
| skip_file_check=skip_file_check) | |
| self.assertEqual(expected, actual) | |
| def test_can_parse_tokens_and_strip_performance_lora(self): | |
| lora_filenames = [ | |
| 'hey-lora.safetensors', | |
| modules.flags.PerformanceLoRA.EXTREME_SPEED.value, | |
| modules.flags.PerformanceLoRA.LIGHTNING.value, | |
| os.path.join('subfolder', modules.flags.PerformanceLoRA.HYPER_SD.value) | |
| ] | |
| test_cases = [ | |
| { | |
| "input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.QUALITY), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4)], | |
| 'some prompt' | |
| ), | |
| }, | |
| { | |
| "input": ("some prompt, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.SPEED), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4)], | |
| 'some prompt' | |
| ), | |
| }, | |
| { | |
| "input": ("some prompt, <lora:sdxl_lcm_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.EXTREME_SPEED), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4)], | |
| 'some prompt' | |
| ), | |
| }, | |
| { | |
| "input": ("some prompt, <lora:sdxl_lightning_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.LIGHTNING), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4)], | |
| 'some prompt' | |
| ), | |
| }, | |
| { | |
| "input": ("some prompt, <lora:sdxl_hyper_sd_4step_lora:1>, <lora:hey-lora:0.4>", [], 5, True, modules.flags.Performance.HYPER_SD), | |
| "output": ( | |
| [('hey-lora.safetensors', 0.4)], | |
| 'some prompt' | |
| ), | |
| } | |
| ] | |
| for test in test_cases: | |
| prompt, loras, loras_limit, skip_file_check, performance = test["input"] | |
| lora_filenames = modules.util.remove_performance_lora(lora_filenames, performance) | |
| expected = test["output"] | |
| actual = util.parse_lora_references_from_prompt(prompt, loras, loras_limit=loras_limit, lora_filenames=lora_filenames) | |
| self.assertEqual(expected, actual) | |