Skip to content

Commit

Permalink
Fix datacollator and update docs (#39)
Browse files Browse the repository at this point in the history
* update index.md

* fix ruff erors

* [pre-commit.ci] Add auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update user guide

* update user_guide examples

* update docs

* fix bug in maskgit obj

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
a-kore and pre-commit-ci[bot] authored Aug 12, 2024
1 parent 21ecb9d commit 86f8312
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 30 deletions.
4 changes: 2 additions & 2 deletions atomgen/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ def torch_mask_tokens(
batch_randperm = torch.rand((batch, seq_len)).argsort(dim=-1)
mask = batch_randperm < num_token_masked.unsqueeze(1)
inputs = torch.where(
mask,
~mask,
inputs,
self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token),
)
labels = torch.where(~mask, labels, -100)
labels = torch.where(mask, labels, -100)
if special_tokens_mask is not None:
labels = torch.where(~special_tokens_mask, labels, -100)

Expand Down
1 change: 0 additions & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
```markdown
---
hide-toc: true
---
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/api/atomgen.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Submodules

atomgen.data.data_collator
atomgen.data.tokenizer
atomgen.data.utils
7 changes: 7 additions & 0 deletions docs/source/reference/api/atomgen.data.utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
atomgen.data.utils module
=========================

.. automodule:: atomgen.data.utils
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
atomgen.models.configuration\_atomformer module
===============================================

.. automodule:: atomgen.models.configuration_atomformer
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
atomgen.models.modeling\_atomformer module
==========================================

.. automodule:: atomgen.models.modeling_atomformer
:members:
:undoc-members:
:show-inheritance:
3 changes: 2 additions & 1 deletion docs/source/reference/api/atomgen.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Submodules
.. toctree::
:maxdepth: 4

atomgen.models.configuration_atomformer
atomgen.models.modeling_atomformer
atomgen.models.schnet
atomgen.models.tokengt
atomgen.models.unimolplus
7 changes: 0 additions & 7 deletions docs/source/reference/api/atomgen.models.unimolplus.rst

This file was deleted.

22 changes: 3 additions & 19 deletions docs/source/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,33 +110,19 @@ python run_atom3d.py \
--model_name_or_path "vector-institute/atomformer-base" \
--dataset_name "vector-institute/atom3d-smp" \
--output_dir "./results" \
--do_train \
--do_eval \
--max_seq_length 512 \
--per_device_train_batch_size 32 \
--batch_size 32 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--save_steps 10000 \
--evaluation_strategy "steps" \
--eval_steps 5000 \
--load_best_model_at_end \
--metric_for_best_model "mae" \
--greater_is_better false
```

Key arguments for `run_atom3d.py`:

- `--model_name_or_path`: Pretrained model to start from
- `--dataset_name`: ATOM3D dataset to use for fine-tuning
- `--output_dir`: Directory to save results
- `--do_train`: Perform training
- `--do_eval`: Perform evaluation
- `--max_seq_length`: Maximum sequence length
- `--per_device_train_batch_size`: Batch size per GPU/CPU for training
- `--batch_size`: Batch size per GPU/CPU for training
- `--learning_rate`: Initial learning rate
- `--num_train_epochs`: Total number of training epochs
- `--evaluation_strategy`: When to evaluate during training
- `--metric_for_best_model`: Metric to use for saving best model

## Inference

Expand Down Expand Up @@ -195,9 +181,7 @@ python -m torch.distributed.launch --nproc_per_node=4 run_atom3d.py \
--model_name_or_path "vector-institute/atomformer-base" \
--dataset_name "vector-institute/atom3d-smp" \
--output_dir "./results" \
--do_train \
--do_eval \
--per_device_train_batch_size 8 \
--batch_size 8 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
```
Expand Down

0 comments on commit 86f8312

Please sign in to comment.