| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | State dict utilities: utility methods for converting state dicts easily |
| | """ |
| | import enum |
| |
|
| | from .logging import get_logger |
| |
|
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class StateDictType(enum.Enum): |
| | """ |
| | The mode to use when converting state dicts. |
| | """ |
| |
|
| | DIFFUSERS_OLD = "diffusers_old" |
| | KOHYA_SS = "kohya_ss" |
| | PEFT = "peft" |
| | DIFFUSERS = "diffusers" |
| |
|
| |
|
| | |
| | |
| | UNET_TO_DIFFUSERS = { |
| | ".to_out_lora.up": ".to_out.0.lora_B", |
| | ".to_out_lora.down": ".to_out.0.lora_A", |
| | ".to_q_lora.down": ".to_q.lora_A", |
| | ".to_q_lora.up": ".to_q.lora_B", |
| | ".to_k_lora.down": ".to_k.lora_A", |
| | ".to_k_lora.up": ".to_k.lora_B", |
| | ".to_v_lora.down": ".to_v.lora_A", |
| | ".to_v_lora.up": ".to_v.lora_B", |
| | ".lora.up": ".lora_B", |
| | ".lora.down": ".lora_A", |
| | } |
| |
|
| |
|
| | DIFFUSERS_TO_PEFT = { |
| | ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", |
| | ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", |
| | ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", |
| | ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", |
| | ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", |
| | ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", |
| | ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", |
| | ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", |
| | ".lora_linear_layer.up": ".lora_B", |
| | ".lora_linear_layer.down": ".lora_A", |
| | } |
| |
|
| | DIFFUSERS_OLD_TO_PEFT = { |
| | ".to_q_lora.up": ".q_proj.lora_B", |
| | ".to_q_lora.down": ".q_proj.lora_A", |
| | ".to_k_lora.up": ".k_proj.lora_B", |
| | ".to_k_lora.down": ".k_proj.lora_A", |
| | ".to_v_lora.up": ".v_proj.lora_B", |
| | ".to_v_lora.down": ".v_proj.lora_A", |
| | ".to_out_lora.up": ".out_proj.lora_B", |
| | ".to_out_lora.down": ".out_proj.lora_A", |
| | ".lora_linear_layer.up": ".lora_B", |
| | ".lora_linear_layer.down": ".lora_A", |
| | } |
| |
|
| | PEFT_TO_DIFFUSERS = { |
| | ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", |
| | ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", |
| | ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", |
| | ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", |
| | ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", |
| | ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", |
| | ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", |
| | ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", |
| | "to_k.lora_A": "to_k.lora.down", |
| | "to_k.lora_B": "to_k.lora.up", |
| | "to_q.lora_A": "to_q.lora.down", |
| | "to_q.lora_B": "to_q.lora.up", |
| | "to_v.lora_A": "to_v.lora.down", |
| | "to_v.lora_B": "to_v.lora.up", |
| | "to_out.0.lora_A": "to_out.0.lora.down", |
| | "to_out.0.lora_B": "to_out.0.lora.up", |
| | } |
| |
|
| | DIFFUSERS_OLD_TO_DIFFUSERS = { |
| | ".to_q_lora.up": ".q_proj.lora_linear_layer.up", |
| | ".to_q_lora.down": ".q_proj.lora_linear_layer.down", |
| | ".to_k_lora.up": ".k_proj.lora_linear_layer.up", |
| | ".to_k_lora.down": ".k_proj.lora_linear_layer.down", |
| | ".to_v_lora.up": ".v_proj.lora_linear_layer.up", |
| | ".to_v_lora.down": ".v_proj.lora_linear_layer.down", |
| | ".to_out_lora.up": ".out_proj.lora_linear_layer.up", |
| | ".to_out_lora.down": ".out_proj.lora_linear_layer.down", |
| | } |
| |
|
| | PEFT_TO_KOHYA_SS = { |
| | "lora_A": "lora_down", |
| | "lora_B": "lora_up", |
| | |
| | |
| | |
| | } |
| |
|
| | PEFT_STATE_DICT_MAPPINGS = { |
| | StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, |
| | StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, |
| | } |
| |
|
| | DIFFUSERS_STATE_DICT_MAPPINGS = { |
| | StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, |
| | StateDictType.PEFT: PEFT_TO_DIFFUSERS, |
| | } |
| |
|
| | KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS} |
| |
|
| | KEYS_TO_ALWAYS_REPLACE = { |
| | ".processor.": ".", |
| | } |
| |
|
| |
|
| | def convert_state_dict(state_dict, mapping): |
| | r""" |
| | Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. |
| | |
| | Args: |
| | state_dict (`dict[str, torch.Tensor]`): |
| | The state dict to convert. |
| | mapping (`dict[str, str]`): |
| | The mapping to use for conversion, the mapping should be a dictionary with the following structure: |
| | - key: the pattern to replace |
| | - value: the pattern to replace with |
| | |
| | Returns: |
| | converted_state_dict (`dict`) |
| | The converted state dict. |
| | """ |
| | converted_state_dict = {} |
| | for k, v in state_dict.items(): |
| | |
| | for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): |
| | if pattern in k: |
| | new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] |
| | k = k.replace(pattern, new_pattern) |
| |
|
| | for pattern in mapping.keys(): |
| | if pattern in k: |
| | new_pattern = mapping[pattern] |
| | k = k.replace(pattern, new_pattern) |
| | break |
| | converted_state_dict[k] = v |
| | return converted_state_dict |
| |
|
| |
|
| | def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): |
| | r""" |
| | Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or |
| | new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. |
| | |
| | Args: |
| | state_dict (`dict[str, torch.Tensor]`): |
| | The state dict to convert. |
| | original_type (`StateDictType`, *optional*): |
| | The original type of the state dict, if not provided, the method will try to infer it automatically. |
| | """ |
| | if original_type is None: |
| | |
| | if any("to_out_lora" in k for k in state_dict.keys()): |
| | original_type = StateDictType.DIFFUSERS_OLD |
| | elif any("lora_linear_layer" in k for k in state_dict.keys()): |
| | original_type = StateDictType.DIFFUSERS |
| | else: |
| | raise ValueError("Could not automatically infer state dict type") |
| |
|
| | if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): |
| | raise ValueError(f"Original type {original_type} is not supported") |
| |
|
| | mapping = PEFT_STATE_DICT_MAPPINGS[original_type] |
| | return convert_state_dict(state_dict, mapping) |
| |
|
| |
|
| | def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): |
| | r""" |
| | Converts a state dict to new diffusers format. The state dict can be from previous diffusers format |
| | (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will |
| | return the state dict as is. |
| | |
| | The method only supports the conversion from diffusers old, PEFT to diffusers new for now. |
| | |
| | Args: |
| | state_dict (`dict[str, torch.Tensor]`): |
| | The state dict to convert. |
| | original_type (`StateDictType`, *optional*): |
| | The original type of the state dict, if not provided, the method will try to infer it automatically. |
| | kwargs (`dict`, *args*): |
| | Additional arguments to pass to the method. |
| | |
| | - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended |
| | with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in |
| | `get_peft_model_state_dict` method: |
| | https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 |
| | but we add it here in case we don't want to rely on that method. |
| | """ |
| | peft_adapter_name = kwargs.pop("adapter_name", None) |
| | if peft_adapter_name is not None: |
| | peft_adapter_name = "." + peft_adapter_name |
| | else: |
| | peft_adapter_name = "" |
| |
|
| | if original_type is None: |
| | |
| | if any("to_out_lora" in k for k in state_dict.keys()): |
| | original_type = StateDictType.DIFFUSERS_OLD |
| | elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): |
| | original_type = StateDictType.PEFT |
| | elif any("lora_linear_layer" in k for k in state_dict.keys()): |
| | |
| | return state_dict |
| | else: |
| | raise ValueError("Could not automatically infer state dict type") |
| |
|
| | if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): |
| | raise ValueError(f"Original type {original_type} is not supported") |
| |
|
| | mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] |
| | return convert_state_dict(state_dict, mapping) |
| |
|
| |
|
| | def convert_unet_state_dict_to_peft(state_dict): |
| | r""" |
| | Converts a state dict from UNet format to diffusers format - i.e. by removing some keys |
| | """ |
| | mapping = UNET_TO_DIFFUSERS |
| | return convert_state_dict(state_dict, mapping) |
| |
|
| |
|
| | def convert_all_state_dict_to_peft(state_dict): |
| | r""" |
| | Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` |
| | for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` |
| | """ |
| | try: |
| | peft_dict = convert_state_dict_to_peft(state_dict) |
| | except Exception as e: |
| | if str(e) == "Could not automatically infer state dict type": |
| | peft_dict = convert_unet_state_dict_to_peft(state_dict) |
| | else: |
| | raise |
| |
|
| | if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()): |
| | raise ValueError("Your LoRA was not converted to PEFT") |
| |
|
| | return peft_dict |
| |
|
| |
|
| | def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): |
| | r""" |
| | Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. |
| | The method only supports the conversion from PEFT to Kohya for now. |
| | |
| | Args: |
| | state_dict (`dict[str, torch.Tensor]`): |
| | The state dict to convert. |
| | original_type (`StateDictType`, *optional*): |
| | The original type of the state dict, if not provided, the method will try to infer it automatically. |
| | kwargs (`dict`, *args*): |
| | Additional arguments to pass to the method. |
| | |
| | - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended |
| | with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in |
| | `get_peft_model_state_dict` method: |
| | https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 |
| | but we add it here in case we don't want to rely on that method. |
| | """ |
| | try: |
| | import torch |
| | except ImportError: |
| | logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.") |
| | raise |
| |
|
| | peft_adapter_name = kwargs.pop("adapter_name", None) |
| | if peft_adapter_name is not None: |
| | peft_adapter_name = "." + peft_adapter_name |
| | else: |
| | peft_adapter_name = "" |
| |
|
| | if original_type is None: |
| | if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): |
| | original_type = StateDictType.PEFT |
| |
|
| | if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): |
| | raise ValueError(f"Original type {original_type} is not supported") |
| |
|
| | |
| | kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) |
| | kohya_ss_state_dict = {} |
| |
|
| | |
| | for kohya_key, weight in kohya_ss_partial_state_dict.items(): |
| | if "text_encoder_2." in kohya_key: |
| | kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") |
| | elif "text_encoder." in kohya_key: |
| | kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") |
| | elif "unet" in kohya_key: |
| | kohya_key = kohya_key.replace("unet", "lora_unet") |
| | kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) |
| | kohya_key = kohya_key.replace(peft_adapter_name, "") |
| | kohya_ss_state_dict[kohya_key] = weight |
| | if "lora_down" in kohya_key: |
| | alpha_key = f'{kohya_key.split(".")[0]}.alpha' |
| | kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) |
| |
|
| | return kohya_ss_state_dict |
| |
|