Skip to content

API Documentation🔗

Text representation🔗

Core class🔗

Generate text representations of crystal structure for Language modelling.

Attributes:

Name Type Description
structure

pymatgen structure

Methods:

Name Description
from_input

a classmethod

Source code in xtal2txt/core.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
class TextRep:
    """
    Generate text representations of crystal structure for Language modelling.

    Attributes:
        structure : pymatgen structure

    Methods:
        from_input : a classmethod
        get_cif_string(n=3)
        get_parameters(n=3)
        get_coords(name, n=3)
        get_cartesian(n=3)
        get_fractional(n=3)
    """

    backend = InvCryRep()
    condenser = StructureCondenser()
    describer = StructureDescriber()

    def __init__(
        self,
        structure: Structure,
        transformations: List[Tuple[str, dict]] = None,
    ) -> None:
        self.structure = structure
        self.transformations = transformations or []
        self.apply_transformations()

    @classmethod
    def from_input(
        cls,
        input_data: Union[str, Path, Structure],
        transformations: List[Tuple[str, dict]] = None,
    ) -> "TextRep":
        """
        Instantiate the TextRep class object with the pymatgen structure from a cif file, a cif string, or a pymatgen Structure object.

        Args:
            input_data (Union[str,pymatgen.core.structure.Structure]): A cif file of a crystal structure, a cif string,
                or a pymatgen Structure object.

        Returns:
            TextRep: A TextRep object.
        """
        if isinstance(input_data, Structure):
            structure = input_data

        elif isinstance(input_data, (str, Path)):
            try:
                if Path(input_data).is_file():
                    structure = Structure.from_file(str(input_data))
                else:
                    raise ValueError
            except (OSError, ValueError):
                structure = Structure.from_str(str(input_data), "cif")

        else:
            structure = Structure.from_str(str(input_data), "cif")

        return cls(structure, transformations)

    def apply_transformations(self) -> None:
        """
        Apply transformations to the structure.
        """
        for transformation, params in self.transformations:
            transform_func = getattr(TransformationCallback, transformation)
            self.structure = transform_func(self.structure, **params)

    @staticmethod
    def _safe_call(func, *args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception:
            return None

    @staticmethod
    def round_numbers_in_string(original_string: str, decimal_places: int) -> str:
        """
        Rounds float numbers in the given string to the specified number of decimal places using regex.

        Args:
            original_string (str): The input string.
            decimal_places (int): The number of decimal places to round to.

        Returns:
            str: The string with the float numbers rounded to the specified number of decimal places.
        """
        pattern = r"\b\d+\.\d+\b"
        matches = re.findall(pattern, original_string)
        rounded_numbers = [round(float(match), decimal_places) for match in matches]
        new_string = re.sub(
            pattern, lambda x: str(rounded_numbers.pop(0)), original_string
        )
        return new_string

    def get_cif_string(
        self, format: str = "symmetrized", decimal_places: int = 3
    ) -> str:
        """
        Generate CIF as string in multi-line format.

        All float numbers can be rounded to the specified number (decimal_places).
        Currently supports two formats. Symmetrized (cif with symmetry operations and the least symmetric basis) ...
        and P1 (conventional unit cell , with all the atoms listed and only identity as symmetry operation).

        Args:
            format (str): The format of the CIF file. Can be "symmetrized" or "p1".
            decimal_places (int): The number of decimal places to round to.

        Returns:
            str: The CIF string.
        """

        if format == "symmetrized":
            symmetry_analyzer = SpacegroupAnalyzer(self.structure)
            symmetrized_structure = symmetry_analyzer.get_symmetrized_structure()
            cif_string = str(
                CifWriter(
                    symmetrized_structure,
                    symprec=0.1,
                    significant_figures=decimal_places,
                ).cif_file
            )
            cif = "\n".join(cif_string.split("\n")[1:])
            return self.round_numbers_in_string(cif, decimal_places)

        elif format == "p1":
            cif_string = "\n".join(self.structure.to(fmt="cif").split("\n")[1:])
            return self.round_numbers_in_string(cif_string, decimal_places)

    def get_lattice_parameters(self, decimal_places: int = 3) -> List[str]:
        """
        Return lattice parameters of unit cells in a crystal lattice:
        the lengths of the cell edges (a, b, and c) in angstrom and the angles between them (alpha, beta, and gamma) in degrees.

        All float numbers can be rounded to a specific number (decimal_places).

        Args:
            decimal_places (int): The number of decimal places to round to.

        Returns:
            List[str]: The lattice parameters.
        """
        return [
            str(round(i, decimal_places)) for i in self.structure.lattice.parameters
        ]

    def get_coords(self, name: str = "cartesian", decimal_places: int = 3) -> List[str]:
        """
        Return list of atoms in unit cell for with their positions in Cartesian or fractional coordinates as per choice.

        Args:
            name (str): The name of the coordinates. Can be "cartesian" or "fractional".
            decimal_places (int): The number of decimal places to round to.

        Returns:
            List[str]: The list of atoms with their positions.
        """
        elements = []
        for site in self.structure.sites:
            elements.append(str(site.specie))
            coord = [
                str(x)
                for x in (
                    site.coords.round(decimal_places)
                    if name == "cartesian"
                    else site.frac_coords.round(decimal_places)
                )
            ]
            elements.extend(coord)
        return elements

    def get_slices(self, primitive: bool = True) -> str:
        """Returns SLICES representation of the crystal structure.
        https://www.nature.com/articles/s41467-023-42870-7

        Args:
            primitive (bool): Whether to use the primitive structure or not.

        Returns:
            str: The SLICE representation of the crystal structure.
        """

        if primitive:
            primitive_structure = (
                self.structure.get_primitive_structure()
            )  # convert to primitive structure
            return self.backend.structure2SLICES(primitive_structure)
        return self.backend.structure2SLICES(self.structure)

    def get_composition(self, format="hill") -> str:
        """Return composition in hill format.

        Args:
            format (str): format in which the composition is required.

        Returns:
            str: The composition in hill format.
        """
        if format == "hill":
            composition_string = self.structure.composition.hill_formula
            composition = composition_string.replace(" ", "")
        return composition

    def get_local_env_rep(self, local_env_kwargs: Optional[dict] = None) -> str:
        """
        Get the local environment representation of the crystal structure.

        The local environment representation is a string that contains
        the space group symbol and the local environment of each atom in the unit cell.
        The local environment of each atom is represented as SMILES string and the
        Wyckoff symbol of the local environment.

        Args:
            local_env_kwargs (dict): Keyword arguments to pass to the LocalEnvAnalyzer.

        Returns:
            str: The local environment representation of the crystal structure.
        """
        if not local_env_kwargs:
            local_env_kwargs = {}
        analyzer = LocalEnvAnalyzer(**local_env_kwargs)
        return analyzer.structure_to_local_env_string(self.structure)

    def get_crystal_text_llm(
        self,
        permute_atoms: bool = False,
    ) -> str:
        """
        Code adopted from https://github.com/facebookresearch/crystal-llm/blob/main/llama_finetune.py
        https://openreview.net/pdf?id=0r5DE2ZSwJ

        Returns the representation as per the above citation.

        Args:
            permute_atoms (bool): Whether to permute the atoms in the unit cell.

        Returns:
            str: The crystal-llm representation of the crystal structure.
        """

        lengths = self.structure.lattice.parameters[:3]
        angles = self.structure.lattice.parameters[3:]
        atom_ids = self.structure.species
        frac_coords = self.structure.frac_coords

        if permute_atoms:
            atom_coord_pairs = list(zip(atom_ids, frac_coords))
            random.shuffle(atom_coord_pairs)
            atom_ids, frac_coords = zip(*atom_coord_pairs)

        crystal_str = (
            " ".join(["{0:.1f}".format(x) for x in lengths])
            + "\n"
            + " ".join([str(int(x)) for x in angles])
            + "\n"
            + "\n".join(
                [
                    str(t) + "\n" + " ".join(["{0:.2f}".format(x) for x in c])
                    for t, c in zip(atom_ids, frac_coords)
                ]
            )
        )

        return crystal_str

    def get_robocrys_rep(self):
        """
        https://github.com/hackingmaterials/robocrystallographer/tree/main
        """

        condensed_structure = self.condenser.condense_structure(self.structure)
        return self.describer.describe(condensed_structure)

    def get_wyckoff_positions(self):
        """
        Getting wyckoff positions of the elements in the unit cell as the combination of...
        number and letter.

        Returns:
            str:  A multi-line string that contain elements of the unit cell along with their wyckoff position in each line.

        Hint:
            At the end of the string, there is an additional newline character.
        """

        spacegroup_analyzer = SpacegroupAnalyzer(self.structure)
        wyckoff_sites = spacegroup_analyzer.get_symmetry_dataset()
        element_symbols = [site.specie.element.symbol for site in self.structure.sites]

        data = []

        for i in range(len(wyckoff_sites["wyckoffs"])):
            sub_data = (
                element_symbols[i],
                wyckoff_sites["wyckoffs"][i],
                wyckoff_sites["equivalent_atoms"][i],
            )
            data.append(sub_data)

        a = dict(Counter(data))

        output = ""
        for i, j in a.items():
            output += str(i[0]) + " " + str(j) + " " + str(i[1]) + "\n"

        return output

    def get_wycryst(self):
        """
        Obtaining the wyckoff representation for crystal structures that include:
            chemical formula
            space group number
            elements of the unit cell with their wyckoff positions.

        Returns:
            str: A multi-line string that contains the chemical formula, space group number,
                and the elements of the unit cell with their wyckoff positions.
        """
        output = ""
        chemical_formula = self.structure.composition.formula
        output += chemical_formula
        output += "\n" + str(self.structure.get_space_group_info()[1])
        output += "\n" + self.get_wyckoff_positions()

        return output

    def get_atom_sequences_plusplus(
        self, lattice_params: bool = False, decimal_places: int = 1
    ) -> str:
        """
        Generating a string with the elements of composition inside the crystal lattice with the option to
        get the lattice parameters as angles (int) and lengths (float) in a string with a space
        between them

        Args:
            lattice_params (bool): Whether to include lattice parameters or not.
            decimal_places (int): The number of decimal places to round to.

        Returns:
            str: The string representation of the crystal structure.
        """

        try:
            output = [site.specie.element.symbol for site in self.structure.sites]
        except AttributeError:
            output = [site.specie.symbol for site in self.structure.sites]
        if lattice_params:
            params = self.get_lattice_parameters(decimal_places=decimal_places)
            params[3:] = [str(int(float(i))) for i in params[3:]]
            output.extend(params)

        return " ".join(output)

    def updated_zmatrix_rep(self, zmatrix, decimal_places=1):
        """
        Replace the variables in the Z-matrix with their values and return the updated Z-matrix.
        for eg: z-matrix from pymatgen
        'N\nN 1 B1\nN 1 B2 2 A2\nN 1 B3 2 A3 3 D3\n
        B1=3.79
        B2=6.54
        ....
        is replaced to
        'N\nN 1 3.79\nN 1 6.54 2 90\nN 1 6.54 2 90 3 120\n'

        Args:
            Zmatrix (bool): zmatrix multi line string as implemented in pymatgen.
            decimal_places (int): The number of decimal places to round to.

        Returns:
            str: The updated Z-matrix representation of the crystal structure.
        """
        lines = zmatrix.split("\n")
        main_part = []
        variables_part = []

        # Determine the main part and the variables part of the Z-matrix
        for line in lines:
            if "=" in line:
                variables_part.append(line)
            else:
                if line.strip():  # Skip empty lines
                    main_part.append(line)

        # Extract variables from the variables part
        variable_dict = {}
        for var_line in variables_part:
            var, value = var_line.split("=")
            if var.startswith("B"):
                rounded_value = round(float(value.strip()), decimal_places)
            else:
                rounded_value = int(round(float(value.strip())))
            variable_dict[var] = (
                f"{rounded_value}"
                if var.startswith(("A", "D"))
                else f"{rounded_value:.{decimal_places}f}"
            )

        # Replace variables in the main part
        replaced_lines = []
        for line in main_part:
            parts = line.split()
            # atom = parts[0]
            replaced_line = line
            for i in range(1, len(parts)):
                var = parts[i]
                if var in variable_dict:
                    replaced_line = replaced_line.replace(var, variable_dict[var])
            replaced_lines.append(replaced_line)

        return "\n".join(replaced_lines)

    def get_zmatrix_rep(self, decimal_places=1):
        """
        Generate the Z-matrix representation of the crystal structure.
        It provides a description of each atom in terms of its atomic number,
        bond length, bond angle, and dihedral angle, the so-called internal coordinates.

        Disclaimer: The Z-matrix is meant for molecules, current implementation converts atoms within unit cell to molecule.
        Hence the current implentation might overlook bonds acrosse unit cells.
        """
        species = [
            s.element if hasattr(s, "element") else s for s in self.structure.species
        ]
        coords = [c for c in self.structure.cart_coords]
        molecule_ = Molecule(
            species,
            coords,
        )
        zmatrix = molecule_.get_zmatrix()
        return self.updated_zmatrix_rep(zmatrix, decimal_places)

    def get_all_text_reps(self, decimal_places: int = 2):
        """
        Returns all the Text representations of the crystal structure in a dictionary.
        """

        return {
            "cif_p1": self._safe_call(
                self.get_cif_string, format="p1", decimal_places=decimal_places
            ),
            "cif_symmetrized": self._safe_call(
                self.get_cif_string, format="symmetrized", decimal_places=decimal_places
            ),
            "cif_bonding": None,
            "slices": self._safe_call(self.get_slices),
            "composition": self._safe_call(self.get_composition),
            "crystal_text_llm": self._safe_call(self.get_crystal_text_llm),
            "robocrys_rep": self._safe_call(self.get_robocrys_rep),
            "wycoff_rep": None,
            "atom_sequences": self._safe_call(
                self.get_atom_sequences_plusplus,
                lattice_params=False,
                decimal_places=decimal_places,
            ),
            "atom_sequences_plusplus": self._safe_call(
                self.get_atom_sequences_plusplus,
                lattice_params=True,
                decimal_places=decimal_places,
            ),
            "zmatrix": self._safe_call(self.get_zmatrix_rep),
            "local_env": self._safe_call(self.get_local_env_rep, local_env_kwargs=None),
        }

    def get_requested_text_reps(
        self, requested_reps: List[str], decimal_places: int = 2
    ):
        """
        Returns the requested Text representations of the crystal structure in a dictionary.

        Args:
            requested_reps (List[str]): The list of representations to return.
            decimal_places (int): The number of decimal places to round to.

        Returns:
            dict: A dictionary containing the requested text representations of the crystal structure.
        """

        if requested_reps == "cif_p1":
            return self._safe_call(
                self.get_cif_string, format="p1", decimal_places=decimal_places
            )

        elif requested_reps == "cif_symmetrized":
            return self._safe_call(
                self.get_cif_string,
                format="symmetrized",
                decimal_places=decimal_places,
            )

        elif requested_reps == "slices":
            return self._safe_call(self.get_slices)

        elif requested_reps == "composition":
            return self._safe_call(self.get_composition)

        elif requested_reps == "crystal_text_llm":
            return self._safe_call(self.get_crystal_text_llm)

        elif requested_reps == "robocrys_rep":
            return self._safe_call(self.get_robocrys_rep)

        elif requested_reps == "atom_sequences":
            return self._safe_call(
                self.get_atom_sequences_plusplus,
                lattice_params=False,
                decimal_places=decimal_places,
            )

        elif requested_reps == "atom_sequences_plusplus":
            return self._safe_call(
                self.get_atom_sequences_plusplus,
                lattice_params=True,
                decimal_places=decimal_places,
            )

        elif requested_reps == "zmatrix":
            return self._safe_call(self.get_zmatrix_rep)

        elif requested_reps == "local_env":
            return self._safe_call(self.get_local_env_rep, local_env_kwargs=None)

apply_transformations() 🔗

Apply transformations to the structure.

Source code in xtal2txt/core.py
80
81
82
83
84
85
86
def apply_transformations(self) -> None:
    """
    Apply transformations to the structure.
    """
    for transformation, params in self.transformations:
        transform_func = getattr(TransformationCallback, transformation)
        self.structure = transform_func(self.structure, **params)

from_input(input_data, transformations=None) classmethod 🔗

Instantiate the TextRep class object with the pymatgen structure from a cif file, a cif string, or a pymatgen Structure object.

Parameters:

Name Type Description Default
input_data Union[str, Structure]

A cif file of a crystal structure, a cif string, or a pymatgen Structure object.

required

Returns:

Name Type Description
TextRep TextRep

A TextRep object.

Source code in xtal2txt/core.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@classmethod
def from_input(
    cls,
    input_data: Union[str, Path, Structure],
    transformations: List[Tuple[str, dict]] = None,
) -> "TextRep":
    """
    Instantiate the TextRep class object with the pymatgen structure from a cif file, a cif string, or a pymatgen Structure object.

    Args:
        input_data (Union[str,pymatgen.core.structure.Structure]): A cif file of a crystal structure, a cif string,
            or a pymatgen Structure object.

    Returns:
        TextRep: A TextRep object.
    """
    if isinstance(input_data, Structure):
        structure = input_data

    elif isinstance(input_data, (str, Path)):
        try:
            if Path(input_data).is_file():
                structure = Structure.from_file(str(input_data))
            else:
                raise ValueError
        except (OSError, ValueError):
            structure = Structure.from_str(str(input_data), "cif")

    else:
        structure = Structure.from_str(str(input_data), "cif")

    return cls(structure, transformations)

get_all_text_reps(decimal_places=2) 🔗

Returns all the Text representations of the crystal structure in a dictionary.

Source code in xtal2txt/core.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def get_all_text_reps(self, decimal_places: int = 2):
    """
    Returns all the Text representations of the crystal structure in a dictionary.
    """

    return {
        "cif_p1": self._safe_call(
            self.get_cif_string, format="p1", decimal_places=decimal_places
        ),
        "cif_symmetrized": self._safe_call(
            self.get_cif_string, format="symmetrized", decimal_places=decimal_places
        ),
        "cif_bonding": None,
        "slices": self._safe_call(self.get_slices),
        "composition": self._safe_call(self.get_composition),
        "crystal_text_llm": self._safe_call(self.get_crystal_text_llm),
        "robocrys_rep": self._safe_call(self.get_robocrys_rep),
        "wycoff_rep": None,
        "atom_sequences": self._safe_call(
            self.get_atom_sequences_plusplus,
            lattice_params=False,
            decimal_places=decimal_places,
        ),
        "atom_sequences_plusplus": self._safe_call(
            self.get_atom_sequences_plusplus,
            lattice_params=True,
            decimal_places=decimal_places,
        ),
        "zmatrix": self._safe_call(self.get_zmatrix_rep),
        "local_env": self._safe_call(self.get_local_env_rep, local_env_kwargs=None),
    }

get_atom_sequences_plusplus(lattice_params=False, decimal_places=1) 🔗

Generating a string with the elements of composition inside the crystal lattice with the option to get the lattice parameters as angles (int) and lengths (float) in a string with a space between them

Parameters:

Name Type Description Default
lattice_params bool

Whether to include lattice parameters or not.

False
decimal_places int

The number of decimal places to round to.

1

Returns:

Name Type Description
str str

The string representation of the crystal structure.

Source code in xtal2txt/core.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def get_atom_sequences_plusplus(
    self, lattice_params: bool = False, decimal_places: int = 1
) -> str:
    """
    Generating a string with the elements of composition inside the crystal lattice with the option to
    get the lattice parameters as angles (int) and lengths (float) in a string with a space
    between them

    Args:
        lattice_params (bool): Whether to include lattice parameters or not.
        decimal_places (int): The number of decimal places to round to.

    Returns:
        str: The string representation of the crystal structure.
    """

    try:
        output = [site.specie.element.symbol for site in self.structure.sites]
    except AttributeError:
        output = [site.specie.symbol for site in self.structure.sites]
    if lattice_params:
        params = self.get_lattice_parameters(decimal_places=decimal_places)
        params[3:] = [str(int(float(i))) for i in params[3:]]
        output.extend(params)

    return " ".join(output)

get_cif_string(format='symmetrized', decimal_places=3) 🔗

Generate CIF as string in multi-line format.

All float numbers can be rounded to the specified number (decimal_places). Currently supports two formats. Symmetrized (cif with symmetry operations and the least symmetric basis) ... and P1 (conventional unit cell , with all the atoms listed and only identity as symmetry operation).

Parameters:

Name Type Description Default
format str

The format of the CIF file. Can be "symmetrized" or "p1".

'symmetrized'
decimal_places int

The number of decimal places to round to.

3

Returns:

Name Type Description
str str

The CIF string.

Source code in xtal2txt/core.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def get_cif_string(
    self, format: str = "symmetrized", decimal_places: int = 3
) -> str:
    """
    Generate CIF as string in multi-line format.

    All float numbers can be rounded to the specified number (decimal_places).
    Currently supports two formats. Symmetrized (cif with symmetry operations and the least symmetric basis) ...
    and P1 (conventional unit cell , with all the atoms listed and only identity as symmetry operation).

    Args:
        format (str): The format of the CIF file. Can be "symmetrized" or "p1".
        decimal_places (int): The number of decimal places to round to.

    Returns:
        str: The CIF string.
    """

    if format == "symmetrized":
        symmetry_analyzer = SpacegroupAnalyzer(self.structure)
        symmetrized_structure = symmetry_analyzer.get_symmetrized_structure()
        cif_string = str(
            CifWriter(
                symmetrized_structure,
                symprec=0.1,
                significant_figures=decimal_places,
            ).cif_file
        )
        cif = "\n".join(cif_string.split("\n")[1:])
        return self.round_numbers_in_string(cif, decimal_places)

    elif format == "p1":
        cif_string = "\n".join(self.structure.to(fmt="cif").split("\n")[1:])
        return self.round_numbers_in_string(cif_string, decimal_places)

get_composition(format='hill') 🔗

Return composition in hill format.

Parameters:

Name Type Description Default
format str

format in which the composition is required.

'hill'

Returns:

Name Type Description
str str

The composition in hill format.

Source code in xtal2txt/core.py
210
211
212
213
214
215
216
217
218
219
220
221
222
def get_composition(self, format="hill") -> str:
    """Return composition in hill format.

    Args:
        format (str): format in which the composition is required.

    Returns:
        str: The composition in hill format.
    """
    if format == "hill":
        composition_string = self.structure.composition.hill_formula
        composition = composition_string.replace(" ", "")
    return composition

get_coords(name='cartesian', decimal_places=3) 🔗

Return list of atoms in unit cell for with their positions in Cartesian or fractional coordinates as per choice.

Parameters:

Name Type Description Default
name str

The name of the coordinates. Can be "cartesian" or "fractional".

'cartesian'
decimal_places int

The number of decimal places to round to.

3

Returns:

Type Description
List[str]

List[str]: The list of atoms with their positions.

Source code in xtal2txt/core.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def get_coords(self, name: str = "cartesian", decimal_places: int = 3) -> List[str]:
    """
    Return list of atoms in unit cell for with their positions in Cartesian or fractional coordinates as per choice.

    Args:
        name (str): The name of the coordinates. Can be "cartesian" or "fractional".
        decimal_places (int): The number of decimal places to round to.

    Returns:
        List[str]: The list of atoms with their positions.
    """
    elements = []
    for site in self.structure.sites:
        elements.append(str(site.specie))
        coord = [
            str(x)
            for x in (
                site.coords.round(decimal_places)
                if name == "cartesian"
                else site.frac_coords.round(decimal_places)
            )
        ]
        elements.extend(coord)
    return elements

get_crystal_text_llm(permute_atoms=False) 🔗

Code adopted from https://github.com/facebookresearch/crystal-llm/blob/main/llama_finetune.py https://openreview.net/pdf?id=0r5DE2ZSwJ

Returns the representation as per the above citation.

Parameters:

Name Type Description Default
permute_atoms bool

Whether to permute the atoms in the unit cell.

False

Returns:

Name Type Description
str str

The crystal-llm representation of the crystal structure.

Source code in xtal2txt/core.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def get_crystal_text_llm(
    self,
    permute_atoms: bool = False,
) -> str:
    """
    Code adopted from https://github.com/facebookresearch/crystal-llm/blob/main/llama_finetune.py
    https://openreview.net/pdf?id=0r5DE2ZSwJ

    Returns the representation as per the above citation.

    Args:
        permute_atoms (bool): Whether to permute the atoms in the unit cell.

    Returns:
        str: The crystal-llm representation of the crystal structure.
    """

    lengths = self.structure.lattice.parameters[:3]
    angles = self.structure.lattice.parameters[3:]
    atom_ids = self.structure.species
    frac_coords = self.structure.frac_coords

    if permute_atoms:
        atom_coord_pairs = list(zip(atom_ids, frac_coords))
        random.shuffle(atom_coord_pairs)
        atom_ids, frac_coords = zip(*atom_coord_pairs)

    crystal_str = (
        " ".join(["{0:.1f}".format(x) for x in lengths])
        + "\n"
        + " ".join([str(int(x)) for x in angles])
        + "\n"
        + "\n".join(
            [
                str(t) + "\n" + " ".join(["{0:.2f}".format(x) for x in c])
                for t, c in zip(atom_ids, frac_coords)
            ]
        )
    )

    return crystal_str

get_lattice_parameters(decimal_places=3) 🔗

Return lattice parameters of unit cells in a crystal lattice: the lengths of the cell edges (a, b, and c) in angstrom and the angles between them (alpha, beta, and gamma) in degrees.

All float numbers can be rounded to a specific number (decimal_places).

Parameters:

Name Type Description Default
decimal_places int

The number of decimal places to round to.

3

Returns:

Type Description
List[str]

List[str]: The lattice parameters.

Source code in xtal2txt/core.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def get_lattice_parameters(self, decimal_places: int = 3) -> List[str]:
    """
    Return lattice parameters of unit cells in a crystal lattice:
    the lengths of the cell edges (a, b, and c) in angstrom and the angles between them (alpha, beta, and gamma) in degrees.

    All float numbers can be rounded to a specific number (decimal_places).

    Args:
        decimal_places (int): The number of decimal places to round to.

    Returns:
        List[str]: The lattice parameters.
    """
    return [
        str(round(i, decimal_places)) for i in self.structure.lattice.parameters
    ]

get_local_env_rep(local_env_kwargs=None) 🔗

Get the local environment representation of the crystal structure.

The local environment representation is a string that contains the space group symbol and the local environment of each atom in the unit cell. The local environment of each atom is represented as SMILES string and the Wyckoff symbol of the local environment.

Parameters:

Name Type Description Default
local_env_kwargs dict

Keyword arguments to pass to the LocalEnvAnalyzer.

None

Returns:

Name Type Description
str str

The local environment representation of the crystal structure.

Source code in xtal2txt/core.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def get_local_env_rep(self, local_env_kwargs: Optional[dict] = None) -> str:
    """
    Get the local environment representation of the crystal structure.

    The local environment representation is a string that contains
    the space group symbol and the local environment of each atom in the unit cell.
    The local environment of each atom is represented as SMILES string and the
    Wyckoff symbol of the local environment.

    Args:
        local_env_kwargs (dict): Keyword arguments to pass to the LocalEnvAnalyzer.

    Returns:
        str: The local environment representation of the crystal structure.
    """
    if not local_env_kwargs:
        local_env_kwargs = {}
    analyzer = LocalEnvAnalyzer(**local_env_kwargs)
    return analyzer.structure_to_local_env_string(self.structure)

get_requested_text_reps(requested_reps, decimal_places=2) 🔗

Returns the requested Text representations of the crystal structure in a dictionary.

Parameters:

Name Type Description Default
requested_reps List[str]

The list of representations to return.

required
decimal_places int

The number of decimal places to round to.

2

Returns:

Name Type Description
dict

A dictionary containing the requested text representations of the crystal structure.

Source code in xtal2txt/core.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def get_requested_text_reps(
    self, requested_reps: List[str], decimal_places: int = 2
):
    """
    Returns the requested Text representations of the crystal structure in a dictionary.

    Args:
        requested_reps (List[str]): The list of representations to return.
        decimal_places (int): The number of decimal places to round to.

    Returns:
        dict: A dictionary containing the requested text representations of the crystal structure.
    """

    if requested_reps == "cif_p1":
        return self._safe_call(
            self.get_cif_string, format="p1", decimal_places=decimal_places
        )

    elif requested_reps == "cif_symmetrized":
        return self._safe_call(
            self.get_cif_string,
            format="symmetrized",
            decimal_places=decimal_places,
        )

    elif requested_reps == "slices":
        return self._safe_call(self.get_slices)

    elif requested_reps == "composition":
        return self._safe_call(self.get_composition)

    elif requested_reps == "crystal_text_llm":
        return self._safe_call(self.get_crystal_text_llm)

    elif requested_reps == "robocrys_rep":
        return self._safe_call(self.get_robocrys_rep)

    elif requested_reps == "atom_sequences":
        return self._safe_call(
            self.get_atom_sequences_plusplus,
            lattice_params=False,
            decimal_places=decimal_places,
        )

    elif requested_reps == "atom_sequences_plusplus":
        return self._safe_call(
            self.get_atom_sequences_plusplus,
            lattice_params=True,
            decimal_places=decimal_places,
        )

    elif requested_reps == "zmatrix":
        return self._safe_call(self.get_zmatrix_rep)

    elif requested_reps == "local_env":
        return self._safe_call(self.get_local_env_rep, local_env_kwargs=None)

get_robocrys_rep() 🔗

Source code in xtal2txt/core.py
286
287
288
289
290
291
292
def get_robocrys_rep(self):
    """
    https://github.com/hackingmaterials/robocrystallographer/tree/main
    """

    condensed_structure = self.condenser.condense_structure(self.structure)
    return self.describer.describe(condensed_structure)

get_slices(primitive=True) 🔗

Returns SLICES representation of the crystal structure. https://www.nature.com/articles/s41467-023-42870-7

Parameters:

Name Type Description Default
primitive bool

Whether to use the primitive structure or not.

True

Returns:

Name Type Description
str str

The SLICE representation of the crystal structure.

Source code in xtal2txt/core.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def get_slices(self, primitive: bool = True) -> str:
    """Returns SLICES representation of the crystal structure.
    https://www.nature.com/articles/s41467-023-42870-7

    Args:
        primitive (bool): Whether to use the primitive structure or not.

    Returns:
        str: The SLICE representation of the crystal structure.
    """

    if primitive:
        primitive_structure = (
            self.structure.get_primitive_structure()
        )  # convert to primitive structure
        return self.backend.structure2SLICES(primitive_structure)
    return self.backend.structure2SLICES(self.structure)

get_wyckoff_positions() 🔗

Getting wyckoff positions of the elements in the unit cell as the combination of... number and letter.

Returns:

Name Type Description
str

A multi-line string that contain elements of the unit cell along with their wyckoff position in each line.

Hint

At the end of the string, there is an additional newline character.

Source code in xtal2txt/core.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def get_wyckoff_positions(self):
    """
    Getting wyckoff positions of the elements in the unit cell as the combination of...
    number and letter.

    Returns:
        str:  A multi-line string that contain elements of the unit cell along with their wyckoff position in each line.

    Hint:
        At the end of the string, there is an additional newline character.
    """

    spacegroup_analyzer = SpacegroupAnalyzer(self.structure)
    wyckoff_sites = spacegroup_analyzer.get_symmetry_dataset()
    element_symbols = [site.specie.element.symbol for site in self.structure.sites]

    data = []

    for i in range(len(wyckoff_sites["wyckoffs"])):
        sub_data = (
            element_symbols[i],
            wyckoff_sites["wyckoffs"][i],
            wyckoff_sites["equivalent_atoms"][i],
        )
        data.append(sub_data)

    a = dict(Counter(data))

    output = ""
    for i, j in a.items():
        output += str(i[0]) + " " + str(j) + " " + str(i[1]) + "\n"

    return output

get_wycryst() 🔗

Obtaining the wyckoff representation for crystal structures that include

chemical formula space group number elements of the unit cell with their wyckoff positions.

Returns:

Name Type Description
str

A multi-line string that contains the chemical formula, space group number, and the elements of the unit cell with their wyckoff positions.

Source code in xtal2txt/core.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def get_wycryst(self):
    """
    Obtaining the wyckoff representation for crystal structures that include:
        chemical formula
        space group number
        elements of the unit cell with their wyckoff positions.

    Returns:
        str: A multi-line string that contains the chemical formula, space group number,
            and the elements of the unit cell with their wyckoff positions.
    """
    output = ""
    chemical_formula = self.structure.composition.formula
    output += chemical_formula
    output += "\n" + str(self.structure.get_space_group_info()[1])
    output += "\n" + self.get_wyckoff_positions()

    return output

get_zmatrix_rep(decimal_places=1) 🔗

Generate the Z-matrix representation of the crystal structure. It provides a description of each atom in terms of its atomic number, bond length, bond angle, and dihedral angle, the so-called internal coordinates.

Disclaimer: The Z-matrix is meant for molecules, current implementation converts atoms within unit cell to molecule. Hence the current implentation might overlook bonds acrosse unit cells.

Source code in xtal2txt/core.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def get_zmatrix_rep(self, decimal_places=1):
    """
    Generate the Z-matrix representation of the crystal structure.
    It provides a description of each atom in terms of its atomic number,
    bond length, bond angle, and dihedral angle, the so-called internal coordinates.

    Disclaimer: The Z-matrix is meant for molecules, current implementation converts atoms within unit cell to molecule.
    Hence the current implentation might overlook bonds acrosse unit cells.
    """
    species = [
        s.element if hasattr(s, "element") else s for s in self.structure.species
    ]
    coords = [c for c in self.structure.cart_coords]
    molecule_ = Molecule(
        species,
        coords,
    )
    zmatrix = molecule_.get_zmatrix()
    return self.updated_zmatrix_rep(zmatrix, decimal_places)

round_numbers_in_string(original_string, decimal_places) staticmethod 🔗

Rounds float numbers in the given string to the specified number of decimal places using regex.

Parameters:

Name Type Description Default
original_string str

The input string.

required
decimal_places int

The number of decimal places to round to.

required

Returns:

Name Type Description
str str

The string with the float numbers rounded to the specified number of decimal places.

Source code in xtal2txt/core.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@staticmethod
def round_numbers_in_string(original_string: str, decimal_places: int) -> str:
    """
    Rounds float numbers in the given string to the specified number of decimal places using regex.

    Args:
        original_string (str): The input string.
        decimal_places (int): The number of decimal places to round to.

    Returns:
        str: The string with the float numbers rounded to the specified number of decimal places.
    """
    pattern = r"\b\d+\.\d+\b"
    matches = re.findall(pattern, original_string)
    rounded_numbers = [round(float(match), decimal_places) for match in matches]
    new_string = re.sub(
        pattern, lambda x: str(rounded_numbers.pop(0)), original_string
    )
    return new_string

updated_zmatrix_rep(zmatrix, decimal_places=1) 🔗

    Replace the variables in the Z-matrix with their values and return the updated Z-matrix.
    for eg: z-matrix from pymatgen
    'N

N 1 B1 N 1 B2 2 A2 N 1 B3 2 A3 3 D3

    B1=3.79
    B2=6.54
    ....
    is replaced to
    'N

N 1 3.79 N 1 6.54 2 90 N 1 6.54 2 90 3 120 '

    Args:
        Zmatrix (bool): zmatrix multi line string as implemented in pymatgen.
        decimal_places (int): The number of decimal places to round to.

    Returns:
        str: The updated Z-matrix representation of the crystal structure.
Source code in xtal2txt/core.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def updated_zmatrix_rep(self, zmatrix, decimal_places=1):
    """
    Replace the variables in the Z-matrix with their values and return the updated Z-matrix.
    for eg: z-matrix from pymatgen
    'N\nN 1 B1\nN 1 B2 2 A2\nN 1 B3 2 A3 3 D3\n
    B1=3.79
    B2=6.54
    ....
    is replaced to
    'N\nN 1 3.79\nN 1 6.54 2 90\nN 1 6.54 2 90 3 120\n'

    Args:
        Zmatrix (bool): zmatrix multi line string as implemented in pymatgen.
        decimal_places (int): The number of decimal places to round to.

    Returns:
        str: The updated Z-matrix representation of the crystal structure.
    """
    lines = zmatrix.split("\n")
    main_part = []
    variables_part = []

    # Determine the main part and the variables part of the Z-matrix
    for line in lines:
        if "=" in line:
            variables_part.append(line)
        else:
            if line.strip():  # Skip empty lines
                main_part.append(line)

    # Extract variables from the variables part
    variable_dict = {}
    for var_line in variables_part:
        var, value = var_line.split("=")
        if var.startswith("B"):
            rounded_value = round(float(value.strip()), decimal_places)
        else:
            rounded_value = int(round(float(value.strip())))
        variable_dict[var] = (
            f"{rounded_value}"
            if var.startswith(("A", "D"))
            else f"{rounded_value:.{decimal_places}f}"
        )

    # Replace variables in the main part
    replaced_lines = []
    for line in main_part:
        parts = line.split()
        # atom = parts[0]
        replaced_line = line
        for i in range(1, len(parts)):
            var = parts[i]
            if var in variable_dict:
                replaced_line = replaced_line.replace(var, variable_dict[var])
        replaced_lines.append(replaced_line)

    return "\n".join(replaced_lines)

Decoding🔗

DecodeTextRep 🔗

Source code in xtal2txt/decoder.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class DecodeTextRep:
    def __init__(self, text):
        self.text = text

    def decode(self):
        return self.text

    def wyckoff_decoder(self, input: str, lattice_params: bool = False):
        """
        Generating a pymatgen object from the output of the get_wyckoff_rep() method by using...
        pyxtal package. In this method, all data are extracted from the multi-line string of the...
        mentioned method.
        In pyxtal package, a 3D crystal is produced by specifying the dimensions, elements,
        composition of elements, space group, and sites as wyckoff positions of the elements.

        Params:
            lattice_params: boolean
                To specify whether use lattice parameters in generating crystal structure.

        Returns:
            pmg_struc: pymatgen.core.structure.Structure
        """

        # Always dimension is 3.
        dimensions = 3

        entities = input.split("\n")[:-1]
        elements = entities[0]
        spg = int(entities[1])
        wyckoff_sites = entities[2:]
        elements = elements.split(" ")

        atoms = []
        composition = []
        for el in elements:
            atom = el.rstrip("0123456789")
            number = el[len(atom) :]
            atoms.append(atom)
            composition.append(int(number))

        sites = []
        for atom in atoms:
            sub_site = []
            for site in wyckoff_sites:
                if atom in site:
                    sub_site.append(site.split()[1])

            sites.append(sub_site)

        xtal_struc = pyxtal()

        if lattice_params:
            a, b, c, alpha, beta, gamma = self.get_lattice_parameters()
            cell = pyLattice.from_para(
                float(a), float(b), float(c), float(alpha), float(beta), float(gamma)
            )
            xtal_struc.from_random(dimensions, spg, atoms, composition, sites=sites, lattice=cell)
        else:
            xtal_struc.from_random(dimensions, spg, atoms, composition, sites=sites)

        pmg_struc = xtal_struc.to_pymatgen()

        return pmg_struc

    def llm_decoder(self, input: str):
        """
        Returning pymatgen structure out of multi-line representation.

        Params:
            input: str
                String to obtain the items needed for the structure.

        Returns:
            pymatgen.core.structure.Structure
        """
        entities = input.split("\n")
        lengths = entities[0].split(" ")
        angles = entities[1].split(" ")
        lattice = Lattice.from_parameters(
            a=float(lengths[0]),
            b=float(lengths[1]),
            c=float(lengths[2]),
            alpha=float(angles[0]),
            beta=float(angles[1]),
            gamma=float(angles[2]),
        )

        elements = entities[2::2]
        coordinates = entities[3::2]
        m_coord = []
        for i in coordinates:
            s = [float(j) for j in i.split(" ")]
            m_coord.append(s)

        return Structure(lattice, elements, m_coord)

    def cif_string_decoder_p1(self, input: str):
        """
        Returning a pymatgen structure out of a string format of a cif file.

        Params:
            input: str
                String to obtain the items needed for the structure.

        Returns:
            pymatgen.core.structure.Structure
        """
        entities = input.split("\n")[:-1]

        params = []
        for i in range(2, 8):
            params.append(entities[i].split("   ")[1])

        lattice = Lattice.from_parameters(
            a=float(params[0]),
            b=float(params[1]),
            c=float(params[2]),
            alpha=float(params[3]),
            beta=float(params[4]),
            gamma=float(params[5]),
        )

        elements = []
        m_coord = []
        atoms = entities[entities.index(" _atom_site_occupancy") + 1 :]
        for atom in atoms:
            ls = atom.split("  ")
            elements.append(ls[1])
            m_coord.append([float(ls[4]), float(ls[5]), float(ls[6])])

        return Structure(lattice, elements, m_coord)

    def cif_string_decoder_sym(self, input: str):
        """
        Returning a pymatgen structure out of a string format of a symmetrized cif file.

        Params:
            input: str
                String to obtain the items needed for the structure.

        Returns:
            pymatgen.core.structure.Structure
        """
        entities = input.split("\n")[:-1]

        params = []
        for i in range(1, 8):
            params.append(entities[i].split("   ")[1])

        spg = params[0]
        params = params[1:]
        lattice = Lattice.from_parameters(
            a=float(params[0]),
            b=float(params[1]),
            c=float(params[2]),
            alpha=float(params[3]),
            beta=float(params[4]),
            gamma=float(params[5]),
        )

        elements = []
        m_coord = []
        atoms = entities[entities.index(" _atom_site_occupancy") + 1 :]
        for atom in atoms:
            ls = atom.split("  ")
            elements.append(ls[1])
            m_coord.append([float(ls[4]), float(ls[5]), float(ls[6])])

        # print(atoms)

        return Structure.from_spacegroup(spg, lattice, elements, m_coord)

cif_string_decoder_p1(input) 🔗

Returning a pymatgen structure out of a string format of a cif file.

Parameters:

Name Type Description Default
input str

str String to obtain the items needed for the structure.

required

Returns:

Type Description

pymatgen.core.structure.Structure

Source code in xtal2txt/decoder.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def cif_string_decoder_p1(self, input: str):
    """
    Returning a pymatgen structure out of a string format of a cif file.

    Params:
        input: str
            String to obtain the items needed for the structure.

    Returns:
        pymatgen.core.structure.Structure
    """
    entities = input.split("\n")[:-1]

    params = []
    for i in range(2, 8):
        params.append(entities[i].split("   ")[1])

    lattice = Lattice.from_parameters(
        a=float(params[0]),
        b=float(params[1]),
        c=float(params[2]),
        alpha=float(params[3]),
        beta=float(params[4]),
        gamma=float(params[5]),
    )

    elements = []
    m_coord = []
    atoms = entities[entities.index(" _atom_site_occupancy") + 1 :]
    for atom in atoms:
        ls = atom.split("  ")
        elements.append(ls[1])
        m_coord.append([float(ls[4]), float(ls[5]), float(ls[6])])

    return Structure(lattice, elements, m_coord)

cif_string_decoder_sym(input) 🔗

Returning a pymatgen structure out of a string format of a symmetrized cif file.

Parameters:

Name Type Description Default
input str

str String to obtain the items needed for the structure.

required

Returns:

Type Description

pymatgen.core.structure.Structure

Source code in xtal2txt/decoder.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def cif_string_decoder_sym(self, input: str):
    """
    Returning a pymatgen structure out of a string format of a symmetrized cif file.

    Params:
        input: str
            String to obtain the items needed for the structure.

    Returns:
        pymatgen.core.structure.Structure
    """
    entities = input.split("\n")[:-1]

    params = []
    for i in range(1, 8):
        params.append(entities[i].split("   ")[1])

    spg = params[0]
    params = params[1:]
    lattice = Lattice.from_parameters(
        a=float(params[0]),
        b=float(params[1]),
        c=float(params[2]),
        alpha=float(params[3]),
        beta=float(params[4]),
        gamma=float(params[5]),
    )

    elements = []
    m_coord = []
    atoms = entities[entities.index(" _atom_site_occupancy") + 1 :]
    for atom in atoms:
        ls = atom.split("  ")
        elements.append(ls[1])
        m_coord.append([float(ls[4]), float(ls[5]), float(ls[6])])

    # print(atoms)

    return Structure.from_spacegroup(spg, lattice, elements, m_coord)

llm_decoder(input) 🔗

Returning pymatgen structure out of multi-line representation.

Parameters:

Name Type Description Default
input str

str String to obtain the items needed for the structure.

required

Returns:

Type Description

pymatgen.core.structure.Structure

Source code in xtal2txt/decoder.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def llm_decoder(self, input: str):
    """
    Returning pymatgen structure out of multi-line representation.

    Params:
        input: str
            String to obtain the items needed for the structure.

    Returns:
        pymatgen.core.structure.Structure
    """
    entities = input.split("\n")
    lengths = entities[0].split(" ")
    angles = entities[1].split(" ")
    lattice = Lattice.from_parameters(
        a=float(lengths[0]),
        b=float(lengths[1]),
        c=float(lengths[2]),
        alpha=float(angles[0]),
        beta=float(angles[1]),
        gamma=float(angles[2]),
    )

    elements = entities[2::2]
    coordinates = entities[3::2]
    m_coord = []
    for i in coordinates:
        s = [float(j) for j in i.split(" ")]
        m_coord.append(s)

    return Structure(lattice, elements, m_coord)

wyckoff_decoder(input, lattice_params=False) 🔗

Generating a pymatgen object from the output of the get_wyckoff_rep() method by using... pyxtal package. In this method, all data are extracted from the multi-line string of the... mentioned method. In pyxtal package, a 3D crystal is produced by specifying the dimensions, elements, composition of elements, space group, and sites as wyckoff positions of the elements.

Parameters:

Name Type Description Default
lattice_params bool

boolean To specify whether use lattice parameters in generating crystal structure.

False

Returns:

Name Type Description
pmg_struc

pymatgen.core.structure.Structure

Source code in xtal2txt/decoder.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def wyckoff_decoder(self, input: str, lattice_params: bool = False):
    """
    Generating a pymatgen object from the output of the get_wyckoff_rep() method by using...
    pyxtal package. In this method, all data are extracted from the multi-line string of the...
    mentioned method.
    In pyxtal package, a 3D crystal is produced by specifying the dimensions, elements,
    composition of elements, space group, and sites as wyckoff positions of the elements.

    Params:
        lattice_params: boolean
            To specify whether use lattice parameters in generating crystal structure.

    Returns:
        pmg_struc: pymatgen.core.structure.Structure
    """

    # Always dimension is 3.
    dimensions = 3

    entities = input.split("\n")[:-1]
    elements = entities[0]
    spg = int(entities[1])
    wyckoff_sites = entities[2:]
    elements = elements.split(" ")

    atoms = []
    composition = []
    for el in elements:
        atom = el.rstrip("0123456789")
        number = el[len(atom) :]
        atoms.append(atom)
        composition.append(int(number))

    sites = []
    for atom in atoms:
        sub_site = []
        for site in wyckoff_sites:
            if atom in site:
                sub_site.append(site.split()[1])

        sites.append(sub_site)

    xtal_struc = pyxtal()

    if lattice_params:
        a, b, c, alpha, beta, gamma = self.get_lattice_parameters()
        cell = pyLattice.from_para(
            float(a), float(b), float(c), float(alpha), float(beta), float(gamma)
        )
        xtal_struc.from_random(dimensions, spg, atoms, composition, sites=sites, lattice=cell)
    else:
        xtal_struc.from_random(dimensions, spg, atoms, composition, sites=sites)

    pmg_struc = xtal_struc.to_pymatgen()

    return pmg_struc

MatchRep 🔗

Source code in xtal2txt/decoder.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class MatchRep:
    def __init__(self, textrep, structure):
        self.text = textrep
        self.structure = structure

    def wyckoff_matcher(
        self,
        ltol=0.2,
        stol=0.5,
        angle_tol=5,
        primitive_cell=True,
        scale=True,
        allow_subset=True,
        attempt_supercell=True,
        lattice_params: bool = False,
    ):
        """
        To check if pymatgen object from the original cif file match with the generated...
        pymatgen structure from wyckoff_decoder method out of wyckoff representation...
        using fit() method of StructureMatcher module in pymatgen package.

        Params:
            StructureMatcher module can be access in below link with its parameters:
                https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping
            lattice_params: bool
                To specify using lattice parameters in the wyckoff_decoder method.

        Returns:
            StructureMatcher().fit_anonymous(): bool
        """

        original_struct = self.structure

        # output_struct = self.wyckoff_decoder(input, lattice_params)
        output_struct = DecodeTextRep(self.text).wyckoff_decoder(self.text, lattice_params=True)

        return StructureMatcher(
            ltol,
            stol,
            angle_tol,
            primitive_cell,
            scale,
            allow_subset,
            attempt_supercell,
        ).fit_anonymous(output_struct, original_struct)

    def llm_matcher(
        self,
        ltol=0.2,
        stol=0.5,
        angle_tol=5,
        primitive_cell=True,
        scale=True,
        allow_subset=True,
        attempt_supercell=True,
    ):
        """
        To check if pymatgen object from the original cif file match with the generated
        pymatgen structure from llm_decoder method out of llm representation
        using fit() method of StructureMatcher module in pymatgen package.

        Params:
            input: str
                String to obtain the items needed for the structure.

            StructureMatcher module can be access in below link with its parameters:
                https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

        Returns:
            StructureMatcher().fit(): bool
        """

        original_struct = self.structure
        output_struct = DecodeTextRep(self.text).llm_decoder(self.text)

        return StructureMatcher(
            ltol,
            stol,
            angle_tol,
            primitive_cell,
            scale,
            allow_subset,
            attempt_supercell,
        ).fit(output_struct, original_struct)

    def cif_string_matcher_sym(
        self,
        #        input: str,
        ltol=0.2,
        stol=0.5,
        angle_tol=5,
        primitive_cell=True,
        scale=True,
        allow_subset=True,
        attempt_supercell=True,
    ):
        """
        To check if pymatgen object from the original cif file match with the generated
        pymatgen structure from cif_string_decoder_sym method out of string cif representation.
        using fit() method of StructureMatcher module in pymatgen package.

        Params:
            input: str
                String to obtain the items needed for the structure.

            StructureMatcher module can be access in below link with its parameters:
                https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

        Returns:
            StructureMatcher().fit(): bool
        """

        original_struct = self.structure
        output_struct = DecodeTextRep(self.text).cif_string_decoder_sym(self.text)

        return StructureMatcher(
            ltol,
            stol,
            angle_tol,
            primitive_cell,
            scale,
            allow_subset,
            attempt_supercell,
        ).fit(output_struct, original_struct)

    def cif_string_matcher_p1(
        self,
        ltol=0.2,
        stol=0.5,
        angle_tol=5,
        primitive_cell=True,
        scale=True,
        allow_subset=True,
        attempt_supercell=True,
    ):
        """
        To check if pymatgen object from the original cif file match with the generated
        pymatgen structure from cif_string_decoder_p1 method out of string cif representation
        using fit() method of StructureMatcher module in pymatgen package.

        Params:
            input: str
                String to obtain the items needed for the structure.

            StructureMatcher module can be access in below link with its parameters:
                https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

        Returns:
            StructureMatcher().fit(): bool
        """

        original_struct = self.structure
        output_struct = DecodeTextRep(self.text).cif_string_decoder_p1(self.text)

        return StructureMatcher(
            ltol,
            stol,
            angle_tol,
            primitive_cell,
            scale,
            allow_subset,
            attempt_supercell,
        ).fit(output_struct, original_struct)

cif_string_matcher_p1(ltol=0.2, stol=0.5, angle_tol=5, primitive_cell=True, scale=True, allow_subset=True, attempt_supercell=True) 🔗

To check if pymatgen object from the original cif file match with the generated pymatgen structure from cif_string_decoder_p1 method out of string cif representation using fit() method of StructureMatcher module in pymatgen package.

Parameters:

Name Type Description Default
input

str String to obtain the items needed for the structure.

required
StructureMatcher module can be access in below link with its parameters

https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

required

Returns:

Name Type Description
StructureMatcher ).fit(

bool

Source code in xtal2txt/decoder.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def cif_string_matcher_p1(
    self,
    ltol=0.2,
    stol=0.5,
    angle_tol=5,
    primitive_cell=True,
    scale=True,
    allow_subset=True,
    attempt_supercell=True,
):
    """
    To check if pymatgen object from the original cif file match with the generated
    pymatgen structure from cif_string_decoder_p1 method out of string cif representation
    using fit() method of StructureMatcher module in pymatgen package.

    Params:
        input: str
            String to obtain the items needed for the structure.

        StructureMatcher module can be access in below link with its parameters:
            https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

    Returns:
        StructureMatcher().fit(): bool
    """

    original_struct = self.structure
    output_struct = DecodeTextRep(self.text).cif_string_decoder_p1(self.text)

    return StructureMatcher(
        ltol,
        stol,
        angle_tol,
        primitive_cell,
        scale,
        allow_subset,
        attempt_supercell,
    ).fit(output_struct, original_struct)

cif_string_matcher_sym(ltol=0.2, stol=0.5, angle_tol=5, primitive_cell=True, scale=True, allow_subset=True, attempt_supercell=True) 🔗

To check if pymatgen object from the original cif file match with the generated pymatgen structure from cif_string_decoder_sym method out of string cif representation. using fit() method of StructureMatcher module in pymatgen package.

Parameters:

Name Type Description Default
input

str String to obtain the items needed for the structure.

required
StructureMatcher module can be access in below link with its parameters

https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

required

Returns:

Name Type Description
StructureMatcher ).fit(

bool

Source code in xtal2txt/decoder.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def cif_string_matcher_sym(
    self,
    #        input: str,
    ltol=0.2,
    stol=0.5,
    angle_tol=5,
    primitive_cell=True,
    scale=True,
    allow_subset=True,
    attempt_supercell=True,
):
    """
    To check if pymatgen object from the original cif file match with the generated
    pymatgen structure from cif_string_decoder_sym method out of string cif representation.
    using fit() method of StructureMatcher module in pymatgen package.

    Params:
        input: str
            String to obtain the items needed for the structure.

        StructureMatcher module can be access in below link with its parameters:
            https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

    Returns:
        StructureMatcher().fit(): bool
    """

    original_struct = self.structure
    output_struct = DecodeTextRep(self.text).cif_string_decoder_sym(self.text)

    return StructureMatcher(
        ltol,
        stol,
        angle_tol,
        primitive_cell,
        scale,
        allow_subset,
        attempt_supercell,
    ).fit(output_struct, original_struct)

llm_matcher(ltol=0.2, stol=0.5, angle_tol=5, primitive_cell=True, scale=True, allow_subset=True, attempt_supercell=True) 🔗

To check if pymatgen object from the original cif file match with the generated pymatgen structure from llm_decoder method out of llm representation using fit() method of StructureMatcher module in pymatgen package.

Parameters:

Name Type Description Default
input

str String to obtain the items needed for the structure.

required
StructureMatcher module can be access in below link with its parameters

https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

required

Returns:

Name Type Description
StructureMatcher ).fit(

bool

Source code in xtal2txt/decoder.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def llm_matcher(
    self,
    ltol=0.2,
    stol=0.5,
    angle_tol=5,
    primitive_cell=True,
    scale=True,
    allow_subset=True,
    attempt_supercell=True,
):
    """
    To check if pymatgen object from the original cif file match with the generated
    pymatgen structure from llm_decoder method out of llm representation
    using fit() method of StructureMatcher module in pymatgen package.

    Params:
        input: str
            String to obtain the items needed for the structure.

        StructureMatcher module can be access in below link with its parameters:
            https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

    Returns:
        StructureMatcher().fit(): bool
    """

    original_struct = self.structure
    output_struct = DecodeTextRep(self.text).llm_decoder(self.text)

    return StructureMatcher(
        ltol,
        stol,
        angle_tol,
        primitive_cell,
        scale,
        allow_subset,
        attempt_supercell,
    ).fit(output_struct, original_struct)

wyckoff_matcher(ltol=0.2, stol=0.5, angle_tol=5, primitive_cell=True, scale=True, allow_subset=True, attempt_supercell=True, lattice_params=False) 🔗

To check if pymatgen object from the original cif file match with the generated... pymatgen structure from wyckoff_decoder method out of wyckoff representation... using fit() method of StructureMatcher module in pymatgen package.

Parameters:

Name Type Description Default
StructureMatcher module can be access in below link with its parameters

https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping

required
lattice_params bool

bool To specify using lattice parameters in the wyckoff_decoder method.

False

Returns:

Name Type Description
StructureMatcher ).fit_anonymous(

bool

Source code in xtal2txt/decoder.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def wyckoff_matcher(
    self,
    ltol=0.2,
    stol=0.5,
    angle_tol=5,
    primitive_cell=True,
    scale=True,
    allow_subset=True,
    attempt_supercell=True,
    lattice_params: bool = False,
):
    """
    To check if pymatgen object from the original cif file match with the generated...
    pymatgen structure from wyckoff_decoder method out of wyckoff representation...
    using fit() method of StructureMatcher module in pymatgen package.

    Params:
        StructureMatcher module can be access in below link with its parameters:
            https://pymatgen.org/pymatgen.analysis.html#pymatgen.analysis.structure_matcher.StructureMatcher.get_mapping
        lattice_params: bool
            To specify using lattice parameters in the wyckoff_decoder method.

    Returns:
        StructureMatcher().fit_anonymous(): bool
    """

    original_struct = self.structure

    # output_struct = self.wyckoff_decoder(input, lattice_params)
    output_struct = DecodeTextRep(self.text).wyckoff_decoder(self.text, lattice_params=True)

    return StructureMatcher(
        ltol,
        stol,
        angle_tol,
        primitive_cell,
        scale,
        allow_subset,
        attempt_supercell,
    ).fit_anonymous(output_struct, original_struct)

Transformations🔗

Tokenizer🔗

CifTokenizer 🔗

Bases: Xtal2txtTokenizer

Source code in xtal2txt/tokenizer.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class CifTokenizer(Xtal2txtTokenizer):
    def __init__(
        self,
        special_num_token: bool = False,
        vocab_file=None,
        model_max_length=None,
        padding_length=None,
        **kwargs,
    ):
        if special_num_token:
            vocab_file = CIF_RT_VOCAB
        else:
            vocab_file = CIF_VOCAB
        super(CifTokenizer, self).__init__(
            special_num_token=special_num_token,
            vocab_file=vocab_file,
            model_max_length=model_max_length,
            padding_length=padding_length,
            **kwargs,
        )

    def token_analysis(self, list_of_tokens):
        """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
        token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
        """
        analysis_masks = ANALYSIS_MASK_TOKENS
        token_type = CIF_ANALYSIS_DICT
        return [
            analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
            for token in list_of_tokens
        ]

token_analysis(list_of_tokens) 🔗

Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.

Source code in xtal2txt/tokenizer.py
461
462
463
464
465
466
467
468
469
470
def token_analysis(self, list_of_tokens):
    """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
    token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
    """
    analysis_masks = ANALYSIS_MASK_TOKENS
    token_type = CIF_ANALYSIS_DICT
    return [
        analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
        for token in list_of_tokens
    ]

CompositionTokenizer 🔗

Bases: Xtal2txtTokenizer

Source code in xtal2txt/tokenizer.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class CompositionTokenizer(Xtal2txtTokenizer):
    def __init__(
        self,
        special_num_token: bool = False,
        vocab_file=None,
        model_max_length=None,
        padding_length=None,
        **kwargs,
    ):
        if special_num_token:
            vocab_file = COMPOSITION_RT_VOCAB if vocab_file is None else vocab_file
        else:
            vocab_file = COMPOSITION_VOCAB if vocab_file is None else vocab_file
        super(CompositionTokenizer, self).__init__(
            special_num_token=special_num_token,
            vocab_file=vocab_file,
            model_max_length=model_max_length,
            padding_length=padding_length,
            **kwargs,
        )

    def token_analysis(self, list_of_tokens):
        """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
        token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
        """
        analysis_masks = ANALYSIS_MASK_TOKENS
        token_type = COMPOSITION_ANALYSIS_DICT
        return [
            analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
            for token in list_of_tokens
        ]

token_analysis(list_of_tokens) 🔗

Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.

Source code in xtal2txt/tokenizer.py
428
429
430
431
432
433
434
435
436
437
def token_analysis(self, list_of_tokens):
    """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
    token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
    """
    analysis_masks = ANALYSIS_MASK_TOKENS
    token_type = COMPOSITION_ANALYSIS_DICT
    return [
        analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
        for token in list_of_tokens
    ]

CrysllmTokenizer 🔗

Bases: Xtal2txtTokenizer

Source code in xtal2txt/tokenizer.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
class CrysllmTokenizer(Xtal2txtTokenizer):
    def __init__(
        self,
        special_num_token: bool = False,
        vocab_file=CRYSTAL_LLM_VOCAB,
        model_max_length=None,
        padding_length=None,
        **kwargs,
    ):
        if special_num_token:
            vocab_file = CRYSTAL_LLM_RT_VOCAB
        else:
            vocab_file = CRYSTAL_LLM_VOCAB
        super(CrysllmTokenizer, self).__init__(
            special_num_token=special_num_token,
            vocab_file=vocab_file,
            model_max_length=model_max_length,
            padding_length=padding_length,
            **kwargs,
        )

    def token_analysis(self, list_of_tokens):
        """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
        token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
        """
        analysis_masks = ANALYSIS_MASK_TOKENS
        token_type = CRYSTAL_LLM_ANALYSIS_DICT
        return [
            analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
            for token in list_of_tokens
        ]

token_analysis(list_of_tokens) 🔗

Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.

Source code in xtal2txt/tokenizer.py
494
495
496
497
498
499
500
501
502
503
def token_analysis(self, list_of_tokens):
    """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
    token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
    """
    analysis_masks = ANALYSIS_MASK_TOKENS
    token_type = CRYSTAL_LLM_ANALYSIS_DICT
    return [
        analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
        for token in list_of_tokens
    ]

NumTokenizer 🔗

Tokenize numbers as implemented in Regression Transformer. https://www.nature.com/articles/s42256-023-00639-z https://github.com/IBM/regression-transformer/tree/main

Source code in xtal2txt/tokenizer.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class NumTokenizer:
    """Tokenize numbers as implemented in Regression Transformer.
    https://www.nature.com/articles/s42256-023-00639-z
    https://github.com/IBM/regression-transformer/tree/main"""

    def __init__(self) -> None:
        """Tokenizer for numbers."""
        self.regex = re.compile(r"(\+|-)?(\d+)(\.)?(\d+)?\s*")

    def num_matcher(self, text: str) -> str:
        """Extract numbers from a sentence and replace them with tokens."""
        # pattern = re.findall(r'(\d+\.\d+|\d+)', text)  # This regex captures both whole numbers and decimal numbers

        pattern = (
            r"\d+(?:\.\d+)?"  # Match any number, whether it is part of a string or not
        )
        matches = list(re.finditer(pattern, text))
        for match in reversed(
            matches
        ):  # since we are replacing substring with a bigger subtring the string we are working on
            start, end = match.start(), match.end()
            tokens = self.tokenize(match.group())
            replacement = "".join(tokens)
            text = text[:start] + replacement + text[end:]
        return text

    def tokenize(self, text: str) -> List[str]:
        """Tokenization of numbers as in RT.
         '0.9' -> '_0_0_', '_._', '_9_-1_'

        Args:
            text: number as string to be tokenized.

        Returns:
            extracted tokens.
        """
        tokens = []
        matched = self.regex.match(text)
        if matched:
            sign, units, dot, decimals = matched.groups()
            tokens = []
            if sign:
                tokens += [f"_{sign}_"]
            tokens += [
                f"_{number}_{position}_" for position, number in enumerate(units[::-1])
            ][::-1]
            if dot:
                tokens += [f"_{dot}_"]
            if decimals:
                tokens += [
                    f"_{number}_-{position}_"
                    for position, number in enumerate(decimals, 1)
                ]
        return tokens

    @staticmethod
    def convert_tokens_to_float(tokens: List[str]) -> float:
        """Converts tokens representing a float value into a float.
        NOTE: Expects that non-floating tokens are strippped off

        Args:
            tokens: List of tokens, each representing a float.
                E.g.: ['_0_0_', '_._', '_9_-1_', '_3_-2_', '_1_-3_']

        Returns:
            float: Float representation for the list of tokens.
        """
        try:
            float_string = "".join([token.split("_")[1] for token in tokens])
            float_value = float(float_string)
        except ValueError:
            float_value = -1
        return float_value

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Converts tokens to string.

        Args:
            tokens: List of tokens.

        Returns:
            str: String representation of the tokens.
        """
        return "".join([token.split("_")[1] for token in tokens])

__init__() 🔗

Tokenizer for numbers.

Source code in xtal2txt/tokenizer.py
87
88
89
def __init__(self) -> None:
    """Tokenizer for numbers."""
    self.regex = re.compile(r"(\+|-)?(\d+)(\.)?(\d+)?\s*")

convert_tokens_to_float(tokens) staticmethod 🔗

Converts tokens representing a float value into a float. NOTE: Expects that non-floating tokens are strippped off

Parameters:

Name Type Description Default
tokens List[str]

List of tokens, each representing a float. E.g.: ['0_0', '.', '9-1_', '3-2_', '1-3_']

required

Returns:

Name Type Description
float float

Float representation for the list of tokens.

Source code in xtal2txt/tokenizer.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@staticmethod
def convert_tokens_to_float(tokens: List[str]) -> float:
    """Converts tokens representing a float value into a float.
    NOTE: Expects that non-floating tokens are strippped off

    Args:
        tokens: List of tokens, each representing a float.
            E.g.: ['_0_0_', '_._', '_9_-1_', '_3_-2_', '_1_-3_']

    Returns:
        float: Float representation for the list of tokens.
    """
    try:
        float_string = "".join([token.split("_")[1] for token in tokens])
        float_value = float(float_string)
    except ValueError:
        float_value = -1
    return float_value

convert_tokens_to_string(tokens) 🔗

Converts tokens to string.

Parameters:

Name Type Description Default
tokens List[str]

List of tokens.

required

Returns:

Name Type Description
str str

String representation of the tokens.

Source code in xtal2txt/tokenizer.py
156
157
158
159
160
161
162
163
164
165
def convert_tokens_to_string(self, tokens: List[str]) -> str:
    """Converts tokens to string.

    Args:
        tokens: List of tokens.

    Returns:
        str: String representation of the tokens.
    """
    return "".join([token.split("_")[1] for token in tokens])

num_matcher(text) 🔗

Extract numbers from a sentence and replace them with tokens.

Source code in xtal2txt/tokenizer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def num_matcher(self, text: str) -> str:
    """Extract numbers from a sentence and replace them with tokens."""
    # pattern = re.findall(r'(\d+\.\d+|\d+)', text)  # This regex captures both whole numbers and decimal numbers

    pattern = (
        r"\d+(?:\.\d+)?"  # Match any number, whether it is part of a string or not
    )
    matches = list(re.finditer(pattern, text))
    for match in reversed(
        matches
    ):  # since we are replacing substring with a bigger subtring the string we are working on
        start, end = match.start(), match.end()
        tokens = self.tokenize(match.group())
        replacement = "".join(tokens)
        text = text[:start] + replacement + text[end:]
    return text

tokenize(text) 🔗

Tokenization of numbers as in RT. '0.9' -> '0_0', '.', '9-1_'

Parameters:

Name Type Description Default
text str

number as string to be tokenized.

required

Returns:

Type Description
List[str]

extracted tokens.

Source code in xtal2txt/tokenizer.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def tokenize(self, text: str) -> List[str]:
    """Tokenization of numbers as in RT.
     '0.9' -> '_0_0_', '_._', '_9_-1_'

    Args:
        text: number as string to be tokenized.

    Returns:
        extracted tokens.
    """
    tokens = []
    matched = self.regex.match(text)
    if matched:
        sign, units, dot, decimals = matched.groups()
        tokens = []
        if sign:
            tokens += [f"_{sign}_"]
        tokens += [
            f"_{number}_{position}_" for position, number in enumerate(units[::-1])
        ][::-1]
        if dot:
            tokens += [f"_{dot}_"]
        if decimals:
            tokens += [
                f"_{number}_-{position}_"
                for position, number in enumerate(decimals, 1)
            ]
    return tokens

RobocrysTokenizer 🔗

Tokenizer for Robocrystallographer. Would be BPE tokenizer. trained on the Robocrystallographer dataset. TODO: Implement this tokenizer.

Source code in xtal2txt/tokenizer.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
class RobocrysTokenizer:
    """Tokenizer for Robocrystallographer. Would be BPE tokenizer.
    trained on the Robocrystallographer dataset.
    TODO: Implement this tokenizer.
    """

    def __init__(self, vocab_file=ROBOCRYS_VOCAB, **kwargs):
        tokenizer = Tokenizer.from_file(vocab_file)
        wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
        self._tokenizer = wrapped_tokenizer

    def tokenize(self, text):
        return self._tokenizer.tokenize(text)

    def encode(self, text):
        return self._tokenizer.encode(text)

    def decode(self, token_ids, skip_special_tokens=True):
        # Check if token_ids is a string and convert it to a list of integers
        if isinstance(token_ids, str):
            token_ids = [int(token_ids)]
        return self._tokenizer.decode(
            token_ids, skip_special_tokens=skip_special_tokens
        )

SliceTokenizer 🔗

Bases: Xtal2txtTokenizer

Source code in xtal2txt/tokenizer.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class SliceTokenizer(Xtal2txtTokenizer):
    def __init__(
        self,
        special_num_token: bool = False,
        vocab_file=None,
        model_max_length=None,
        padding_length=None,
        **kwargs,
    ):
        if special_num_token:
            vocab_file = SLICE_RT_VOCAB if vocab_file is None else vocab_file
        else:
            vocab_file = SLICE_VOCAB if vocab_file is None else vocab_file
        super(SliceTokenizer, self).__init__(
            special_num_token=special_num_token,
            vocab_file=vocab_file,
            model_max_length=model_max_length,
            padding_length=padding_length,
            **kwargs,
        )

    def convert_tokens_to_string(self, tokens):
        """Converts tokens to string."""
        if self.special_num_tokens:
            return " ".join(
                [
                    (
                        token
                        if not (token.startswith("_") and token.endswith("_"))
                        else token.split("_")[1]
                    )
                    for token in tokens
                ]
            )
        return " ".join(tokens).rstrip()

    def token_analysis(self, list_of_tokens):
        """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
        token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
        """
        analysis_masks = ANALYSIS_MASK_TOKENS
        token_type = SLICE_ANALYSIS_DICT
        return [
            analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
            for token in list_of_tokens
        ]

convert_tokens_to_string(tokens) 🔗

Converts tokens to string.

Source code in xtal2txt/tokenizer.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def convert_tokens_to_string(self, tokens):
    """Converts tokens to string."""
    if self.special_num_tokens:
        return " ".join(
            [
                (
                    token
                    if not (token.startswith("_") and token.endswith("_"))
                    else token.split("_")[1]
                )
                for token in tokens
            ]
        )
    return " ".join(tokens).rstrip()

token_analysis(list_of_tokens) 🔗

Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.

Source code in xtal2txt/tokenizer.py
395
396
397
398
399
400
401
402
403
404
def token_analysis(self, list_of_tokens):
    """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
    token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
    """
    analysis_masks = ANALYSIS_MASK_TOKENS
    token_type = SLICE_ANALYSIS_DICT
    return [
        analysis_masks[next((k for k, v in token_type.items() if token in v), None)]
        for token in list_of_tokens
    ]

Models🔗

Matbenchmark 🔗

Class to perform predictions on Matbench datasets.

Args: - task_cfg (DictConfig): Configuration dictionary containing task parameters.

Source code in src/mattext/models/benchmark.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class Matbenchmark:
    """
    Class to perform predictions on Matbench datasets.

    Args:
    - task_cfg (DictConfig): Configuration dictionary containing task parameters.
    """

    def __init__(self, task_cfg: DictConfig):
        """
        Initializes the object with the given task configuration.

        Parameters:
            task_cfg (DictConfig): The configuration dictionary containing task parameters.

        Returns:
            None
        """
        self.task_cfg = task_cfg
        self.representation = self.task_cfg.model.representation
        self.task = self.task_cfg.model.dataset
        self.task_type = self.task_cfg.model.dataset_type
        self.benchmark = self.task_cfg.model.inference.benchmark_dataset
        self.exp_names = self.task_cfg.model.finetune.exp_name
        self.test_exp_names = self.task_cfg.model.inference.exp_name
        self.train_data = self.task_cfg.model.finetune.dataset_name
        self.test_data = self.task_cfg.model.inference.benchmark_dataset
        self.benchmark_save_path = self.task_cfg.model.inference.benchmark_save_file

        # override wandb project name & tokenizer
        self.wandb_project = self.task_cfg.model.logging.wandb_project

    def run_benchmarking(self, local_rank=None) -> None:
        """
        Runs benchmarking on the specified dataset.

        Args:
            local_rank (int, optional): The local rank for distributed training. Defaults to None.

        Returns:
            None

        Raises:
            Exception: If an error occurs during inference for a finetuned checkpoint.

        """
        if self.task_type == "matbench":
            mb = MatbenchBenchmark(autoload=False)
            task = getattr(mb, MATTEXT_MATBENCH[self.task])
            task.load()
        else:
            task = MatTextTask(task_name=self.task)

        for i, (exp_name, test_name) in enumerate(
            zip(self.exp_names, self.test_exp_names)
        ):
            print(
                f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}"
            )
            wandb.init(
                config=dict(self.task_cfg.model.finetune),
                project=self.task_cfg.model.logging.wandb_project,
                name=exp_name,
            )
            fold_name = fold_key_namer(i)
            print("-------------------------")
            print(fold_name)
            print("-------------------------")

            exp_cfg = self.task_cfg.copy()
            exp_cfg.model.finetune.exp_name = exp_name
            exp_cfg.model.finetune.path.finetune_traindata = self.train_data

            finetuner = FinetuneModel(exp_cfg, local_rank, fold=fold_name)
            ckpt = finetuner.finetune()
            print("-------------------------")
            print(ckpt)
            print("-------------------------")

            wandb.init(
                config=dict(self.task_cfg.model.inference),
                project=self.task_cfg.model.logging.wandb_project,
                name=test_name,
            )

            exp_cfg.model.inference.path.test_data = self.test_data
            exp_cfg.model.inference.path.pretrained_checkpoint = ckpt

            try:
                predict = Inference(exp_cfg, fold=fold_name)
                predictions, prediction_ids = predict.predict()
                print(len(prediction_ids), len(predictions))

                if self.task_type == "matbench":
                    task.record(i, predictions)
                else:
                    task.record_fold(
                        fold=i, prediction_ids=prediction_ids, predictions=predictions
                    )

            except Exception as e:
                print(
                    f"Error occurred during inference for finetuned checkpoint '{exp_name}':"
                )
                print(traceback.format_exc())

        if not os.path.exists(self.benchmark_save_path):
            os.makedirs(self.benchmark_save_path)

        file_name = os.path.join(
            self.benchmark_save_path,
            f"mattext_benchmark_{self.representation}_{self.benchmark}.json",
        )
        task.to_file(file_name)

__init__(task_cfg) 🔗

Initializes the object with the given task configuration.

Parameters:

Name Type Description Default
task_cfg DictConfig

The configuration dictionary containing task parameters.

required

Returns:

Type Description

None

Source code in src/mattext/models/benchmark.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(self, task_cfg: DictConfig):
    """
    Initializes the object with the given task configuration.

    Parameters:
        task_cfg (DictConfig): The configuration dictionary containing task parameters.

    Returns:
        None
    """
    self.task_cfg = task_cfg
    self.representation = self.task_cfg.model.representation
    self.task = self.task_cfg.model.dataset
    self.task_type = self.task_cfg.model.dataset_type
    self.benchmark = self.task_cfg.model.inference.benchmark_dataset
    self.exp_names = self.task_cfg.model.finetune.exp_name
    self.test_exp_names = self.task_cfg.model.inference.exp_name
    self.train_data = self.task_cfg.model.finetune.dataset_name
    self.test_data = self.task_cfg.model.inference.benchmark_dataset
    self.benchmark_save_path = self.task_cfg.model.inference.benchmark_save_file

    # override wandb project name & tokenizer
    self.wandb_project = self.task_cfg.model.logging.wandb_project

run_benchmarking(local_rank=None) 🔗

Runs benchmarking on the specified dataset.

Parameters:

Name Type Description Default
local_rank int

The local rank for distributed training. Defaults to None.

None

Returns:

Type Description
None

None

Raises:

Type Description
Exception

If an error occurs during inference for a finetuned checkpoint.

Source code in src/mattext/models/benchmark.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def run_benchmarking(self, local_rank=None) -> None:
    """
    Runs benchmarking on the specified dataset.

    Args:
        local_rank (int, optional): The local rank for distributed training. Defaults to None.

    Returns:
        None

    Raises:
        Exception: If an error occurs during inference for a finetuned checkpoint.

    """
    if self.task_type == "matbench":
        mb = MatbenchBenchmark(autoload=False)
        task = getattr(mb, MATTEXT_MATBENCH[self.task])
        task.load()
    else:
        task = MatTextTask(task_name=self.task)

    for i, (exp_name, test_name) in enumerate(
        zip(self.exp_names, self.test_exp_names)
    ):
        print(
            f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}"
        )
        wandb.init(
            config=dict(self.task_cfg.model.finetune),
            project=self.task_cfg.model.logging.wandb_project,
            name=exp_name,
        )
        fold_name = fold_key_namer(i)
        print("-------------------------")
        print(fold_name)
        print("-------------------------")

        exp_cfg = self.task_cfg.copy()
        exp_cfg.model.finetune.exp_name = exp_name
        exp_cfg.model.finetune.path.finetune_traindata = self.train_data

        finetuner = FinetuneModel(exp_cfg, local_rank, fold=fold_name)
        ckpt = finetuner.finetune()
        print("-------------------------")
        print(ckpt)
        print("-------------------------")

        wandb.init(
            config=dict(self.task_cfg.model.inference),
            project=self.task_cfg.model.logging.wandb_project,
            name=test_name,
        )

        exp_cfg.model.inference.path.test_data = self.test_data
        exp_cfg.model.inference.path.pretrained_checkpoint = ckpt

        try:
            predict = Inference(exp_cfg, fold=fold_name)
            predictions, prediction_ids = predict.predict()
            print(len(prediction_ids), len(predictions))

            if self.task_type == "matbench":
                task.record(i, predictions)
            else:
                task.record_fold(
                    fold=i, prediction_ids=prediction_ids, predictions=predictions
                )

        except Exception as e:
            print(
                f"Error occurred during inference for finetuned checkpoint '{exp_name}':"
            )
            print(traceback.format_exc())

    if not os.path.exists(self.benchmark_save_path):
        os.makedirs(self.benchmark_save_path)

    file_name = os.path.join(
        self.benchmark_save_path,
        f"mattext_benchmark_{self.representation}_{self.benchmark}.json",
    )
    task.to_file(file_name)

FinetuneModel 🔗

Bases: TokenizerMixin

Class to perform finetuning of a language model. Initialize the FinetuneModel.

Parameters:

Name Type Description Default
cfg DictConfig

Configuration for the fine-tuning.

required
local_rank int

Local rank for distributed training. Defaults to None.

None
Source code in src/mattext/models/finetune.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class FinetuneModel(TokenizerMixin):
    """Class to perform finetuning of a language model.
        Initialize the FinetuneModel.

    Args:
        cfg (DictConfig): Configuration for the fine-tuning.
        local_rank (int, optional): Local rank for distributed training. Defaults to None.
    """

    def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None:
        super().__init__(
            cfg=cfg.model.representation,
            special_tokens=cfg.model.special_tokens,
            special_num_token=cfg.model.special_num_token,
        )
        self.fold = fold
        self.local_rank = local_rank
        self.representation = cfg.model.representation
        self.data_repository = cfg.model.data_repository
        self.cfg = cfg.model.finetune
        self.context_length: int = self.cfg.context_length
        self.callbacks = self.cfg.callbacks
        self.tokenized_dataset = self._prepare_datasets(
            self.cfg.path.finetune_traindata
        )

    def _prepare_datasets(self, subset: str) -> DatasetDict:
        """
        Prepare training and validation datasets.

        Args:
            train_df (pd.DataFrame): DataFrame containing training data.

        Returns:
            DatasetDict: Dictionary containing training and validation datasets.
        """

        def replace_none(example, replacement="[PAD]"):
            for key, value in example.items():
                if value is None:
                    example[key] = replacement
            return example

        ds = load_dataset(self.data_repository, subset)
        dataset = ds[self.fold].train_test_split(shuffle=True, test_size=0.2, seed=42)
        dataset = dataset.filter(
            lambda example: example[self.representation] is not None
        )
        return dataset.map(
            partial(
                self._tokenize_pad_and_truncate, context_length=self.context_length
            ),
            batched=True,
        )

    def _callbacks(self) -> List[TrainerCallback]:
        """Returns a list of callbacks for early stopping, and custom logging."""
        callbacks = []

        if self.callbacks.early_stopping:
            callbacks.append(
                EarlyStoppingCallback(
                    early_stopping_patience=self.callbacks.early_stopping_patience,
                    early_stopping_threshold=self.callbacks.early_stopping_threshold,
                )
            )

        if self.callbacks.custom_logger:
            callbacks.append(CustomWandbCallback_FineTune())

        callbacks.append(EvaluateFirstStepCallback)

        return callbacks

    def _compute_metrics(self, p: Any, eval=True) -> Dict[str, float]:
        preds = torch.tensor(
            p.predictions.squeeze()
        )  # Convert predictions to PyTorch tensor
        label_ids = torch.tensor(p.label_ids)  # Convert label_ids to PyTorch tensor

        if eval:
            # Calculate RMSE as evaluation metric
            eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item()
            return {"eval_rmse": round(eval_rmse, 3)}
        else:
            # Calculate RMSE as training metric
            loss = torch.sqrt(((preds - label_ids) ** 2).mean()).item()
            return {"train_rmse": round(loss, 3), "loss": round(loss, 3)}

    def finetune(self) -> None:
        """
        Perform fine-tuning of the language model.
        """

        pretrained_ckpt = self.cfg.path.pretrained_checkpoint

        config_train_args = self.cfg.training_arguments
        callbacks = self._callbacks()

        training_args = TrainingArguments(
            **config_train_args,
            metric_for_best_model="eval_rmse",  # Metric to use for determining the best model
            greater_is_better=False,  # Lower eval_rmse is better
        )

        model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False
        )

        if self.cfg.freeze_base_model:
            for param in model.base_model.parameters():
                param.requires_grad = False

        if self.local_rank is not None:
            model = model.to(self.local_rank)
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[self.local_rank]
            )
        else:
            model = model.to("cuda")

        trainer = Trainer(
            model=model,
            args=training_args,
            data_collator=None,
            compute_metrics=self._compute_metrics,
            tokenizer=self._wrapped_tokenizer,
            train_dataset=self.tokenized_dataset["train"],
            eval_dataset=self.tokenized_dataset["test"],
            callbacks=callbacks,
        )

        wandb.log({"Training Arguments": str(config_train_args)})
        wandb.log({"model_summary": str(model)})

        trainer.train()

        eval_result = trainer.evaluate(eval_dataset=self.tokenized_dataset["test"])
        wandb.log(eval_result)

        model.save_pretrained(self.cfg.path.finetuned_modelname)
        wandb.finish()
        return self.cfg.path.finetuned_modelname

    def evaluate(self):
        """
        Evaluate the fine-tuned model on the test dataset.
        """
        ckpt = self.finetune()

evaluate() 🔗

Evaluate the fine-tuned model on the test dataset.

Source code in src/mattext/models/finetune.py
168
169
170
171
172
def evaluate(self):
    """
    Evaluate the fine-tuned model on the test dataset.
    """
    ckpt = self.finetune()

finetune() 🔗

Perform fine-tuning of the language model.

Source code in src/mattext/models/finetune.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def finetune(self) -> None:
    """
    Perform fine-tuning of the language model.
    """

    pretrained_ckpt = self.cfg.path.pretrained_checkpoint

    config_train_args = self.cfg.training_arguments
    callbacks = self._callbacks()

    training_args = TrainingArguments(
        **config_train_args,
        metric_for_best_model="eval_rmse",  # Metric to use for determining the best model
        greater_is_better=False,  # Lower eval_rmse is better
    )

    model = AutoModelForSequenceClassification.from_pretrained(
        pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False
    )

    if self.cfg.freeze_base_model:
        for param in model.base_model.parameters():
            param.requires_grad = False

    if self.local_rank is not None:
        model = model.to(self.local_rank)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[self.local_rank]
        )
    else:
        model = model.to("cuda")

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=None,
        compute_metrics=self._compute_metrics,
        tokenizer=self._wrapped_tokenizer,
        train_dataset=self.tokenized_dataset["train"],
        eval_dataset=self.tokenized_dataset["test"],
        callbacks=callbacks,
    )

    wandb.log({"Training Arguments": str(config_train_args)})
    wandb.log({"model_summary": str(model)})

    trainer.train()

    eval_result = trainer.evaluate(eval_dataset=self.tokenized_dataset["test"])
    wandb.log(eval_result)

    model.save_pretrained(self.cfg.path.finetuned_modelname)
    wandb.finish()
    return self.cfg.path.finetuned_modelname

FinetuneLLamaSFT 🔗

Class to perform finetuning of a language model. Initialize the FinetuneModel.

Parameters:

Name Type Description Default
cfg DictConfig

Configuration for the fine-tuning.

required
local_rank int

Local rank for distributed training. Defaults to None.

None
Source code in src/mattext/models/llama_sft.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class FinetuneLLamaSFT:
    """Class to perform finetuning of a language model.
        Initialize the FinetuneModel.

    Args:
        cfg (DictConfig): Configuration for the fine-tuning.
        local_rank (int, optional): Local rank for distributed training. Defaults to None.
    """

    def __init__(
        self, cfg: DictConfig, local_rank=None, fold="fold_0", test_sample_size=None
    ) -> None:
        self.fold = fold
        self.local_rank = local_rank
        self.representation = cfg.model.representation
        self.data_repository = cfg.model.data_repository
        self.dataset_ = cfg.model.dataset
        self.add_special_tokens = cfg.model.add_special_tokens
        self.property_map = cfg.model.PROPERTY_MAP
        self.material_map = cfg.model.MATERIAL_MAP
        self.cfg = cfg.model.finetune
        self.train_data = self.cfg.dataset_name
        self.test_data = self.cfg.benchmark_dataset
        self.context_length: int = self.cfg.context_length
        self.dataprep_seed: int = self.cfg.dataprep_seed
        self.callbacks = self.cfg.callbacks
        self.ckpt = self.cfg.path.pretrained_checkpoint
        self.bnb_config = self.cfg.bnb_config
        self.dataset = self.prepare_data(self.train_data)
        self.testdata = self.prepare_test_data(self.test_data)
        self.model, self.tokenizer, self.peft_config = self._setup_model_tokenizer()
        self.property_ = self.property_map[self.dataset_]
        self.material_ = self.material_map[self.dataset_]
        self.test_sample_size = test_sample_size

    def prepare_test_data(self, subset):
        dataset = load_dataset(self.data_repository, subset)[self.fold]
        if self.test_sample_size:
            dataset = dataset.select(range(self.test_sample_size))
        return dataset

    def prepare_data(self, subset):
        dataset = load_dataset(self.data_repository, subset)
        dataset = dataset.shuffle(seed=self.dataprep_seed)[self.fold]
        return dataset.train_test_split(test_size=0.1, seed=self.dataprep_seed)

    def _setup_model_tokenizer(self) -> None:
        # device_string = PartialState().process_index
        # compute_dtype = getattr(torch, "float16")

        if self.bnb_config.use_4bit and self.bnb_config.use_8bit:
            raise ValueError(
                "You can't load the model in 8 bits and 4 bits at the same time"
            )

        elif self.bnb_config.use_4bit or self.bnb_config.use_8bit:
            compute_dtype = getattr(torch, self.bnb_config.bnb_4bit_compute_dtype)
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=self.bnb_config.use_4bit,
                load_in_8bit=self.bnb_config.use_8bit,
                bnb_4bit_quant_type=self.bnb_config.bnb_4bit_quant_type,
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=self.bnb_config.use_nested_quant,
            )
        else:
            bnb_config = None

        # Check GPU compatibility with bfloat16
        if compute_dtype == torch.float16:
            major, _ = torch.cuda.get_device_capability()
            if major >= 8:
                logger.info(
                    "Your GPU supports bfloat16: accelerate training with bf16=True!"
                )

        # LoRA config
        peft_config = LoraConfig(**self.cfg.lora_config)

        tokenizer = AutoTokenizer.from_pretrained(
            self.ckpt,
            use_fast=False,
        )
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"

        model = AutoModelForCausalLM.from_pretrained(
            self.ckpt,
            quantization_config=bnb_config,
            device_map="auto",
        )

        return model, tokenizer, peft_config

    def formatting_prompts_func(self, example):
        output_texts = []
        for i in range(len(example[self.representation])):
            text = f"### What is the {self.property_} of {example[self.representation][i]}\n ### Answer: {example['labels'][i]:.3f}@@@"
            output_texts.append(text)
        return output_texts

    def formatting_tests_func(self, example):
        output_texts = []
        for i in range(len(example[self.representation])):
            text = f"### What is the {self.property_} of {example[self.representation][i]}\n "
            output_texts.append(text)
        return output_texts

    def _callbacks(self) -> List[TrainerCallback]:
        """Returns a list of callbacks for early stopping, and custom logging."""
        callbacks = []

        if self.callbacks.early_stopping:
            callbacks.append(
                EarlyStoppingCallback(
                    early_stopping_patience=self.callbacks.early_stopping_patience,
                    early_stopping_threshold=self.callbacks.early_stopping_threshold,
                )
            )
        callbacks.append(EvaluateFirstStepCallback)
        return callbacks

    def finetune(self) -> None:
        """
        Perform fine-tuning of the language model.
        """

        config_train_args = self.cfg.training_arguments
        training_args = TrainingArguments(
            **config_train_args,
        )
        callbacks = self._callbacks()

        response_template = " ### Answer:"
        collator = DataCollatorForCompletionOnlyLM(
            response_template, tokenizer=self.tokenizer
        )

        packing = False
        max_seq_length = None
        if self.representation == "cif_p1":
            max_seq_length = 2048

        trainer = SFTTrainer(
            model=self.model,
            peft_config=self.peft_config,
            train_dataset=self.dataset["train"],
            eval_dataset=self.dataset["test"],
            formatting_func=self.formatting_prompts_func,
            data_collator=collator,
            max_seq_length=max_seq_length,
            tokenizer=self.tokenizer,
            args=training_args,
            packing=packing,
            callbacks=callbacks,
        )

        wandb.log({"Training Arguments": str(config_train_args)})
        wandb.log({"model_summary": str(self.model)})

        self.output_dir_ = (
            f"{self.cfg.path.finetuned_modelname}/llamav3-8b-lora-fine-tune"
        )
        trainer.train()

        pipe = pipeline(
            "text-generation",
            model=trainer.model,
            tokenizer=self.tokenizer,
            return_full_text=False,
            do_sample=False,
            max_new_tokens=4,
        )
        with torch.cuda.amp.autocast():
            pred = pipe(self.formatting_tests_func(self.testdata))
        logger.debug("Prediction: %s", pred)

        with open(
            f"{self.cfg.path.finetuned_modelname}_{self.fold}_predictions.json", "w"
        ) as json_file:
            json.dump(pred, json_file)

        trainer.save_state()
        trainer.save_model(self.output_dir_)

        # Merge LoRA and base model
        merged_model = trainer.model.merge_and_unload()
        # Save the merged model
        merged_model.save_pretrained(
            f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained",
            save_config=True,
            safe_serialization=True,
        )
        self.tokenizer.save_pretrained(
            f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained"
        )

        with torch.cuda.amp.autocast():
            merge_pred = pipe(self.formatting_tests_func(self.testdata))
        logger.debug("Prediction: %s", merge_pred)

        with open(
            f"{self.cfg.path.finetuned_modelname}__{self.fold}_predictions_merged.json",
            "w",
        ) as json_file:
            json.dump(merge_pred, json_file)

        wandb.finish()
        return self.cfg.path.finetuned_modelname

finetune() 🔗

Perform fine-tuning of the language model.

Source code in src/mattext/models/llama_sft.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def finetune(self) -> None:
    """
    Perform fine-tuning of the language model.
    """

    config_train_args = self.cfg.training_arguments
    training_args = TrainingArguments(
        **config_train_args,
    )
    callbacks = self._callbacks()

    response_template = " ### Answer:"
    collator = DataCollatorForCompletionOnlyLM(
        response_template, tokenizer=self.tokenizer
    )

    packing = False
    max_seq_length = None
    if self.representation == "cif_p1":
        max_seq_length = 2048

    trainer = SFTTrainer(
        model=self.model,
        peft_config=self.peft_config,
        train_dataset=self.dataset["train"],
        eval_dataset=self.dataset["test"],
        formatting_func=self.formatting_prompts_func,
        data_collator=collator,
        max_seq_length=max_seq_length,
        tokenizer=self.tokenizer,
        args=training_args,
        packing=packing,
        callbacks=callbacks,
    )

    wandb.log({"Training Arguments": str(config_train_args)})
    wandb.log({"model_summary": str(self.model)})

    self.output_dir_ = (
        f"{self.cfg.path.finetuned_modelname}/llamav3-8b-lora-fine-tune"
    )
    trainer.train()

    pipe = pipeline(
        "text-generation",
        model=trainer.model,
        tokenizer=self.tokenizer,
        return_full_text=False,
        do_sample=False,
        max_new_tokens=4,
    )
    with torch.cuda.amp.autocast():
        pred = pipe(self.formatting_tests_func(self.testdata))
    logger.debug("Prediction: %s", pred)

    with open(
        f"{self.cfg.path.finetuned_modelname}_{self.fold}_predictions.json", "w"
    ) as json_file:
        json.dump(pred, json_file)

    trainer.save_state()
    trainer.save_model(self.output_dir_)

    # Merge LoRA and base model
    merged_model = trainer.model.merge_and_unload()
    # Save the merged model
    merged_model.save_pretrained(
        f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained",
        save_config=True,
        safe_serialization=True,
    )
    self.tokenizer.save_pretrained(
        f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained"
    )

    with torch.cuda.amp.autocast():
        merge_pred = pipe(self.formatting_tests_func(self.testdata))
    logger.debug("Prediction: %s", merge_pred)

    with open(
        f"{self.cfg.path.finetuned_modelname}__{self.fold}_predictions_merged.json",
        "w",
    ) as json_file:
        json.dump(merge_pred, json_file)

    wandb.finish()
    return self.cfg.path.finetuned_modelname

FinetuneLLama 🔗

Class to perform finetuning of LLama using a regression head.

Parameters:

Name Type Description Default
cfg DictConfig

Configuration for the fine-tuning.

required
local_rank int

Local rank for distributed training. Defaults to None.

None
Source code in src/mattext/models/llama.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class FinetuneLLama:
    """Class to perform finetuning of LLama using
    a regression head.

    Args:
        cfg (DictConfig): Configuration for the fine-tuning.
        local_rank (int, optional): Local rank for distributed training. Defaults to None.
    """

    def __init__(self, cfg: DictConfig, local_rank=None) -> None:
        self.local_rank = local_rank
        self.representation = cfg.model.representation
        self.cfg = cfg.model.finetune
        self.context_length: int = self.cfg.context_length
        self.callbacks = self.cfg.callbacks
        self.ckpt = self.cfg.path.pretrained_checkpoint
        self.bnb_config = self.cfg.bnb_config
        self.model, self.tokenizer = self._setup_model_tokenizer()
        self.tokenized_dataset = self._prepare_datasets(
            self.cfg.path.finetune_traindata
        )

    def _setup_model_tokenizer(self) -> None:
        llama_tokenizer = LlamaTokenizer.from_pretrained(
            self.ckpt,
            model_max_length=MAX_LENGTH,
            padding_side="right",
            use_fast=False,
        )

        if self.bnb_config.use_4bit and self.bnb_config.use_8bit:
            raise ValueError(
                "You can't load the model in 8 bits and 4 bits at the same time"
            )

        elif self.bnb_config.use_4bit or self.bnb_config.use_8bit:
            compute_dtype = getattr(torch, self.bnb_config.bnb_4bit_compute_dtype)
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=self.bnb_config.use_4bit,
                load_in_8bit=self.bnb_config.use_8bit,
                bnb_4bit_quant_type=self.bnb_config.bnb_4bit_quant_type,
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=self.bnb_config.use_nested_quant,
            )
        else:
            bnb_config = None

        # Check GPU compatibility with bfloat16
        if compute_dtype == torch.float16:
            major, _ = torch.cuda.get_device_capability()
            if major >= 8:
                print("=" * 80)
                print("Your GPU supports bfloat16: accelerate training with bf16=True")
                print("=" * 80)

        device_map = {"": 0}
        model = LlamaForSequenceClassification.from_pretrained(
            self.ckpt,
            num_labels=1,
            quantization_config=bnb_config,
            device_map=device_map,
        )

        lora_config = LoraConfig(**self.cfg.lora_config)
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

        special_tokens_dict = dict()
        if llama_tokenizer.pad_token is None:
            special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
        if llama_tokenizer.eos_token is None:
            special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
        if llama_tokenizer.bos_token is None:
            special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
        if llama_tokenizer.unk_token is None:
            special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            llama_tokenizer=llama_tokenizer,
            model=model,
        )

        print(len(llama_tokenizer))
        return model, llama_tokenizer

    def _tokenize(self, examples):
        tokenized_examples = self.tokenizer(
            examples[self.representation],
            truncation=True,
            padding=True,
            return_tensors="pt",
        )
        return tokenized_examples

    def _prepare_datasets(self, path: str) -> DatasetDict:
        """
        Prepare training and validation datasets.

        Args:
           path (Union[str, Path]): Path to json file containing the data

        Returns:
            DatasetDict: Dictionary containing training and validation datasets.
        """

        ds = load_dataset("json", data_files=path, split="train")
        dataset = ds.train_test_split(shuffle=True, test_size=0.2, seed=42)
        return dataset.map(self._tokenize, batched=True)

    def _callbacks(self) -> List[TrainerCallback]:
        """Returns a list of callbacks for early stopping, and custom logging."""
        callbacks = []

        if self.callbacks.early_stopping:
            callbacks.append(
                EarlyStoppingCallback(
                    early_stopping_patience=self.callbacks.early_stopping_patience,
                    early_stopping_threshold=self.callbacks.early_stopping_threshold,
                )
            )

        if self.callbacks.custom_logger:
            callbacks.append(CustomWandbCallback_FineTune())

        callbacks.append(EvaluateFirstStepCallback)

        return callbacks

    def _compute_metrics(self, p: Any, eval=True) -> Dict[str, float]:
        preds = torch.tensor(
            p.predictions.squeeze()
        )  # Convert predictions to PyTorch tensor
        label_ids = torch.tensor(p.label_ids)  # Convert label_ids to PyTorch tensor

        if eval:
            # Calculate RMSE as evaluation metric
            eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item()
            return {"eval_rmse": round(eval_rmse, 3)}
        else:
            # Calculate RMSE as training metric
            loss = torch.sqrt(((preds - label_ids) ** 2).mean()).item()
            return {"train_rmse": round(loss, 3), "loss": round(loss, 3)}

    def finetune(self) -> None:
        """
        Perform fine-tuning of the language model.
        """

        config_train_args = self.cfg.training_arguments
        callbacks = self._callbacks()

        # os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
        training_args = TrainingArguments(
            **config_train_args,
            metric_for_best_model="eval_rmse",  # Metric to use for determining the best model
            greater_is_better=False,  # Lower eval_rmse is better
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            data_collator=None,
            compute_metrics=self._compute_metrics,
            tokenizer=self.tokenizer,
            train_dataset=self.tokenized_dataset["train"],
            eval_dataset=self.tokenized_dataset["test"],
            callbacks=callbacks,
        )

        wandb.log({"Training Arguments": str(config_train_args)})
        wandb.log({"model_summary": str(self.model)})

        trainer.train()
        trainer.save_model(
            f"{self.cfg.path.finetuned_modelname}/llamav2-7b-lora-fine-tune"
        )

        eval_result = trainer.evaluate(eval_dataset=self.tokenized_dataset["test"])
        wandb.log(eval_result)

        self.model.save_pretrained(self.cfg.path.finetuned_modelname)
        wandb.finish()
        return self.cfg.path.finetuned_modelname

    def evaluate(self):
        """
        Evaluate the fine-tuned model on the test dataset.
        """
        ckpt = self.finetune()

evaluate() 🔗

Evaluate the fine-tuned model on the test dataset.

Source code in src/mattext/models/llama.py
249
250
251
252
253
def evaluate(self):
    """
    Evaluate the fine-tuned model on the test dataset.
    """
    ckpt = self.finetune()

finetune() 🔗

Perform fine-tuning of the language model.

Source code in src/mattext/models/llama.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def finetune(self) -> None:
    """
    Perform fine-tuning of the language model.
    """

    config_train_args = self.cfg.training_arguments
    callbacks = self._callbacks()

    # os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
    training_args = TrainingArguments(
        **config_train_args,
        metric_for_best_model="eval_rmse",  # Metric to use for determining the best model
        greater_is_better=False,  # Lower eval_rmse is better
    )

    trainer = Trainer(
        model=self.model,
        args=training_args,
        data_collator=None,
        compute_metrics=self._compute_metrics,
        tokenizer=self.tokenizer,
        train_dataset=self.tokenized_dataset["train"],
        eval_dataset=self.tokenized_dataset["test"],
        callbacks=callbacks,
    )

    wandb.log({"Training Arguments": str(config_train_args)})
    wandb.log({"model_summary": str(self.model)})

    trainer.train()
    trainer.save_model(
        f"{self.cfg.path.finetuned_modelname}/llamav2-7b-lora-fine-tune"
    )

    eval_result = trainer.evaluate(eval_dataset=self.tokenized_dataset["test"])
    wandb.log(eval_result)

    self.model.save_pretrained(self.cfg.path.finetuned_modelname)
    wandb.finish()
    return self.cfg.path.finetuned_modelname

smart_tokenizer_and_embedding_resize(special_tokens_dict, llama_tokenizer, model) 🔗

Resize tokenizer and embedding.

Note: This is the unoptimized version that may make your embedding size not be divisible by 64.

Source code in src/mattext/models/llama.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict,
    llama_tokenizer,
    model,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict)
    llama_tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(llama_tokenizer), pad_to_multiple_of=8)

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        #   output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        #   output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg

    model.config.pad_token_id = llama_tokenizer.pad_token_id

PotentialModel 🔗

Bases: TokenizerMixin

Class to perform finetuning of a language model on the hypothetical potential task.

Parameters:

Name Type Description Default
cfg DictConfig

Configuration for the fine-tuning.

required
local_rank int

Local rank for distributed training. Defaults to None.

None
Source code in src/mattext/models/potential.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class PotentialModel(TokenizerMixin):
    """Class to perform finetuning of a language model on
    the hypothetical potential task.

    Args:
        cfg (DictConfig): Configuration for the fine-tuning.
        local_rank (int, optional): Local rank for distributed training. Defaults to None.
    """

    def __init__(self, cfg: DictConfig, local_rank=None) -> None:
        super().__init__(
            cfg=cfg.model.representation,
            special_tokens=cfg.model.special_tokens,
            special_num_token=cfg.model.special_num_token,
        )
        self.local_rank = local_rank
        self.representation = cfg.model.representation
        self.alpha = cfg.model.alpha
        self.test_data = cfg.model.inference.path.test_data
        self.cfg = cfg.model.finetune
        self.context_length: int = self.cfg.context_length
        self.callbacks = self.cfg.callbacks
        self.tokenized_dataset = self._prepare_datasets(
            self.cfg.path.finetune_traindata, split="train"
        )
        self.tokenized_testset = self._prepare_datasets(self.test_data, split="test")

    def _prepare_datasets(self, path: str, split) -> DatasetDict:
        """
        Prepare training and validation datasets.

        Args:
            train_df (pd.DataFrame): DataFrame containing training data.

        Returns:
            DatasetDict: Dictionary containing training and validation datasets.
        """

        ds = load_dataset("json", data_files=path, split="train")
        # with contextlib.suppress(KeyError):
        #     ds = ds.remove_columns("labels")
        if split == "train":
            ds = ds.remove_columns("labels")
        else:
            print("test set")

        labal_name = f"total_energy_alpha_{self.alpha}"
        ds = ds.rename_column(labal_name, "labels")
        dataset = ds.train_test_split(shuffle=True, test_size=0.2, seed=42)
        # dataset= dataset.filter(lambda example: example[self.representation] is not None)
        return dataset.map(
            partial(
                self._tokenize_pad_and_truncate, context_length=self.context_length
            ),
            batched=True,
        )

    def _callbacks(self) -> List[TrainerCallback]:
        """Returns a list of callbacks for early stopping, and custom logging."""
        callbacks = []

        if self.callbacks.early_stopping:
            callbacks.append(
                EarlyStoppingCallback(
                    early_stopping_patience=self.callbacks.early_stopping_patience,
                    early_stopping_threshold=self.callbacks.early_stopping_threshold,
                )
            )

        if self.callbacks.custom_logger:
            callbacks.append(CustomWandbCallback_FineTune())

        callbacks.append(EvaluateFirstStepCallback)

        return callbacks

    def _compute_metrics(self, p: Any, eval=True) -> Dict[str, float]:
        preds = torch.tensor(
            p.predictions.squeeze()
        )  # Convert predictions to PyTorch tensor
        label_ids = torch.tensor(p.label_ids)  # Convert label_ids to PyTorch tensor

        if eval:
            # Calculate RMSE as evaluation metric
            eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item()
            return {"eval_rmse": round(eval_rmse, 3)}
        else:
            # Calculate RMSE as training metric
            loss = torch.sqrt(((preds - label_ids) ** 2).mean()).item()
            return {"train_rmse": round(loss, 3), "loss": round(loss, 3)}

    def finetune(self) -> None:
        """
        Perform fine-tuning of the language model.
        """

        pretrained_ckpt = self.cfg.path.pretrained_checkpoint

        config_train_args = self.cfg.training_arguments
        callbacks = self._callbacks()

        training_args = TrainingArguments(
            **config_train_args,
            metric_for_best_model="eval_rmse",  # Metric to use for determining the best model
            greater_is_better=False,  # Lower eval_rmse is better
        )

        model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False
        )

        if self.cfg.freeze_base_model:
            for param in model.base_model.parameters():
                param.requires_grad = False

        if self.local_rank is not None:
            model = model.to(self.local_rank)
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[self.local_rank]
            )
        else:
            model = model.to("cuda")

        trainer = Trainer(
            model=model,
            args=training_args,
            data_collator=None,
            compute_metrics=self._compute_metrics,
            tokenizer=self._wrapped_tokenizer,
            train_dataset=self.tokenized_dataset["train"],
            eval_dataset=self.tokenized_dataset["test"],
            callbacks=callbacks,
        )

        wandb.log({"Training Arguments": str(config_train_args)})
        wandb.log({"model_summary": str(model)})

        trainer.train()
        model.save_pretrained(self.cfg.path.finetuned_modelname)

        eval_result = trainer.evaluate(eval_dataset=self.tokenized_testset)
        wandb.log(eval_result)
        wandb.finish()
        return self.cfg.path.finetuned_modelname

    def evaluate(self):
        """
        Evaluate the fine-tuned model on the test dataset.
        """
        ckpt = self.finetune()

evaluate() 🔗

Evaluate the fine-tuned model on the test dataset.

Source code in src/mattext/models/potential.py
169
170
171
172
173
def evaluate(self):
    """
    Evaluate the fine-tuned model on the test dataset.
    """
    ckpt = self.finetune()

finetune() 🔗

Perform fine-tuning of the language model.

Source code in src/mattext/models/potential.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def finetune(self) -> None:
    """
    Perform fine-tuning of the language model.
    """

    pretrained_ckpt = self.cfg.path.pretrained_checkpoint

    config_train_args = self.cfg.training_arguments
    callbacks = self._callbacks()

    training_args = TrainingArguments(
        **config_train_args,
        metric_for_best_model="eval_rmse",  # Metric to use for determining the best model
        greater_is_better=False,  # Lower eval_rmse is better
    )

    model = AutoModelForSequenceClassification.from_pretrained(
        pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False
    )

    if self.cfg.freeze_base_model:
        for param in model.base_model.parameters():
            param.requires_grad = False

    if self.local_rank is not None:
        model = model.to(self.local_rank)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[self.local_rank]
        )
    else:
        model = model.to("cuda")

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=None,
        compute_metrics=self._compute_metrics,
        tokenizer=self._wrapped_tokenizer,
        train_dataset=self.tokenized_dataset["train"],
        eval_dataset=self.tokenized_dataset["test"],
        callbacks=callbacks,
    )

    wandb.log({"Training Arguments": str(config_train_args)})
    wandb.log({"model_summary": str(model)})

    trainer.train()
    model.save_pretrained(self.cfg.path.finetuned_modelname)

    eval_result = trainer.evaluate(eval_dataset=self.tokenized_testset)
    wandb.log(eval_result)
    wandb.finish()
    return self.cfg.path.finetuned_modelname

Inference 🔗

Bases: TokenizerMixin

Class to perform inference on a language model with a sequence classification head.

Source code in src/mattext/models/predict.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class Inference(TokenizerMixin):
    """Class to perform inference on a language model with a sequence classification head."""

    def __init__(self, cfg: DictConfig, fold="fold_0"):
        super().__init__(
            cfg=cfg.model.representation,
            special_tokens=cfg.model.special_tokens,
            special_num_token=cfg.model.special_num_token,
        )
        self.fold = fold
        self.representation = cfg.model.representation
        self.data_repository = cfg.model.data_repository
        self.dataset_name = cfg.model.finetune.dataset_name
        self.cfg = cfg.model.inference
        self.context_length: int = self.cfg.context_length
        self.tokenized_test_datasets = self._prepare_datasets(self.cfg.path.test_data)
        self.prediction_ids = None

    def _prepare_datasets(self, path: str) -> DatasetDict:
        """
        Prepare training and validation datasets.

        Args:
            train_df (pd.DataFrame): DataFrame containing training data.

        Returns:
            DatasetDict: Dictionary containing training and validation datasets.
        """
        dataset = load_dataset(self.data_repository, path)
        filtered_dataset = dataset[self.fold].filter(
            lambda example: example[self.representation] is not None
        )

        return filtered_dataset.map(
            partial(
                self._tokenize_pad_and_truncate, context_length=self.context_length
            ),
            batched=True,
        )

    def _callbacks(self) -> List[TrainerCallback]:
        """Returns a list of callbacks for logging."""
        return [CustomWandbCallback_Inference()]

    def predict(self):
        pretrained_ckpt = self.cfg.path.pretrained_checkpoint
        callbacks = self._callbacks()

        model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False
        )

        trainer = Trainer(
            model=model.to("cuda"), data_collator=None, callbacks=callbacks
        )

        predictions = trainer.predict(self.tokenized_test_datasets)
        for callback in callbacks:
            callback.on_predict_end(
                None, None, None, model, predictions
            )  # Manually trigger callback
        torch.cuda.empty_cache()

        # TODO: Save predictions to disk optional
        # os.makedirs(self.cfg.path.predictions, exist_ok=True)
        # predictions_path = os.path.join(self.cfg.path.predictions, 'predictions.npy')
        # np.save(predictions_path, predictions.predictions)
        prediction_ids = self.tokenized_test_datasets["mbid"]
        self.prediction_ids = prediction_ids

        return pd.Series(predictions.predictions.flatten()), prediction_ids

PretrainModel 🔗

Bases: TokenizerMixin

Class to perform pretraining of a language model.

Source code in src/mattext/models/pretrain.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class PretrainModel(TokenizerMixin):
    """Class to perform pretraining of a language model."""

    def __init__(self, cfg: DictConfig, local_rank=None):
        super().__init__(
            cfg=cfg.model.representation,
            special_tokens=cfg.model.special_tokens,
            special_num_token=cfg.model.special_num_token,
        )
        self.local_rank = local_rank
        self.representation = cfg.model.representation
        self.data_repository = cfg.model.data_repository
        self.cfg = cfg.model.pretrain
        self.context_length: int = self.cfg.context_length
        self.callbacks = self.cfg.callbacks
        self.model_name_or_path: str = self.cfg.model_name_or_path
        self.local_file_path = cfg.model.dataset_local_path if cfg.model.dataset_local_path else None
        self.tokenized_train_datasets, self.tokenized_eval_datasets =  self._prepare_datasets(
            subset=self.cfg.dataset_name,local_file_path=self.local_file_path
        )

    def _prepare_datasets(self, subset: str, local_file_path: Optional[str] = None) -> DatasetDict:
        """
        Prepare training and validation datasets.

        Args:
            train_df (pd.DataFrame): DataFrame containing training data.

        Returns:
            DatasetDict: Dictionary containing training and validation datasets.
        """
        if local_file_path:
            # Load data from a local JSON file
            train_dataset = load_dataset("json", data_files=f"{local_file_path}/train.json", split="train")
            eval_dataset = load_dataset("json", data_files=f"{local_file_path}/test.json", split="train")
        else:
            # Load data from the repository
            dataset = load_dataset(self.data_repository, subset)
            train_dataset = dataset["train"]
            eval_dataset = dataset["test"]

        filtered_train_dataset = train_dataset.filter(
            lambda example: example[self.representation] is not None
        )
        filtered_eval_dataset = eval_dataset.filter(
            lambda example: example[self.representation] is not None
        )

        return filtered_train_dataset.map(
            partial(
                self._tokenize_pad_and_truncate, context_length=self.context_length
            ),
            batched=True,
        ), filtered_eval_dataset.map(
            partial(
                self._tokenize_pad_and_truncate, context_length=self.context_length
            ),
            batched=True,
        )

    def _callbacks(self) -> List[TrainerCallback]:
        """Returns a list of callbacks for early stopping, and custom logging."""
        callbacks = []

        if self.callbacks.early_stopping:
            callbacks.append(
                EarlyStoppingCallback(
                    early_stopping_patience=self.callbacks.early_stopping_patience,
                    early_stopping_threshold=self.callbacks.early_stopping_threshold,
                )
            )

        if self.callbacks.custom_logger:
            callbacks.append(CustomWandbCallback_Pretrain())

        return callbacks

    def pretrain_mlm(self) -> None:
        """Performs MLM pretraining of the language model."""
        config_mlm = self.cfg.mlm
        config_train_args = self.cfg.training_arguments
        config_model_args = self.cfg.model_config
        #config_model_args["max_position_embeddings"] = self.context_length

        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self._wrapped_tokenizer,
            mlm=config_mlm.is_mlm,
            mlm_probability=config_mlm.mlm_probability,
        )

        callbacks = self._callbacks()

        config = AutoConfig.from_pretrained(
            self.model_name_or_path, **config_model_args
        )

        model = AutoModelForMaskedLM.from_config(config)

        if self.local_rank is not None:
            model = model.to(self.local_rank)
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[self.local_rank]
            )
        else:
            model = model.to("cuda")

        training_args = TrainingArguments(**config_train_args)

        trainer = Trainer(
            model=model,
            data_collator=data_collator,
            train_dataset=self.tokenized_train_datasets,
            eval_dataset=self.tokenized_eval_datasets,
            args=training_args,
            callbacks=callbacks,
        )

        wandb.log({"config_details": str(config)})
        wandb.log({"Training Arguments": str(config_train_args)})
        wandb.log({"model_summary": str(model)})

        trainer.train()

        # Save the fine-tuned model
        model.save_pretrained(self.cfg.path.finetuned_modelname)

pretrain_mlm() 🔗

Performs MLM pretraining of the language model.

Source code in src/mattext/models/pretrain.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def pretrain_mlm(self) -> None:
    """Performs MLM pretraining of the language model."""
    config_mlm = self.cfg.mlm
    config_train_args = self.cfg.training_arguments
    config_model_args = self.cfg.model_config
    #config_model_args["max_position_embeddings"] = self.context_length

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=self._wrapped_tokenizer,
        mlm=config_mlm.is_mlm,
        mlm_probability=config_mlm.mlm_probability,
    )

    callbacks = self._callbacks()

    config = AutoConfig.from_pretrained(
        self.model_name_or_path, **config_model_args
    )

    model = AutoModelForMaskedLM.from_config(config)

    if self.local_rank is not None:
        model = model.to(self.local_rank)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[self.local_rank]
        )
    else:
        model = model.to("cuda")

    training_args = TrainingArguments(**config_train_args)

    trainer = Trainer(
        model=model,
        data_collator=data_collator,
        train_dataset=self.tokenized_train_datasets,
        eval_dataset=self.tokenized_eval_datasets,
        args=training_args,
        callbacks=callbacks,
    )

    wandb.log({"config_details": str(config)})
    wandb.log({"Training Arguments": str(config_train_args)})
    wandb.log({"model_summary": str(model)})

    trainer.train()

    # Save the fine-tuned model
    model.save_pretrained(self.cfg.path.finetuned_modelname)