-
Notifications
You must be signed in to change notification settings - Fork 365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add support to use datasets using features with sequences of class labels #5613
base: feat/argilla-direct-feature-branch
Are you sure you want to change the base?
Conversation
…es of class labels
@@ -102,6 +102,8 @@ def _batch_index_to_row(self, batch: dict, index: int) -> dict: | |||
row[feature_name] = None | |||
else: | |||
row[feature_name] = feature.int2str(value) | |||
elif isinstance(feature, features.Sequence) and isinstance(feature.feature, features.ClassLabel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I understand that inside a feature that it's a Sequence
we have inside another feature. But I'm not 100% sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the docs:
If your data type contains a list of objects, then you want to use the Sequence feature. Remember the SQuAD dataset?
from datasets import load_dataset dataset = load_dataset('squad', split='train') dataset.features {'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None), 'context': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None)}The answers field is constructed using the Sequence feature because it contains two subfields, text and >answer_start, which are lists of string and int32, respectively.
The feature could be whatever. But this condition is quite enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, then my code it's fine.
@@ -164,6 +166,9 @@ def _row_suggestions(self, row: dict, dataset: Dataset) -> list: | |||
if question.is_text or question.is_label_selection: | |||
value = str(value) | |||
|
|||
if question.is_multi_label_selection: | |||
value = [str(v) for v in value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we should check if the value is a list. Otherwise, we should "list-ilize" it: value = [value]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## feat/argilla-direct-feature-branch #5613 +/- ##
======================================================================
- Coverage 91.20% 91.18% -0.02%
======================================================================
Files 150 150
Lines 6251 6260 +9
======================================================================
+ Hits 5701 5708 +7
- Misses 550 552 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Description
This PR adds the following changes:
int2str
function).Refs argilla-io/roadmap#21
Type of change
How Has This Been Tested
Checklist