import copy
import functools
import re
import typing as t
import numpy as np
from aiida.common.constants import elements
from aiida.common.exceptions import UnsupportedSpeciesError
try:
import ase # noqa: F401
except ImportError:
pass
try:
import pymatgen.core as core # noqa: F401
except ImportError:
pass
from aiida.engine import calcfunction
from aiida.orm import List
from . import _GLOBAL_PROPERTIES, _COMPUTED_PROPERTIES, _CONVERSION_PLURAL_SINGULAR
# Threshold used to check if the mass of two different Site objects is the same.
[docs]
_MASS_THRESHOLD = 1.0e-3
# Threshold to check if the sum is one or not
[docs]
_SUM_THRESHOLD = 1.0e-6
# Default cell
[docs]
_DEFAULT_CELL = ((0, 0, 0), (0, 0, 0), (0, 0, 0))
[docs]
_valid_symbols = tuple(i["symbol"] for i in elements.values())
[docs]
_atomic_masses = {el["symbol"]: el["mass"] for el in elements.values()}
[docs]
_atomic_numbers = {data["symbol"]: num for num, data in elements.items()}
[docs]
_dimensionality_label = {0: '', 1: 'length', 2: 'surface', 3: 'volume'}
[docs]
class ObservedArray(np.ndarray):
"""
This is a subclass of numpy.ndarray that allows to observe changes to the array.
In this way, full flexibility of StructureDataMutable is achieved and at the same
time we can keep track of all the changes.
"""
def __new__(cls, input_array):
"""
Create a new instance of ObservedArray.
Parameters:
- input_array: array-like
The input array to be converted to an instance of ObservedArray.
Returns:
- obj: ObservedArray
The new instance of ObservedArray.
"""
obj = np.asarray(input_array).view(cls)
return obj
[docs]
def __setitem__(self, index, value):
"""
Set the value of an item in the ObservedArray.
Parameters:
- index: int or tuple
The index or indices of the item(s) to be set.
- value: any
The value to be assigned to the item(s).
Returns:
None
"""
super(ObservedArray, self).__setitem__(index, value)
[docs]
def __array_finalize__(self, obj):
"""
Finalize the creation of the ObservedArray.
This method is called when the view is created or sliced.
Parameters:
- obj: ObservedArray or None
The object being finalized.
Returns:
None
"""
if obj is None:
return
[docs]
def efficient_copy(self):
# Only copy mutable parts, this is much more efficient than using always copy.deepcopy.
return self.__class__(**{
k: v if isinstance(v, (str, int, float, tuple)) else copy.deepcopy(v)
for k, v in self.items()
})
[docs]
def _get_valid_cell(inputcell):
"""Return the cell in a valid format from a generic input.
:raise ValueError: whenever the format is not valid.
"""
try:
the_cell = list(list(float(c) for c in i) for i in inputcell)
if len(the_cell) != 3:
raise ValueError
if any(len(i) != 3 for i in the_cell):
raise ValueError
except (IndexError, ValueError, TypeError):
raise ValueError(
"Cell must be a list of three vectors, each defined as a list of three coordinates."
)
return np.array(the_cell)
[docs]
def _get_valid_pbc(inputpbc):
"""Return a list of three booleans for the periodic boundary conditions,
in a valid format from a generic input.
:raise ValueError: if the format is not valid.
"""
if isinstance(inputpbc, bool):
the_pbc = [inputpbc, inputpbc, inputpbc]
elif hasattr(inputpbc, "__iter__"):
# To manage numpy lists of bools, whose elements are of type numpy.bool_
# and for which isinstance(i,bool) return False...
if hasattr(inputpbc, "tolist"):
the_value = list(i for i in inputpbc.tolist())
else:
the_value = inputpbc
if all(isinstance(i, bool) for i in the_value):
if len(the_value) == 3:
the_pbc = list(i for i in the_value)
elif len(the_value) == 1:
the_pbc = (the_value[0], the_value[0], the_value[0])
else:
raise ValueError("pbc length must be either one or three.")
else:
raise ValueError("pbc elements are not booleans.")
else:
raise ValueError("pbc must be a boolean or a list of three booleans.", inputpbc)
return the_pbc
[docs]
def _check_valid_sites(sites):
"""Check that no two sites have positions that are too close to each other."""
positions = np.array([site['position'] for site in sites])
n_sites = len(positions)
if n_sites <= 1:
return
# Calculate pairwise distances using broadcasting (this is much more efficient than loops...)
diff = positions[:, np.newaxis, :] - positions[np.newaxis, :, :] # Shape: (n_sites, n_sites, 3)
distances = np.linalg.norm(diff, axis=2) # Shape: (n_sites, n_sites)
# Set diagonal to large value to ignore self-comparisons
np.fill_diagonal(distances, np.inf)
# Check if any distance is below threshold
min_distance = 1e-3 # You can adjust this threshold
close_pairs = np.where(distances < min_distance)
if len(close_pairs[0]) > 0:
i, j = close_pairs[0][0], close_pairs[1][0] # Get first problematic pair
raise ValueError(f"Sites {i} and {j} have positions that are too close: "
f"{positions[i]} and {positions[j]} (distance: {distances[i,j]:.6f})")
return
[docs]
def has_ase():
""":return: True if the ase module can be imported, False otherwise."""
try:
import ase # noqa: F401
except ImportError:
return False
return True
[docs]
def has_pymatgen():
""":return: True if the pymatgen module can be imported, False otherwise."""
try:
import pymatgen # noqa: F401
except ImportError:
return False
return True
[docs]
def get_pymatgen_version():
""":return: string with pymatgen version, None if can not import."""
if not has_pymatgen():
return None
try:
from pymatgen import __version__
except ImportError:
# this was changed in version 2022.0.3
from pymatgen.core import __version__
return __version__
[docs]
def has_spglib():
""":return: True if the spglib module can be imported, False otherwise."""
try:
import spglib # noqa: F401
except ImportError:
return False
return True
[docs]
def get_dimensionality(
pbc,
cell,
):
"""Return the dimensionality of the structure and its length/surface/volume.
Zero-dimensional structures are assigned "volume" 0.
:return: returns a dictionary with keys "dim" (dimensionality integer), "label" (dimensionality label)
and "value" (numerical length/surface/volume).
"""
import numpy as np
retdict = {}
pbc = np.array(pbc)
cell = np.array(cell)
dim = len(pbc[pbc])
retdict["dim"] = dim
retdict["label"] = _dimensionality_label[dim]
if dim not in (0, 1, 2, 3):
raise ValueError(f"Dimensionality {dim} must be one of 0, 1, 2, 3")
if dim == 0:
# We have no concept of 0d volume. Let's return a value of 0 for a consistent output dictionary
retdict["value"] = 0
elif dim == 1:
retdict["value"] = np.linalg.norm(cell[pbc])
elif dim == 2:
vectors = cell[pbc]
retdict["value"] = np.linalg.norm(np.cross(vectors[0], vectors[1]))
elif dim == 3:
retdict["value"] = calc_cell_volume(cell)
return retdict
[docs]
def calc_cell_volume(cell):
"""Compute the three-dimensional cell volume in Angstrom^3.
:param cell: the cell vectors; the must be a 3x3 list of lists of floats
:returns: the cell volume.
"""
return np.abs(np.dot(cell[0], np.cross(cell[1], cell[2])))
[docs]
def _create_symbols_tuple(symbols):
"""Returns a tuple with the symbols provided. If a string is provided,
this is converted to a tuple with one single element.
"""
if isinstance(symbols, str):
symbols_list = re.sub( r"([A-Z])", r" \1", symbols).split()
else:
symbols_list = tuple(symbols)
for symbol in symbols_list:
if symbol not in _valid_symbols:
raise ValueError(f"Some or all of the symbols provided are not correct: {symbols_list}")
return symbols_list
[docs]
def _create_weights_tuple(weights):
"""Returns a tuple with the weights provided. If a number is provided,
this is converted to a tuple with one single element.
If None is provided, this is converted to the tuple (1.,)
"""
import numbers
if weights is None:
weights_tuple = (1.0,)
elif isinstance(weights, numbers.Number):
weights_tuple = (weights,)
else:
weights_tuple = tuple(float(i) for i in weights)
return weights_tuple
def create_automatic_kind_name(symbols, weights):
"""Create a string obtained with the symbols appended one
after the other, without spaces, in alphabetical order;
if the site has a vacancy, a X is appended at the end too.
"""
sorted_symbol_list = list(set(symbols))
sorted_symbol_list.sort() # In-place sort
name_string = "".join(sorted_symbol_list)
if has_vacancies(weights):
name_string += "X"
return name_string
[docs]
def validate_weights_tuple(weights_tuple, threshold):
"""Validates the weight of the atomic kinds.
:raise: ValueError if the weights_tuple is not valid.
:param weights_tuple: the tuple to validate. It must be a
a tuple of floats (as created by :func:_create_weights_tuple).
:param threshold: a float number used as a threshold to check that the sum
of the weights is <= 1.
If the sum is less than one, it means that there are vacancies.
Each element of the list must be >= 0, and the sum must be <= 1.
"""
w_sum = sum(weights_tuple)
if any(i < 0.0 for i in weights_tuple) or (w_sum - 1.0 > threshold):
raise ValueError(
"The weight list is not valid (each element must be positive, and the sum must be <= 1)."
)
[docs]
def is_valid_symbol(symbol):
"""Validates the chemical symbol name.
:return: True if the symbol is a valid chemical symbol (with correct
capitalization), or the dummy X, False otherwise.
Recognized symbols are for elements from hydrogen (Z=1) to lawrencium
(Z=103). In addition, a dummy element unknown name (Z=0) is supported.
"""
return symbol in _valid_symbols
[docs]
def validate_symbols_tuple(symbols_tuple):
"""Used to validate whether the chemical species are valid.
:param symbols_tuple: a tuple (or list) with the chemical symbols name.
:raises: UnsupportedSpeciesError if any symbol in the tuple is not a valid chemical
symbol (with correct capitalization).
Refer also to the documentation of :func:is_valid_symbol
"""
if len(symbols_tuple) == 0:
valid = False
else:
valid = all(is_valid_symbol(sym) for sym in symbols_tuple)
if not valid:
raise UnsupportedSpeciesError(
f"At least one element of the symbol list {symbols_tuple} has not been recognized."
)
[docs]
def group_symbols(_list):
"""Group a list of symbols to a list containing the number of consecutive
identical symbols, and the symbol itself.
Examples
--------
* ``['Ba','Ti','O','O','O','Ba']`` will return
``[[1,'Ba'],[1,'Ti'],[3,'O'],[1,'Ba']]``
* ``[ [ [1,'Ba'],[1,'Ti'] ],[ [1,'Ba'],[1,'Ti'] ] ]`` will return
``[[2, [ [1, 'Ba'], [1, 'Ti'] ] ]]``
:param _list: a list of elements representing a chemical formula
:return: a list of length-2 lists of the form [ multiplicity , element ]
"""
the_list = efficient_copy(_list)
the_list.reverse()
grouped_list = [[1, the_list.pop()]]
while the_list:
elem = the_list.pop()
if elem == grouped_list[-1][1]:
# same symbol is repeated
grouped_list[-1][0] += 1
else:
grouped_list.append([1, elem])
return grouped_list
[docs]
def get_symbols_string(symbols, weights):
"""Return a string that tries to match as good as possible the symbols
and weights. If there is only one symbol (no alloy) with 100%
occupancy, just returns the symbol name. Otherwise, groups the full
string in curly brackets, and try to write also the composition
(with 2 precision only).
If (sum of weights<1), we indicate it with the X symbol followed
by 1-sum(weights) (still with 2 digits precision, so it can be 0.00)
:param symbols: the symbols as obtained from <kind>._symbols
:param weights: the weights as obtained from <kind>._weights
.. note:: Note the difference with respect to the symbols and the
symbol properties!
"""
if len(symbols) == 1 and weights[0] == 1.0:
return symbols[0]
pieces = []
for symbol, weight in zip(symbols, weights):
pieces.append(f"{symbol}{weight:4.2f}")
if has_vacancies(weights):
pieces.append(f"X{1.0 - sum(weights):4.2f}")
return f"{{{''.join(sorted(pieces))}}}"
[docs]
def has_vacancies(weights):
"""Returns True if the sum of the weights is less than one.
It uses the internal variable _SUM_THRESHOLD as a threshold.
:param weights: the weights
:return: a boolean
"""
w_sum = sum(weights)
return not 1.0 - w_sum < _SUM_THRESHOLD
[docs]
def symop_ortho_from_fract(cell):
"""Creates a matrix for conversion from orthogonal to fractional
coordinates.
Taken from
svn://www.crystallography.net/cod-tools/trunk/lib/perl5/Fractional.pm,
revision 850.
:param cell: array of cell parameters (three lengths and three angles)
"""
import math
import numpy
a, b, c, alpha, beta, gamma = cell
alpha, beta, gamma = (math.pi * x / 180 for x in [alpha, beta, gamma])
ca, cb, cg = (math.cos(x) for x in [alpha, beta, gamma])
sg = math.sin(gamma)
return numpy.array(
[
[a, b * cg, c * cb],
[0, b * sg, c * (ca - cb * cg) / sg],
[0, 0, c * math.sqrt(sg * sg - ca * ca - cb * cb + 2 * ca * cb * cg) / sg],
]
)
[docs]
def symop_fract_from_ortho(cell):
"""Creates a matrix for conversion from fractional to orthogonal
coordinates.
Taken from
svn://www.crystallography.net/cod-tools/trunk/lib/perl5/Fractional.pm,
revision 850.
:param cell: array of cell parameters (three lengths and three angles)
"""
import math
import numpy
a, b, c, alpha, beta, gamma = cell
alpha, beta, gamma = (math.pi * x / 180 for x in [alpha, beta, gamma])
ca, cb, cg = (math.cos(x) for x in [alpha, beta, gamma])
sg = math.sin(gamma)
ctg = cg / sg
D = math.sqrt(sg * sg - cb * cb - ca * ca + 2 * ca * cb * cg) # noqa: N806
return numpy.array(
[
[1.0 / a, -(1.0 / a) * ctg, (ca * cg - cb) / (a * D)],
[0, 1.0 / (b * sg), -(ca - cb * cg) / (b * D * sg)],
[0, 0, sg / (c * D)],
]
)
[docs]
def ase_refine_cell(aseatoms, **kwargs):
"""Detect the symmetry of the structure, remove symmetric atoms and
refine unit cell.
:param aseatoms: an ase.atoms.Atoms instance
:param symprec: symmetry precision, used by spglib
:return newase: refined cell with reduced set of atoms
:return symmetry: a dictionary describing the symmetry space group
"""
from ase.atoms import Atoms
from spglib import get_symmetry_dataset, refine_cell
spglib_tuple = (
aseatoms.get_cell(),
aseatoms.get_scaled_positions(),
aseatoms.get_atomic_numbers(),
)
cell, positions, numbers = refine_cell(spglib_tuple, **kwargs)
refined_atoms = (
cell,
positions,
numbers,
)
sym_dataset = get_symmetry_dataset(refined_atoms, **kwargs)
unique_numbers = []
unique_positions = []
for i in set(sym_dataset["equivalent_atoms"]):
unique_numbers.append(numbers[i])
unique_positions.append(positions[i])
unique_atoms = Atoms(
unique_numbers, scaled_positions=unique_positions, cell=cell, pbc=True
)
return unique_atoms, {
"hm": sym_dataset["international"],
"hall": sym_dataset["hall"],
"tables": sym_dataset["number"],
"rotations": sym_dataset["rotations"],
"translations": sym_dataset["translations"],
}
[docs]
def atom_kinds_to_html(atom_kind):
"""Construct in html format
an alloy with 0.5 Ge, 0.4 Si and 0.1 vacancy is represented as
Ge<sub>0.5</sub> + Si<sub>0.4</sub> + vacancy<sub>0.1</sub>
Args:
-----
atom_kind: a string with the name of the atomic kind, as printed by
kind.get_symbols_string(), e.g. Ba0.80Ca0.10X0.10
Returns:
--------
html code for rendered formula
"""
# Parse the formula (TODO can be made more robust though never fails if
# it takes strings generated with kind.get_symbols_string())
import re
matched_elements = re.findall(r"([A-Z][a-z]*)([0-1][.[0-9]*]?)?", atom_kind)
# Compose the html string
html_formula_pieces = []
for element in matched_elements:
# replace element X by 'vacancy'
species = element[0] if element[0] != "X" else "vacancy"
weight = element[1] if element[1] != "" else None
if weight is not None:
html_formula_pieces.append(f"{species}<sub>{weight}</sub>")
else:
html_formula_pieces.append(species)
html_formula = " + ".join(html_formula_pieces)
return html_formula
[docs]
def create_automatic_kind_name(symbols, weights):
"""Create a string obtained with the symbols appended one
after the other, without spaces, in alphabetical order;
if the site has a vacancy, a X is appended at the end too.
"""
sorted_symbol_list = list(set(symbols))
sorted_symbol_list.sort() # In-place sort
name_string = "".join(sorted_symbol_list)
if has_vacancies(weights):
name_string += "X"
return name_string
[docs]
def set_symbols_and_weights(new_data):
"""Set the chemical symbols and the weights for the site.
.. note:: Note that the kind name remains unchanged.
"""
symbols_tuple = _create_symbols_tuple(new_data["symbol"]) if isinstance(new_data["symbol"], str) else new_data["symbol"]
for symbol in symbols_tuple:
if symbol not in _valid_symbols:
raise ValueError(f'This is not a valid element: {symbol}')
weights_tuple = _create_weights_tuple(new_data["weight"])
if len(symbols_tuple) != len(weights_tuple):
raise ValueError('The number of symbols and weights must coincide.')
validate_symbols_tuple(symbols_tuple)
validate_weights_tuple(weights_tuple, _SUM_THRESHOLD)
new_data["alloy"] = symbols_tuple
new_data["weight"] = weights_tuple
if "mass" not in new_data.keys() or np.isnan(new_data.get("mass", None)) or new_data.get("mass", None) == 0:
# Weighted mass
w_sum = sum(weights_tuple)
normalized_weights = (i / w_sum for i in weights_tuple)
element_masses = (_atomic_masses[sym] for sym in symbols_tuple)
new_data["mass"] = sum(i * j for i, j in zip(normalized_weights, element_masses))
[docs]
def check_is_alloy(data):
"""Check if the data is an alloy or not.
:param data: the data to check. The dict of the SiteCore model.
:return: True if the data is an alloy, False otherwise.
"""
new_data = efficient_copy(data)
if "weight" not in new_data.keys() or new_data.get("weight", None) is None:
return new_data
if len(new_data.get("weight", [1,])) == 1:
if new_data["symbol"] not in _valid_symbols:
raise ValueError(f'This is not a valid element: {new_data["symbol"]}')
return None
set_symbols_and_weights(new_data)
return new_data
[docs]
def check_plugin_support(structure, plugin_properties: set) -> set:
"""
Check if the plugin supports the given properties.
:param plugin_properties: The supported properties in the plugin.
:return: the defined properties which are not supported by the plugin
:rtype: set
"""
defined_properties = structure.get_defined_properties()
return defined_properties.difference(plugin_properties)
[docs]
def order_k(k):
"""
Adjusts the order of elements in the array `k` by ensuring that there are no gaps in the sequence.
If the minimum value in `k` is 0, it increments all elements by 1. Then, it iterates from the maximum value
in `k` down to the minimum value, checking if each value minus one is not in `k`. If a value minus one is not
found, it decrements all elements in `k` that are greater than or equal to the current value.
Parameters:
k (numpy.ndarray): An array of integers to be reordered.
Returns:
numpy.ndarray: The reordered array `k`.
"""
if min(k) == 0:
k = k + 1
for i in range(max(k),min(1,min(k)),-1):
if i-1 not in k:
k[np.where(k >=i )] -= 1
return k
[docs]
def compress_properties_by_kind(props):
"""
Compress site-wise properties into kind-wise lists.
Returns a dict with properties as lists, one entry per kind.
"""
import numpy as np
if not props.get("kind_names", None):
raise ValueError("The input properties must contain 'kind_names' information.")
kind_names_array = np.array(props["kind_names"])
site_props = set(props.keys()).difference(_GLOBAL_PROPERTIES + _COMPUTED_PROPERTIES + ["sites"])
compressed = {prop: [] if prop in site_props.union(["site_indices"]) else props.get(prop, None) for prop in site_props.union(["site_indices"]).union(_GLOBAL_PROPERTIES)}
for kind_name in set(props["kind_names"]):
site_indices = np.where(kind_names_array == kind_name)[0]
for prop in site_props:
if prop == "positions":
compressed[prop].append([props[prop][i] for i in site_indices])
elif prop in props:
compressed[prop].append(props[prop][site_indices[0]])
else:
compressed.pop(prop)
compressed["site_indices"].append(site_indices.tolist())
for prop in _GLOBAL_PROPERTIES:
if compressed.get(prop, None) is None:
compressed.pop(prop, None)
return compressed
[docs]
def rebuild_site_lists_from_kind_lists(compressed):
"""
Expand kinds into a list of site dictionaries, sorted by site_index.
"""
site_props = set(compressed.keys()).difference(_GLOBAL_PROPERTIES + _COMPUTED_PROPERTIES + ["sites","site_indices"])
expanded = {prop: [] if prop in site_props.union(["site_indices"]) else compressed.get(prop, None) for prop in site_props.union(["site_indices"]).union(_GLOBAL_PROPERTIES)}
for i, site_indices in enumerate(compressed["site_indices"]):
for prop in site_props:
if prop == "positions":
expanded[prop].extend(compressed[prop][i])
elif prop == "site_indices":
continue
elif prop in compressed:
expanded[prop].extend([compressed[prop][i]] * len(site_indices))
else:
expanded.pop(prop)
expanded["site_indices"] += site_indices
# Reorder by site_index
order = np.argsort(expanded["site_indices"])
for prop in site_props.union(["site_indices"]):
if prop in expanded:
expanded[prop] = [expanded[prop][i] for i in order]
for prop in _GLOBAL_PROPERTIES:
if expanded.get(prop, None) is None:
expanded.pop(prop, None)
expanded.pop("site_indices")
return expanded
[docs]
def build_sites_from_expanded_properties(expanded):
"""
Build the structure dictionary from expanded site-wise lists of properties.
"""
# Use all keys except positions if you want to exclude arrays, or specify your own
site_props = set(expanded.keys()).difference(_GLOBAL_PROPERTIES + _COMPUTED_PROPERTIES + ["sites", "site_indices"])
n_sites = len(expanded.get("positions",[]))
sites = []
for i in range(n_sites):
site = {}
site["position"] = expanded["positions"][i]
for prop in site_props:
site[_CONVERSION_PLURAL_SINGULAR[prop]] = expanded[prop][i]
sites.append(site)
structure_dict = {}
for prop in _GLOBAL_PROPERTIES:
if expanded.get(prop, None) is not None:
structure_dict[prop] = expanded[prop]
structure_dict["sites"] = sites
return structure_dict
[docs]
def classify_site_kinds(sites:list, exclude_props:bool=None, tolerance:t.Union[dict, float]=1e-3):
"""
Classify sites into groups where each group (kind) has the same properties except position.
Args:
sites: List of site dictionaries
exclude_props: Set of property names to exclude from grouping (default: {'position'})
tolerance: Numerical tolerance for floating point comparisons (default: 1e-3)
Returns:
dict: {group_key: {'sites': [site_indices], 'properties': {prop: value}}}
"""
import numpy as np
from collections import defaultdict
if exclude_props is None:
exclude_props = {'position'}
def normalize_value(value, tol=tolerance):
"""Normalize values for consistent comparison."""
if isinstance(value, np.ndarray):
# Round numpy arrays to tolerance
normalized = np.round(value / tol) * tol
return tuple(normalized.tolist())
elif isinstance(value, (float, np.floating)):
# Round floats to tolerance
return round(value / tol) * tol
elif isinstance(value, (int, np.integer)):
return int(value)
elif value is None:
return None
else:
return value
groups = defaultdict(lambda: {'sites': [], 'positions': [], 'properties': {}})
for i, site in enumerate(sites):
# Create a hashable key from normalized properties
key_props = {}
for prop, value in site.items():
if prop not in exclude_props:
if isinstance(tolerance, dict):
tol = tolerance.get(prop, 1e-3)
else:
tol = tolerance
normalized_value = normalize_value(value, tol)
key_props[prop] = normalized_value
# Create a hashable key containing both property names and their normalized values, so it is a unique identifier
key = tuple(sorted(key_props.items()))
# Add site index to this group (or this specific hashable key)
groups[key]['sites'].append(i)
groups[key]['positions'].append(site['position'])
# Store the original properties (first occurrence)
if not groups[key]['properties']:
groups[key]['properties'] = {
prop: normalize_value(value, tolerance.get(prop, 1e-3) if isinstance(tolerance, dict) else tolerance) for prop, value in site.items()
if prop not in exclude_props
}
return dict(groups)
# Usage example:
# groups = classify_site_kinds(m.to_dict()['sites'])
# for i, (key, group) in enumerate(groups.items()):
# print(f"Group {i+1}:")
# print(f" Sites: {group['sites']}")
# print(f" Positions: {group['positions']}")
# print(f" Properties: {group['properties']}")
# print()
[docs]
def check_kinds_match(structure, kinds_list):
check_kinds = []
kind_names_indices = [kind['site_indices'] for kind in kinds_list]
for kind in structure.kinds:
site_indices = kind.site_indices
check_kinds.append(site_indices in kind_names_indices)
return all(check_kinds)
[docs]
def sites_from_kinds(kinds):
"""
Expand kinds into a list of site dictionaries, sorted by site_index.
1. Create a list of site indices and positions from the kinds
2. Create a list of site dictionaries by copying the kind properties
and adding the position
3. Return the list of site dictionaries
4. Note: the returned list is sorted by site_index
Format of kinds (basically what can be obtained by structure.generate_kinds()):
[
{'site_indices': [0, 2],
'positions': [array([0., 0., 0.]), array([0., 1., 0.])],
'symbol': 'H',
'mass': 1.008,
'charge': 0.0,
'magmom': (0.0, 0.0, -1.0),
'kind_name': 'H1'},
{'site_indices': [1],
'positions': [array([0., 0., 1.])],
'symbol': 'O',
'mass': 15.999,
'charge': -2.0,
'magmom': (0.0, 0.0, 1.0),
'kind_name': 'O1'}
]
"""
sites_list = []
positions = []
for i,kind in enumerate(kinds):
sites_list += [i]*len(kind['site_indices'])
positions += kind['positions']
num_sites = len(sites_list)
for i in range(num_sites):
sites_list[i] = efficient_copy(kinds[sites_list[i]])
sites_list[i].pop('site_indices', None)
sites_list[i].pop('positions', None)
sites_list[i]['position'] = positions[i]
return sites_list