using class_weight in model.fit() doesnt’t work

  Kiến thức lập trình

I have an imbalanced dataset and I would like to use class_weight in model.fit().
When I use model.fit() without class_weight, it works correctly, but if I add class_weight, I’ve got an error.

My model is this one :

model2_2 = Sequential()
model2_2.add( Dense(units = 70, activation = 'elu', input_shape = (X_train.shape[1],)))
model2_2.add(Dense(units = 140, activation = 'elu'))
model2_2.add(Dropout(rate = 0.2))
model2_2.add( Dense(units = 70, activation = 'elu', input_shape = (X_train.shape[1],)))
model2_2.add(Dense(units = 35, activation = 'elu'))
model2_2.add(Dropout(rate = 0.2))
model2_2.add(Dense(units = 14, activation = 'elu'))
model2_2.add(Dense(units = 1, activation = 'sigmoid'))

model2_2.summary()

And I create a weights like that :

weights = compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y)
weights_dict = {0 : weights[0], 1 : weights[1]}
weights_dict

I compile and train the model :

model2_2.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['Accuracy', 'Precision', 'Recall'])
history2_2 = model2_2.fit(X_train, y_train, epochs = 100, batch_size = 500, validation_data = 0.1, callbacks = [reduce_learning_rate], class_weight = weights_dict)

Then, I’ve got this error :

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[92], line 2
      1 model2_2.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['Accuracy', 'Precision', 'Recall'])
----> 2 history2_2 = model2_2.fit(X_train, y_train, epochs = 100, batch_size = 500, validation_data = 0.1, callbacks = [reduce_learning_rate], class_weight = weights_dict)

File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/pandas/core/series.py:1040, in Series.__getitem__(self, key)
   1037     return self._values[key]
   1039 elif key_is_scalar:
-> 1040     return self._get_value(key)
   1042 # Convert generator to list before going through hashable part
   1043 # (We will iterate through the generator there to check for slices)
   1044 if is_iterator(key):

File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/pandas/core/series.py:1156, in Series._get_value(self, label, takeable)
   1153     return self._values[label]
   1155 # Similar to Index.get_value, but we do not fall back to positional
-> 1156 loc = self.index.get_loc(label)
   1158 if is_integer(loc):
   1159     return self._values[loc]

File ~/anaconda3/envs/Projet/lib/python3.12/site-packages/pandas/core/indexes/base.py:3798, in Index.get_loc(self, key)
   3793     if isinstance(casted_key, slice) or (
   3794         isinstance(casted_key, abc.Iterable)
   3795         and any(isinstance(x, slice) for x in casted_key)
   3796     ):
   3797         raise InvalidIndexError(key)
-> 3798     raise KeyError(key) from err
   3799 except TypeError:
   3800     # If we have a listlike key, _check_indexing_error will raise
   3801     #  InvalidIndexError. Otherwise we fall through and re-raise
   3802     #  the TypeError.
   3803     self._check_indexing_error(key)

KeyError: 5

How can I fix it ?

New contributor

user24560346 is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.

LEAVE A COMMENT