| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- from typing import List, Tuple, Union
- import numpy as np
- import paddle
- import paddle.distributed as dist
- from paddle.framework import core
- def get_coordinator(mesh: Union[np.array, List[List[int]]], rank: int):
- mesh = paddle.to_tensor(mesh)
- rand_coordinator = (mesh == rank).nonzero()
- assert rand_coordinator.shape[0] in (
- 0,
- 1,
- ), f"rand_coordinator.shape: {rand_coordinator.shape}"
- return (
- rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None
- )
- def compute_local_shape_and_global_offset(
- global_shape: List[int],
- process_mesh: core.ProcessMesh,
- placements: List[core.Placement],
- ) -> Tuple[Tuple[int], Tuple[int]]:
- mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape)
- # deal with cross mesh case
- if paddle.distributed.get_rank() not in mesh:
- return (None, None)
- rank_coordinator = get_coordinator(mesh, paddle.distributed.get_rank())
- local_shape = copy.copy(global_shape)
- global_offset = [0 for _ in global_shape]
- for dim, placement in enumerate(placements):
- if isinstance(placement, dist.Replicate):
- continue
- else:
- i = placement.get_dim()
- assert (
- global_shape[i] % process_mesh.shape[dim] == 0
- ), f"i:{i}, global_shape[i]:{global_shape[i]}, process_mesh.shape[dim]:{process_mesh.shape[dim]}"
- local_shape[i] = global_shape[i] // process_mesh.shape[dim]
- chunk_idx = rank_coordinator[dim]
- global_offset[i] = chunk_idx * local_shape[i]
- return tuple(local_shape), tuple(global_offset)
- def flatten_state_dict(state_dict):
- """
- Flatten the nested dict to a flat dict.
- {"model": {"w0": xxx}} -> {model.w0: xxx}
- """
- flatten_state_dict = {}
- mapping = {}
- def _flatten(key, value):
- if isinstance(value, dict):
- for k, v in value.items():
- assert isinstance(k, str), f"The key should be str, but is {k}"
- _flatten(key + (k,), v)
- elif isinstance(value, paddle.Tensor):
- flatten_key_str = ".".join(key)
- flatten_state_dict[flatten_key_str] = value
- mapping[flatten_key_str] = key
- else:
- raise ValueError(
- f"The value should be dict or paddle.Tensor, but is {value}"
- )
- _flatten((), state_dict)
- return flatten_state_dict, mapping
- def unflatten_state_dict(flat_state_dict, mapping):
- """
- Unflatten the flat dict to a nested dict.
- {model.w0: xxx} -> {"model": {"w0": xxx}}
- """
- state_dict = {}
- for key, value in flat_state_dict.items():
- key_tuple = mapping[key]
- assert isinstance(
- key_tuple, tuple
- ), f"The key should be tuple, but is {key_tuple}"
- tmp = state_dict
- for i in range(len(key_tuple) - 1):
- key = key_tuple[i]
- tmp = tmp.setdefault(key, {})
- tmp[key_tuple[-1]] = value
- return state_dict
|