Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support to use datasets using features with sequences of class labels #5613

Open
wants to merge 2 commits into
base: feat/argilla-direct-feature-branch
Choose a base branch
from

Conversation

jfcalvo
Copy link
Member

@jfcalvo jfcalvo commented Oct 18, 2024

Description

This PR adds the following changes:

  • Add casting for features using sequences of class labels (casting them using int2str function).
  • Casting to string values for suggestions mapped to multi label questions (iterating over the values).

Refs argilla-io/roadmap#21

Type of change

  • New feature (non-breaking change which adds functionality)

How Has This Been Tested

  • Adding additional tests to our suite.

Checklist

  • I added relevant documentation
  • I followed the style guidelines of this project
  • I did a self-review of my code
  • I made corresponding changes to the documentation
  • I confirm My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)

@@ -102,6 +102,8 @@ def _batch_index_to_row(self, batch: dict, index: int) -> dict:
row[feature_name] = None
else:
row[feature_name] = feature.int2str(value)
elif isinstance(feature, features.Sequence) and isinstance(feature.feature, features.ClassLabel):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I understand that inside a feature that it's a Sequence we have inside another feature. But I'm not 100% sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docs:

If your data type contains a list of objects, then you want to use the Sequence feature. Remember the SQuAD dataset?

from datasets import load_dataset

dataset = load_dataset('squad', split='train')
dataset.features
{'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None),
'context': Value(dtype='string', id=None),
'id': Value(dtype='string', id=None),
'question': Value(dtype='string', id=None),
'title': Value(dtype='string', id=None)}

The answers field is constructed using the Sequence feature because it contains two subfields, text and >answer_start, which are lists of string and int32, respectively.

The feature could be whatever. But this condition is quite enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then my code it's fine.

@@ -164,6 +166,9 @@ def _row_suggestions(self, row: dict, dataset: Dataset) -> list:
if question.is_text or question.is_label_selection:
value = str(value)

if question.is_multi_label_selection:
value = [str(v) for v in value]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should check if the value is a list. Otherwise, we should "list-ilize" it: value = [value]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link

codecov bot commented Oct 18, 2024

Codecov Report

Attention: Patch coverage is 88.88889% with 1 line in your changes missing coverage. Please review.

Project coverage is 91.18%. Comparing base (d375c4b) to head (80f4d60).

Files with missing lines Patch % Lines
argilla-server/src/argilla_server/contexts/hub.py 83.33% 1 Missing ⚠️
Additional details and impacted files
@@                          Coverage Diff                           @@
##           feat/argilla-direct-feature-branch    #5613      +/-   ##
======================================================================
- Coverage                               91.20%   91.18%   -0.02%     
======================================================================
  Files                                     150      150              
  Lines                                    6251     6260       +9     
======================================================================
+ Hits                                     5701     5708       +7     
- Misses                                    550      552       +2     
Flag Coverage Δ
argilla-server 91.18% <88.88%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants