I have an imbalanced dataset and I would like to use class_weight in model.fit().
When I use model.fit() without class_weight, it works correctly, but if I add class_weight, I’ve got an error.
My model is this one :
model2_2 = Sequential()
model2_2.add( Dense(units = 70, activation = 'elu', input_shape = (X_train.shape[1],)))
model2_2.add(Dense(units = 140, activation = 'elu'))
model2_2.add(Dropout(rate = 0.2))
model2_2.add( Dense(units = 70, activation = 'elu', input_shape = (X_train.shape[1],)))
model2_2.add(Dense(units = 35, activation = 'elu'))
model2_2.add(Dropout(rate = 0.2))
model2_2.add(Dense(units = 14, activation = 'elu'))
model2_2.add(Dense(units = 1, activation = 'sigmoid'))
model2_2.summary()
And I create a weights like that :
weights = compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y)
weights_dict = {0 : weights[0], 1 : weights[1]}
weights_dict
I compile and train the model :
model2_2.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['Accuracy', 'Precision', 'Recall'])
history2_2 = model2_2.fit(X_train, y_train, epochs = 100, batch_size = 500, validation_data = 0.1, callbacks = [reduce_learning_rate], class_weight = weights_dict)
Then, I’ve got this error :
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[92], line 2
1 model2_2.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['Accuracy', 'Precision', 'Recall'])
----> 2 history2_2 = model2_2.fit(X_train, y_train, epochs = 100, batch_size = 500, validation_data = 0.1, callbacks = [reduce_learning_rate], class_weight = weights_dict)
File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/pandas/core/series.py:1040, in Series.__getitem__(self, key)
1037 return self._values[key]
1039 elif key_is_scalar:
-> 1040 return self._get_value(key)
1042 # Convert generator to list before going through hashable part
1043 # (We will iterate through the generator there to check for slices)
1044 if is_iterator(key):
File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/pandas/core/series.py:1156, in Series._get_value(self, label, takeable)
1153 return self._values[label]
1155 # Similar to Index.get_value, but we do not fall back to positional
-> 1156 loc = self.index.get_loc(label)
1158 if is_integer(loc):
1159 return self._values[loc]
File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/pandas/core/indexes/base.py:3798, in Index.get_loc(self, key)
3793 if isinstance(casted_key, slice) or (
3794 isinstance(casted_key, abc.Iterable)
3795 and any(isinstance(x, slice) for x in casted_key)
3796 ):
3797 raise InvalidIndexError(key)
-> 3798 raise KeyError(key) from err
3799 except TypeError:
3800 # If we have a listlike key, _check_indexing_error will raise
3801 # InvalidIndexError. Otherwise we fall through and re-raise
3802 # the TypeError.
3803 self._check_indexing_error(key)
KeyError: 5
How can I fix it ?
New contributor